├── paper.pdf
├── code
├── bash
│ ├── aste_14res.sh
│ ├── aste_15res.sh
│ ├── aste_16res.sh
│ ├── aste_14lap.sh
│ └── aste.sh
├── model
│ ├── table.py
│ ├── table_encoder
│ │ └── resnet.py
│ ├── bdtf_model.py
│ ├── matching_layer.py
│ └── seq2mat.py
├── utils
│ ├── __init__.py
│ ├── aste_result.py
│ └── aste_datamodule.py
└── aste_train.py
├── README.md
└── README_EN.md
/paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HITSZ-HLT/BDTF-ABSA/HEAD/paper.pdf
--------------------------------------------------------------------------------
/code/bash/aste_14res.sh:
--------------------------------------------------------------------------------
1 | bash bash/aste.sh -s 50 -c 0 -d V2/14res -l 5 -t resnet -n 2 -z 0.3 -e tensorcontext -D 64
--------------------------------------------------------------------------------
/code/bash/aste_15res.sh:
--------------------------------------------------------------------------------
1 | bash bash/aste.sh -s 50 -c 0 -d V2/15res -l 4 -t resnet -n 2 -z 0.3 -e tensorcontext -D 64
--------------------------------------------------------------------------------
/code/bash/aste_16res.sh:
--------------------------------------------------------------------------------
1 | bash bash/aste.sh -s 50 -c 0 -d V2/16res -l 5 -t resnet -n 2 -z 0.3 -e tensorcontext -D 64
--------------------------------------------------------------------------------
/code/bash/aste_14lap.sh:
--------------------------------------------------------------------------------
1 | bash bash/aste.sh -s 40 -c 0 -d V2/14lap -l 3 -t resnet -n 2 -z 0.3 -e tensorcontext -D 64
2 |
--------------------------------------------------------------------------------
/code/model/table.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .seq2mat import *
4 | from .table_encoder.resnet import ResNet
5 |
6 | class TableEncoder(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 |
10 | self.config = config
11 | if config.seq2mat == 'tensor':
12 | self.seq2mat = TensorSeq2Mat(config)
13 | elif config.seq2mat == 'tensorcontext':
14 | self.seq2mat = TensorcontextSeq2Mat(config)
15 | elif config.seq2mat == 'context':
16 | self.seq2mat = ContextSeq2Mat(config)
17 | else:
18 | self.seq2mat = Seq2Mat(config)
19 |
20 | if config.table_encoder != 'none':
21 | self.layer = nn.ModuleList([ResNet(config) for _ in range(config.num_table_layers)])
22 |
23 | def forward(self, seq, mask):
24 | table = self.seq2mat(seq, seq)
25 |
26 | if self.config.table_encoder == 'none':
27 | return table
28 |
29 | for layer_module in self.layer:
30 | table = layer_module(table)
31 |
32 | return table
33 |
--------------------------------------------------------------------------------
/code/bash/aste.sh:
--------------------------------------------------------------------------------
1 | while getopts ':d:s:c:t:n:l:z:e:D:' opt
2 | do
3 | case $opt in
4 | d)
5 | dataset="$OPTARG" ;;
6 | s)
7 | seed="$OPTARG" ;;
8 | c)
9 | CUDA_IDS="$OPTARG" ;;
10 | t)
11 | table_encoder="$OPTARG" ;;
12 | n)
13 | num_table_layers="$OPTARG" ;;
14 | l)
15 | learning_rate="$OPTARG" ;;
16 | z)
17 | span_pruning="$OPTARG" ;;
18 | e)
19 | seq2mat="$OPTARG" ;;
20 | D)
21 | num_d="$OPTARG" ;;
22 | ?)
23 | exit 1;;
24 | esac
25 | done
26 |
27 |
28 | if [ ! "${table_encoder}" ]
29 | then
30 | table_encoder=resnet
31 | fi
32 |
33 |
34 | if [ ! "${num_table_layers}" ]
35 | then
36 | num_table_layers=2
37 | fi
38 |
39 |
40 | gradient_clip_val=1
41 | warmup_steps=100
42 | weight_decay=0.01
43 | precision=16
44 | batch_size=4
45 | data_dir="../data/aste_data_bert/${dataset}"
46 |
47 |
48 | CUDA_VISIBLE_DEVICES=${CUDA_IDS} python3 aste_train.py \
49 | --gpus=1 \
50 | --precision=${precision} \
51 | --data_dir ${data_dir} \
52 | --model_name_or_path 'bert-base-uncased' \
53 | --output_dir ../output/ASTE/${dataset}/ \
54 | --learning_rate ${learning_rate}e-5 \
55 | --train_batch_size ${batch_size} \
56 | --eval_batch_size ${batch_size} \
57 | --seed $seed \
58 | --warmup_steps ${warmup_steps} \
59 | --lr_scheduler linear \
60 | --gradient_clip_val ${gradient_clip_val} \
61 | --weight_decay ${weight_decay} \
62 | --max_seq_length -1 \
63 | --max_epochs 10 \
64 | --cuda_ids ${CUDA_IDS} \
65 | --do_train \
66 | --table_encoder ${table_encoder} \
67 | --num_table_layers ${num_table_layers} \
68 | --span_pruning ${span_pruning} \
69 | --seq2mat ${seq2mat} \
70 | --num_d ${num_d}
--------------------------------------------------------------------------------
/code/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import os
4 |
5 |
6 |
7 | class NpEncoder(json.JSONEncoder):
8 | def default(self, obj):
9 | if isinstance(obj, np.integer):
10 | return int(obj)
11 | elif isinstance(obj, np.floating):
12 | return float(obj)
13 | elif isinstance(obj, np.ndarray):
14 | return obj.tolist()
15 | else:
16 | return super(NpEncoder, self).default(obj)
17 |
18 |
19 | def load_json(file_name):
20 | with open(file_name, mode='r', encoding='utf-8-sig') as f:
21 | return json.load(f)
22 |
23 |
24 | def append_json(file_name, obj, mode='a'):
25 | mkdir_if_not_exist(file_name)
26 | with open(file_name, mode=mode, encoding='utf-8') as f:
27 | if type(obj) is dict:
28 | string = json.dumps(obj)
29 | elif type(obj) is list:
30 | string = ' '.join([str(item) for item in obj])
31 | elif type(obj) is str:
32 | string = obj
33 | else:
34 | raise Exception()
35 |
36 | string = string + '\n'
37 | f.write(string)
38 |
39 |
40 | def mkdir_if_not_exist(path):
41 | dir_name, file_name = os.path.split(path)
42 | if not os.path.exists(dir_name):
43 | os.makedirs(dir_name)
44 |
45 |
46 | def save_json(json_obj, file_name):
47 | mkdir_if_not_exist(file_name)
48 | with open(file_name, mode='w+', encoding='utf-8-sig') as f:
49 | json.dump(json_obj, f, indent=4, cls=NpEncoder)
50 |
51 |
52 |
53 | def params_count(model):
54 | """
55 | Compute the number of parameters.
56 | Args:
57 | model (model): model to count the number of parameters.
58 | """
59 | return np.sum([p.numel() for p in model.parameters()]).item()
60 |
61 |
62 |
63 | def load_json(file_name):
64 | with open(file_name, mode='r', encoding='utf-8-sig') as f:
65 | return json.load(f)
66 |
67 |
68 | def yield_data_file(data_dir):
69 | for file_name in os.listdir(data_dir):
70 | yield os.path.join(data_dir, file_name)
71 |
--------------------------------------------------------------------------------
/code/model/table_encoder/resnet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | from transformers.models.t5.modeling_t5 import T5LayerNorm
4 | from einops import rearrange
5 |
6 | class ResNet(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 | if config.model_type == 'bart':
10 | layer_norm_eps = 1e-12
11 | elif config.model_type == 't5':
12 | layer_norm_eps = 1e-12
13 | else:
14 | layer_norm_eps = config.layer_norm_eps
15 | self.conv1 = nn.Conv2d(
16 | in_channels =config.hidden_size,
17 | out_channels=config.hidden_size,
18 | kernel_size =(1, 1),
19 | padding=0
20 | )
21 | self.norm1 = T5LayerNorm(config.hidden_size, layer_norm_eps)
22 |
23 | self.conv2 = nn.Conv2d(
24 | in_channels =config.hidden_size,
25 | out_channels=config.hidden_size,
26 | kernel_size =(3, 3),
27 | padding=1
28 | )
29 | self.norm2 = T5LayerNorm(config.hidden_size, layer_norm_eps)
30 |
31 | self.conv3 = nn.Conv2d(
32 | in_channels =config.hidden_size,
33 | out_channels=config.hidden_size,
34 | kernel_size =(1, 1),
35 | padding=0
36 | )
37 | self.norm3 = T5LayerNorm(config.hidden_size, layer_norm_eps)
38 |
39 | def layer_forward(self, x, conv, norm):
40 | x = conv(x)
41 | n = x.size(-1)
42 | x = rearrange(x, 'b d m n -> b (m n) d')
43 | x = norm(x)
44 | x = F.relu(x)
45 | x = rearrange(x, 'b (m n) d -> b d m n', n=n)
46 | return x
47 |
48 | def forward(self, x_input, **kwargs):
49 | x = rearrange(x_input, 'b m n d -> b d m n')
50 | x = self.layer_forward(x, self.conv1, self.norm1)
51 | x = self.layer_forward(x, self.conv2, self.norm2)
52 | x = self.layer_forward(x, self.conv3, self.norm3)
53 | x = rearrange(x, 'b d m n -> b m n d')
54 | return x + x_input
55 |
--------------------------------------------------------------------------------
/code/model/bdtf_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from transformers import BertModel, BertPreTrainedModel
4 | from .table import TableEncoder
5 | from .matching_layer import MatchingLayer
6 |
7 |
8 | class BDTFModel(BertPreTrainedModel):
9 | def __init__(self, config):
10 | super().__init__(config)
11 |
12 | self.bert = BertModel(config)
13 | self.table_encoder = TableEncoder(config)
14 | self.inference = InferenceLayer(config)
15 | self.matching = MatchingLayer(config)
16 | self.init_weights()
17 |
18 | def forward(self, input_ids, attention_mask, ids,
19 | start_label_masks, end_label_masks,
20 | t_start_labels=None, t_end_labels=None,
21 | o_start_labels=None, o_end_labels=None,
22 | table_labels_S=None, table_labels_E=None,
23 | polarity_labels=None, pairs_true=None):
24 |
25 | seq = self.bert(input_ids, attention_mask)[0]
26 | table = self.table_encoder(seq, attention_mask)
27 |
28 | output = self.inference(table, attention_mask, table_labels_S, table_labels_E)
29 |
30 | output['ids'] = ids
31 |
32 | output = self.matching(output, table, pairs_true, seq)
33 | return output
34 |
35 |
36 | class InferenceLayer(nn.Module):
37 | def __init__(self,config):
38 | super().__init__()
39 | self.config = config
40 | self.cls_linear_S = nn.Linear(768,1)
41 | self.cls_linear_E = nn.Linear(768,1)
42 |
43 | def span_pruning(self, pred, z, attention_mask):
44 | mask_length = attention_mask.sum(dim=1)-2
45 | length = ((attention_mask.sum(dim=1)-2)*z).long()
46 | length[length<5] = 5
47 | max_length = mask_length**2
48 | for i in range(length.shape[0]):
49 | if length[i] > max_length[i]:
50 | length[i] = max_length[i]
51 | batch_size = attention_mask.shape[0]
52 | pred_sort,_ = pred.view(batch_size, -1).sort(descending=True)
53 | batchs = torch.arange(batch_size).to('cuda')
54 | topkth = pred_sort[batchs, length-1].unsqueeze(1)
55 | return pred >= (topkth.view(batch_size,1,1))
56 |
57 | def forward(self, table, attention_mask, table_labels_S, table_labels_E):
58 | outputs = {}
59 |
60 | logits_S = torch.squeeze(self.cls_linear_S(table), 3)
61 | logits_E = torch.squeeze(self.cls_linear_E(table), 3)
62 |
63 |
64 | loss_func = nn.BCEWithLogitsLoss(weight=(table_labels_S>=0))
65 |
66 | outputs['table_loss_S'] = loss_func(logits_S, table_labels_S.float())
67 | outputs['table_loss_E'] = loss_func(logits_E, table_labels_E.float())
68 |
69 | S_pred = torch.sigmoid(logits_S) * (table_labels_S>=0)
70 | E_pred = torch.sigmoid(logits_E) * (table_labels_S>=0)
71 |
72 | if self.config.span_pruning != 0:
73 | table_predict_S = self.span_pruning(S_pred, self.config.span_pruning, attention_mask)
74 | table_predict_E = self.span_pruning(E_pred, self.config.span_pruning, attention_mask)
75 | else:
76 | table_predict_S = S_pred>0.5
77 | table_predict_E = E_pred>0.5
78 | outputs['table_predict_S'] = table_predict_S
79 | outputs['table_predict_E'] = table_predict_E
80 | outputs['table_labels_S'] = table_labels_S
81 | outputs['table_labels_E'] = table_labels_E
82 | return outputs
83 |
84 |
--------------------------------------------------------------------------------
/code/model/matching_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | class MatchingLayer(nn.Module):
6 | def __init__(self, config):
7 | super().__init__()
8 | self.config = config
9 | self.linear = nn.Linear(config.hidden_size*3, 4)
10 |
11 | def gene_pred(self, batch_size, S_preds, E_preds, pairs_true):
12 | all_pred = [[] for i in range(batch_size)]
13 | pred_label = [[] for i in range(batch_size)]
14 | pred_maxlen = 0
15 | for i in range(batch_size):
16 | S_pred = torch.nonzero(S_preds[i]).cpu().numpy()
17 | E_pred = torch.nonzero(E_preds[i]).cpu().numpy()
18 |
19 | for (s0, s1) in S_pred:
20 | for (e0, e1) in E_pred:
21 | if s0 <= e0 and s1 <= e1:
22 | sentiment = 0
23 | for j in range(len(pairs_true[i])):
24 | p = pairs_true[i][j]
25 | if [s0-1, e0, s1-1, e1] == p[:4]:
26 | sentiment = p[4]
27 | pred_label[i].append(sentiment)
28 | all_pred[i].append([s0-1, e0, s1-1, e1])
29 | if len(all_pred[i])>pred_maxlen:
30 | pred_maxlen = len(all_pred[i])
31 |
32 | for i in range(batch_size):
33 | for j in range(len(all_pred[i]), pred_maxlen):
34 | pred_label[i].append(-1)
35 | pred_label = torch.tensor(pred_label).to('cuda')
36 |
37 | return all_pred, pred_label, pred_maxlen
38 |
39 | def input_encoding(self, batch_size, pairs, maxlen, table, seq):
40 | input_ret = torch.zeros([batch_size, maxlen, self.config.hidden_size*3]).to('cuda')
41 | for i in range(batch_size):
42 | j = 0
43 | for (s0, e0, s1, e1) in pairs[i]:
44 | S = table[i, s0+1, s1+1, :]
45 | E = table[i, e0, e1, :]
46 | R = torch.max(torch.max(table[i, s0+1:e0+1, s1+1:e1+1, :], dim=1)[0], dim=0)[0]
47 | input_ret[i, j, :] = torch.cat([S, E, R])
48 | j += 1
49 | return input_ret
50 |
51 |
52 | def forward(self, outputs, Table, pairs_true, seq):
53 | seq = seq.clone().detach()
54 | table = Table.clone()
55 | batch_size = table.size(0)
56 |
57 | all_pred, pred_label, pred_maxlen = self.gene_pred(batch_size, outputs['table_predict_S'], outputs['table_predict_E'], pairs_true)
58 | pred_input = self.input_encoding(batch_size, all_pred, pred_maxlen, table, seq)
59 | pred_output = self.linear(pred_input)
60 | loss_func = nn.CrossEntropyLoss(ignore_index = -1)
61 |
62 | loss_input = pred_output
63 | loss_label = pred_label
64 |
65 | if loss_input.shape[1] == 0:
66 | loss_input = torch.zeros([batch_size, 1, 2])
67 | loss_label = torch.zeros([batch_size, 1])-1
68 |
69 | outputs['pair_loss'] = loss_func(loss_input.transpose(1, 2), loss_label.long())
70 |
71 | pairs_logits = F.softmax(pred_output, dim=2)
72 | if pairs_logits.shape[1] == 0:
73 | outputs['pairs_preds'] = []
74 | return outputs
75 |
76 | pairs_pred = pairs_logits.argmax(dim=2)
77 |
78 | outputs['pairs_preds'] = []
79 | for i in range(batch_size):
80 | for j in range(len(all_pred[i])):
81 | if pairs_pred[i][j] >= 1:
82 | se = all_pred[i][j]
83 | outputs['pairs_preds'].append((i, se[0], se[1], se[2], se[3], pairs_pred[i][j].item()))
84 | return outputs
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [**中文说明**](https://github.com/HITSZ-HLT/BDTF-ASTE/) | [**English**](https://github.com/HITSZ-HLT/BDTF-ASTE/blob/master/README_EN.md)
2 |
3 |
4 | # BDTF-ASTE
5 |
6 | 本仓库开源了以下论文的代码:
7 |
8 | - 标题:[Boundary-Driven Table-Filling for Aspect Sentiment Triplet Extraction](https://aclanthology.org/2022.emnlp-main.435/)
9 | - 作者:Yice Zhang∗, Yifan Yang∗, Yihui Li, Bin Liang, Shiwei Chen, Yixue Dang, Ming Yang, and Ruifeng Xu
10 | - 会议:EMNLP-2022 Main (Long)
11 |
12 | ## 工作简介
13 |
14 | ### ASTE任务
15 |
16 | 本文要解决的是Aspect-Based Sentiment Analysis(ABSA)问题中的Aspect Sentiment Triplet Extraction(ASTE)任务。
17 | 如下图所示,ASTE的目的是抽取用户评论中表达观点的方面情感三元组,一个元组包含三个部分:
18 | - Aspect Term: 情感所针对的目标对象,一般是被评价实体(餐馆或者产品)的某个方面项,常被称作方面术语、方面词、属性词等。
19 | - Opinion Term: 具体表达情感的词或短语,常被称作情感术语、情感词等。
20 | - Sentiment Polarity: 用户针对Aspect Term所表达的情感倾向,类别空间为`{POS, NEG, NEU}`。
21 |
22 |

23 |
24 | ### 以往方法的问题
25 |
26 | 以往的方法将本任务建模为一个表格填充问题(table-filling problem)。如下图所示,二维表中的每个元素为词与词之间的关系。该方法首先通过对角线抽取aspect和opinion,然后通过aspect和opinion定位对应的关系区域,通过投票的方法是确定aspect和opinion之间的关系。该方法存在诸多问题。比较明显的问题有两个:
27 | 1. 关系不一致:将aspect和opinion之间的关系分解为词与词之间的关系,这会带来潜在的关系不一致问题。
28 | 2. 边界不敏感:如果aspect或者opinion的边界出现了小错误,二者关系预测的结果大概率不会改变,这就使得模型产生了边界错误的输出,如('dogs', 'top notch', `POS`)。
29 |
30 | 
31 |
32 |
33 | 以往的工作尝试使用Span-based的方法来解决关系不一致的问题。这是一种可行的思路。但是该方法忽略了细粒度的词级别的信息,这正是表格填充方法的优点。
34 |
35 | ### 本文提出的方法
36 |
37 | 本文为了解决上述的两个问题,提出了边界驱动的表格填充方法(Boundary-Driven Table-Filling)。如下图所示,该方法将方面关系三元组转为二维表中的一个关系区域,因而将ASTE任务转化为关系区域的定位和分类。对关系区域整体进行分类可以解决了关系不一致的问题,那些边界错误的关系区域也可以通过将其分类为Invaild而移除。
38 |
39 | 
40 |
41 | 此外,本文还提出了一种关系学习的方法来学习一个二维的表示。该方法包含三个部分:
42 | - 首先,将评论文本输入到`BERT`中学习词级别的上下文表示。
43 | - 然后,通过基于张量的操作,根据词表示构建关系表示。文本中所有词之间的关系表示构成一个二维的表,表中的元素为一个向量。
44 | - 最后,使用CNN对二维表进行建模。
45 | 该方法学习到的二维表示将被用到关系区域的定位和分类中。
46 |
47 | 整体上,本文所提出方法的模型框架如下图所示。
48 |
49 | 
50 |
51 | ### 实验结果
52 |
53 | 本方法的主要实验结果如下表,详细的分析见论文。
54 |
55 | 
56 |
57 | ## 运行代码
58 | ### 环境配置
59 |
60 | - transformers==4.15.0
61 | - pytorch==1.7.1
62 | - einops=0.4.0
63 | - torchmetrics==0.7.0
64 | - tntorch==1.0.1
65 | - pytorch-lightning==1.3.5
66 |
67 | ### 代码结构
68 |
69 | ```
70 | ├── code
71 | │ ├── utils
72 | │ │ ├── __init__.py
73 | │ │ ├── aste_datamodule.py
74 | | | └── aste_result.py
75 | │ ├── model
76 | │ │ ├── seq2mat.py
77 | │ │ ├── table.py
78 | │ │ ├── table_encoder
79 | │ │ | └── resnet.py
80 | | | └── bdtf_model.py
81 | | ├── aste_train.py
82 | | └── bash
83 | │ ├── aste.sh
84 | │ ├── aste_14res.sh
85 | │ ├── aste_14lap.sh
86 | │ ├── aste_15res.sh
87 | | └── aste_16res.sh
88 | └── data
89 | └── aste_data_bert
90 | ├── V1
91 | │ ├── 14res
92 | | │ ├── train.json
93 | | │ ├── dev.json
94 | | │ └── test.json
95 | │ ├── 14lap/...
96 | │ ├── 15res/...
97 | | └── 16res/...
98 | └── V2/...
99 | ```
100 |
101 | ### 运行代码
102 |
103 | 在`code`目录下
104 | - 运行`chmod +x bash/*`。
105 | - 运行`bash/aste_14lap.sh`。
106 |
107 | 下面是aste_14lap.sh运行的结果。这里随机种子取的是40,计算设备为A100。
108 |
109 | 
110 |
111 | 在V100上跑aste_14lap.sh,结果如下。
112 |
113 | 
114 |
115 | 请注意,文章发布的性能都是在5个随机种子下运行然后取平均的结果,这与单次运行可能存在一些出入。
116 |
117 | ## 如有问题请在`issues`提出,或者联系我
118 |
119 | - email: `zhangyc_hit@163.com`
120 |
121 |
122 |
--------------------------------------------------------------------------------
/code/utils/aste_result.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from . import append_json, save_json, mkdir_if_not_exist
4 |
5 | polarity_map = {
6 | 'NEG': 0,
7 | 'NEU': 1,
8 | 'POS': 2
9 | }
10 | polarity_map_reversed = {
11 | 0: 'NEG',
12 | 1: 'NEU',
13 | 2: 'POS'
14 | }
15 |
16 | class F1_Measure:
17 | def __init__(self):
18 | self.pred_set = set()
19 | self.true_set = set()
20 |
21 | def pred_inc(self, idx, preds):
22 | for pred in preds:
23 | self.pred_set.add((idx, tuple(pred)))
24 |
25 | def true_inc(self, idx, trues):
26 | for true in trues:
27 | self.true_set.add((idx, tuple(true)))
28 |
29 | def report(self):
30 | self.f1, self.p, self.r = self.cal_f1(self.pred_set, self.true_set)
31 | return self.f1
32 |
33 | def __getitem__(self, key):
34 | if hasattr(self, key):
35 | return getattr(self, key)
36 | else:
37 | raise NotImplementedError
38 |
39 | def cal_f1(self, pred_set, true_set):
40 | intersection = pred_set.intersection(true_set)
41 | _p = len(intersection) / len(pred_set) if pred_set else 1
42 | _r = len(intersection) / (len(true_set)) if true_set else 1
43 | f1 = 2 * _p * _r / (_p + _r) if _p + _r else 0
44 | return f1, _p, _r
45 |
46 |
47 |
48 | class NER_F1_Measure(F1_Measure):
49 | def __init__(self, entity_types):
50 | super().__init__()
51 | self.entity_types = entity_types
52 |
53 | def report(self):
54 | for entity_type in self.entity_types:
55 | name = '_'.join(entity_type)
56 | pred_set = self.filter(self.pred_set, entity_type)
57 | true_set = self.filter(self.true_set, entity_type)
58 |
59 | f1, p, r = self.cal_f1(pred_set, true_set)
60 | setattr(self, f'{name}_f1', f1)
61 |
62 | def filter(self, set_, entity_type):
63 | return_set = set()
64 | for type_ in entity_type:
65 | return_set.update(set([it for it in set_ if it[1][0] == type_]))
66 | return return_set
67 |
68 |
69 | class Result:
70 | def __init__(self, result_json):
71 | self.result_json = result_json
72 | self.detailed_metric = None
73 | self.monitor = 0
74 |
75 | def __setitem__(self, key, value):
76 | self.result_json[key] = value
77 |
78 | def __ge__(self, other):
79 | return self.monitor >= other.monitor
80 |
81 | def __gt__(self, other):
82 | return self.monitor > other.monitor
83 |
84 | @classmethod
85 | def parse_from(cls, all_preds, examples):
86 | result_json = {}
87 | examples = {example['ID']: example for example in examples}
88 |
89 | for preds in all_preds:
90 |
91 | for ID in preds['ids']:
92 | example = examples[ID]
93 | pairs_true = []
94 | for pp in example['pairs']:
95 | pl = polarity_map[pp[4]]+1
96 | pairs_true.append([pp[0],pp[1],pp[2],pp[3],pl])
97 | result_json[ID] = {
98 | 'ID': ID,
99 | 'sentence': example['sentence'],
100 | 'pairs': pairs_true,
101 | 'tokens': str(example['tokens']),
102 | 'pair_preds': set(),
103 | }
104 | for (i, a_start, a_end, b_start, b_end, pol) in preds['pair_preds']:
105 | ID = preds['ids'][i]
106 | result_json[ID]['pair_preds'].add((a_start, a_end, b_start, b_end, pol))
107 | return cls(result_json)
108 |
109 |
110 | def cal_metric(self):
111 | b_pair_f1 = F1_Measure()
112 | for item in self.result_json.values():
113 | for pair_f1 in (b_pair_f1, ):
114 | pair_f1.true_inc(item['ID'], item['pairs'])
115 |
116 | b_pair_f1.pred_inc(item['ID'], item['pair_preds'])
117 |
118 | b_pair_f1.report()
119 |
120 | detailed_metrics = {
121 | 'pair_f1': b_pair_f1['f1'],
122 | 'pair_p':b_pair_f1['p'],
123 | 'pair_r':b_pair_f1['r'],
124 | }
125 |
126 | self.detailed_metrics = detailed_metrics
127 | self.monitor = b_pair_f1['f1']
128 |
129 | def report(self):
130 | print(f'monitor: {self.monitor:.4f}', end=" ** ")
131 | for metric_names in (('pair_f1','pair_p','pair_r'),):
132 | for metric_name in metric_names:
133 | value = self.detailed_metrics[metric_name] if metric_name in self.detailed_metrics else 0
134 | print(f'{metric_name}: {value:.4f}', end=' | ')
135 | print()
136 |
137 | def save(self, dir_name, args):
138 | mkdir_if_not_exist(dir_name)
139 | current_time = time.strftime("%Y-%m-%d %H_%M_%S", time.localtime())
140 | current_day = time.strftime("%Y-%m-%d", time.localtime())
141 |
142 | result_file_name = os.path.join(dir_name, f'val_results_{self.monitor*10000:4.0f}_{current_time}.txt')
143 | performance_dir = os.path.join(os.path.dirname(os.path.dirname(dir_name)), 'performance')
144 |
145 | performance_dir = os.path.join(performance_dir, current_day)
146 | performance_file_name = os.path.join(performance_dir, f'{args.cuda_ids}.txt')
147 |
148 | for key, item in self.result_json.items():
149 | item['pair_preds'] = list(item['pair_preds'])
150 |
151 | save_json(list(self.result_json.values()), result_file_name)
152 | print('## save result to', result_file_name)
153 |
154 | description = f'{args.data_dir}, lr={args.learning_rate}, seed={args.seed}, model_name_or_path={args.model_name_or_path}'
155 | detailed_metrics = {k: (v if type(v) in (int, float) else v.item()) for k,v in self.detailed_metrics.items()}
156 |
157 | append_json(performance_file_name, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
158 | append_json(performance_file_name, f'{description} {self.monitor*10000:4.0f}')
159 | append_json(performance_file_name, f'{args.span_pruning} {args.seq2mat} {args.num_d} {args.table_encoder} {args.num_table_layers}')
160 | append_json(performance_file_name, detailed_metrics)
161 | append_json(performance_file_name, '')
162 |
163 |
164 |
--------------------------------------------------------------------------------
/code/model/seq2mat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from transformers.models.t5.modeling_t5 import T5LayerNorm
4 | from transformers.activations import ACT2FN
5 |
6 |
7 | class Seq2Mat(nn.Module):
8 | def __init__(self, config):
9 | super().__init__()
10 | self.config = config
11 | self.W = nn.Linear(config.hidden_size*2, config.hidden_size)
12 | self.norm = T5LayerNorm(config.hidden_size, config.layer_norm_eps)
13 | self.activation = ACT2FN[config.hidden_act]
14 |
15 | def forward(self, x, y):
16 | """
17 | x,y: [B, L, H] => [B, L, L, H]
18 | """
19 | x, y = torch.broadcast_tensors(x[:, :, None], y[:, None, :])
20 | t = torch.cat([x, y], dim=-1)
21 | t = self.W(t)
22 | t = self.activation(t)
23 | return t
24 |
25 |
26 | class ContextSeq2Mat(nn.Module):
27 | def __init__(self, config):
28 | super().__init__()
29 | self.config = config
30 | self.W = nn.Linear(config.hidden_size*3, config.hidden_size)
31 | self.norm = T5LayerNorm(config.hidden_size, config.layer_norm_eps)
32 | self.activation = ACT2FN[config.hidden_act]
33 |
34 | def forward(self, x, y):
35 | """
36 | x,y: [B, L, H] => [B, L, L, H]
37 | """
38 | xmat = x.clone()
39 | batch_size = xmat.shape[0]
40 | x, y = torch.broadcast_tensors(x[:, :, None], y[:, None, :])
41 |
42 | max_len = xmat.shape[1]
43 | xmat_t = xmat.transpose(1, 2)
44 | context = torch.ones_like(x).to('cuda')
45 | for i in range(max_len):
46 | diag = x.diagonal(dim1=1, dim2=2, offset=-i)
47 | xmat_t = torch.max(xmat_t[:, :, :max_len-i], diag)
48 | bb = [[b] for b in range(batch_size)]
49 | linexup = [[j for j in range(max_len-i)] for b in range(batch_size)]
50 | lineyup = [[j+i for j in range(max_len-i)] for b in range(batch_size)]
51 | linexdown = [[j+i for j in range(max_len-i)] for b in range(batch_size)]
52 | lineydown = [[j for j in range(max_len-i)] for b in range(batch_size)]
53 | context[bb, linexup, lineyup, :] = xmat_t.permute(0, 2, 1)
54 | context[bb, linexdown, lineydown, :] = xmat_t.permute(0, 2, 1)
55 |
56 | t = torch.cat([x, y, context], dim=-1)
57 | t = self.W(t)
58 | t = self.activation(t)
59 | return t
60 |
61 |
62 | class TensorSeq2Mat(nn.Module):
63 | """
64 | refernce: SOCHER R, PERELYGIN A, WU J, 等. Recursive deep models for semantic compositionality over a sentiment treebank[C]//Proceedings of the 2013 conference on empirical methods in natural language processing. 2013: 1631-1642.
65 | """
66 | def __init__(self, config):
67 | super().__init__()
68 | self.config = config
69 | self.h = config.num_attention_heads
70 | self.d = config.num_d
71 | self.W = nn.Linear(2*config.hidden_size+self.d, config.hidden_size)
72 | self.V = nn.Parameter(torch.Tensor(self.d, config.hidden_size, config.hidden_size))
73 | self.norm = T5LayerNorm(config.hidden_size, config.layer_norm_eps)
74 | self.activation = ACT2FN[config.hidden_act]
75 | self.init_weights()
76 |
77 | def init_weights(self):
78 | self.V.data.normal_(mean=0.0, std=self.config.initializer_range)
79 |
80 | def rntn(self, x, y):
81 | t = torch.cat([x, y], dim=-1)
82 | xv = torch.einsum('b m n p, k p d -> b m n k d', x, self.V)
83 | xvy = torch.einsum('b m n k d, b m n d -> b m n k', xv, y)
84 | t = torch.cat([t, xvy], dim=-1)
85 | tw = self.W(t)
86 | return tw
87 |
88 | def forward(self, x, y):
89 | """
90 | x,y: [B, L, H] => [B, L, L, H]
91 | """
92 | seq = x
93 | x, y = torch.broadcast_tensors(x[:, :, None], y[:, None, :])
94 | t = self.rntn(x, y)
95 | t = self.activation(t)
96 | return t
97 |
98 |
99 | class TensorcontextSeq2Mat(nn.Module):
100 | """
101 | refernce: SOCHER R, PERELYGIN A, WU J, 等. Recursive deep models for semantic compositionality over a sentiment treebank[C]//Proceedings of the 2013 conference on empirical methods in natural language processing. 2013: 1631-1642.
102 | """
103 | def __init__(self, config):
104 | super().__init__()
105 | self.config = config
106 | self.h = config.num_attention_heads
107 | self.d = config.num_d
108 | self.W = nn.Linear(3*config.hidden_size+self.d, config.hidden_size)
109 | self.V = nn.Parameter(torch.Tensor(self.d, config.hidden_size, config.hidden_size))
110 | self.norm = T5LayerNorm(config.hidden_size, config.layer_norm_eps)
111 | self.activation = ACT2FN[config.hidden_act]
112 | self.init_weights()
113 |
114 | def init_weights(self):
115 | if self.config.model_type=='bart' or self.config.model_type=='t5':
116 | self.V.data.normal_(mean=0.0, std=0.02)
117 | else:
118 | self.V.data.normal_(mean=0.0, std=self.config.initializer_range)
119 |
120 | def rntn(self, x, y, xmat):
121 | max_len = xmat.shape[1]
122 | xmat_t = xmat.transpose(1, 2)
123 | batch_size = xmat.shape[0]
124 | context = torch.ones_like(x).to('cuda')
125 | for i in range(max_len):
126 | diag = x.diagonal(dim1=1, dim2=2, offset=-i)
127 | xmat_t = torch.max(xmat_t[:, :, :max_len-i], diag)
128 | bb = [[b] for b in range(batch_size)]
129 | linexup = [[j for j in range(max_len-i)] for b in range(batch_size)]
130 | lineyup = [[j+i for j in range(max_len-i)] for b in range(batch_size)]
131 | linexdown = [[j+i for j in range(max_len-i)] for b in range(batch_size)]
132 | lineydown = [[j for j in range(max_len-i)] for b in range(batch_size)]
133 | context[bb, linexup, lineyup, :] = xmat_t.permute(0, 2, 1)
134 | context[bb, linexdown, lineydown, :] = xmat_t.permute(0, 2, 1)
135 |
136 | t = torch.cat([x, y, context], dim=-1)
137 | xvy = torch.einsum('b m n p, k p d, b m n d -> b m n k', x, self.V, y)
138 | t = torch.cat([t, xvy], dim=-1)
139 | tw = self.W(t)
140 | return tw
141 |
142 | def forward(self, x, y):
143 | """
144 | x,y: [B, L, H] => [B, L, L, H]
145 | """
146 | xmat = x
147 | x, y = torch.broadcast_tensors(x[:, :, None], y[:, None, :])
148 | t = self.rntn(x, y, xmat)
149 | t = self.activation(t)
150 | return t
151 |
--------------------------------------------------------------------------------
/README_EN.md:
--------------------------------------------------------------------------------
1 | [**中文说明**](https://github.com/HITSZ-HLT/BDTF-ASTE/) | [**English**](https://github.com/HITSZ-HLT/BDTF-ASTE/blob/master/README_EN.md)
2 |
3 | # BDTF-ASTE
4 |
5 | This repository releases the code of the following paper:
6 |
7 | - Title: [Boundary-Driven Table-Filling for Aspect Sentiment Triplet Extraction](https://aclanthology.org/2022.emnlp-main.435/)
8 | - Authors: Yice Zhang∗, Yifan Yang∗, Yihui Li, Bin Liang, Shiwei Chen, Yixue Dang, Ming Yang, and Ruifeng Xu
9 | - Conference: EMNLP-2022 Main (Long)
10 |
11 | ## Brief Introduction of Our Paper
12 |
13 | ### The ASTE Task
14 |
15 | The task that this paper addresses is Aspect Sentiment Triplet Extraction (ASTE), which is an important task in Aspect-Based Sentiment Analysis(ABSA).
16 | As shown in the figure below, ASTE aims to extract the aspect terms along with the corresponding opinion terms and the expressed sentiments in the review.
17 | Specifically, a triplet is defined as (aspect term, opinion term, sentiment polarity):
18 | - **Aspect term**: the target of an opinion, usually an aspect of an entity (a restaurant or product).
19 | - **Opinion term**: the word or phrase that specifically expresses the sentiment.
20 | - **Sentiment polarity**: a specific category in `{POS, NEG, NEU}`.
21 |
22 | 
23 |
24 | ### Previous Methods and Limitations
25 |
26 | Previous methods tackle the ASTE task through a table-filling approach, where the triplets are represented by a two-dimensional (2D) table of word-pair relations.
27 | In this approach, aspect terms and opinion terms are extracted through the diagonal elements of the table, and sentiments are treated as relation tags that are represented by the non-diagonal elements of the table.
28 | This formalization enables joint learning of different subtasks in ASTE, achieving superior performance over the pipeline approach.
29 |
30 | However, the previous table formalization suffers from relation inconsistency and boundary insensitivity when dealing with multi-word aspect terms and opinion terms. It decomposes the relation between an aspect term and an opinion term into the relations between the corresponding aspect words and opinion words.
31 | In other words, a term-level relation is represented by several wordlevel relation tags. The relation tags in the table are assigned independently, which leads to potential inconsistencies in the predictions of the wordlevel relations.
32 | In addition, when there are minor boundary errors in the aspect term or opinion term, the voting result for the term-level relation may stay unchanged, encouraging the model to produce wrong predictions. Researchers try to solve this problem through a span-based method, but their method discards fine-grained word-level information, which is the advantage of the table-filling approach.
33 |
34 | 
35 |
36 | ### Proposed Approach
37 |
38 | This paper proposes a Boundary-Driven Table-Filling (BDTF) approach for ASTE to overcome the above issues.
39 | In BDTF, a triplet is represented as a relation region in the 2D table, which is shown in the figure below.
40 | In this way, it extracts triplets by directly detecting and classifying the relation regions in a 2D table.
41 | Classification over the entire relation region ensures relation consistency, and those relation regions with boundary errors can be removed by being classified as invalid.
42 |
43 | 
44 |
45 |
46 | In addition, this paper also develops an effective relation representation learning approach to learn the table representation.
47 | This consists of three parts:
48 | - We first learn the word-level contextualized representations of the input review through a pre-trained language model.
49 | - Then we adopt a tensor-based operation to construct the relation-level representations to fully exploit the word-to-word interactions.
50 | - Finally, we model relation-to-relation interactions through a multi-layer convolution-based encoder to enhance the relation-level representations.
51 |
52 | The relation representations of each two words in the review together form a 2D relation matrix, which serves as the table representation for BDTF.
53 |
54 | The proposed approach is briefly presented in the figure below.
55 |
56 | 
57 |
58 | ### Experimentual Results
59 |
60 | The main results are listed in the table below. See the paper for a detailed analysis.
61 |
62 | 
63 |
64 | ## How to Run
65 |
66 | ### Requirements
67 |
68 | - transformers==4.15.0
69 | - pytorch==1.7.1
70 | - einops=0.4.0
71 | - torchmetrics==0.7.0
72 | - tntorch==1.0.1
73 | - pytorch-lightning==1.3.5
74 |
75 | ### Files
76 |
77 | ```
78 | ├── code
79 | │ ├── utils
80 | │ │ ├── __init__.py
81 | │ │ ├── aste_datamodule.py
82 | | | └── aste_result.py
83 | │ ├── model
84 | │ │ ├── seq2mat.py
85 | │ │ ├── table.py
86 | │ │ ├── table_encoder
87 | │ │ | └── resnet.py
88 | | | └── bdtf_model.py
89 | | ├── aste_train.py
90 | | └── bash
91 | │ ├── aste.sh
92 | │ ├── aste_14res.sh
93 | │ ├── aste_14lap.sh
94 | │ ├── aste_15res.sh
95 | | └── aste_16res.sh
96 | └── data
97 | └── aste_data_bert
98 | ├── V1
99 | │ ├── 14res
100 | | │ ├── train.json
101 | | │ ├── dev.json
102 | | │ └── test.json
103 | │ ├── 14lap/...
104 | │ ├── 15res/...
105 | | └── 16res/...
106 | └── V2/...
107 | ```
108 |
109 | ### Run our code!
110 |
111 | Enter `code` and
112 | - execute `chmod +x bash/*`,
113 | - execute `bash/aste_14lap.sh`.
114 |
115 | Result of aste_14lap.sh (Random seed is set to be 40 and the computing device is A100):
116 |
117 | 
118 |
119 | Result of aste_14lap.sh (Random seed is set to be 40 and the computing device is V100):
120 |
121 | 
122 |
123 | Note that the performance posted in the paper is the average results of 5 run with 5 different random seeds, which has some differences from a single run.
124 |
125 | ## If you have any questions, please raise an `issue` or contact me
126 |
127 | - email: `zhangyc_hit@163.com`
128 |
129 |
130 |
131 |
132 |
--------------------------------------------------------------------------------
/code/utils/aste_datamodule.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | import pytorch_lightning as pl
4 | from transformers import AutoTokenizer
5 |
6 | import os
7 | from . import load_json
8 |
9 | polarity_map = {
10 | 'NEG': 0,
11 | 'NEU': 1,
12 | 'POS': 2
13 | }
14 |
15 | polarity_map_reversed = {
16 | 0: 'NEG',
17 | 1: 'NEU',
18 | 2: 'POS'
19 | }
20 |
21 |
22 | class Example:
23 | def __init__(self, data, max_length=-1):
24 | self.data = data
25 | self.max_length = max_length
26 | self.data['tokens'] = eval(str(self.data['tokens']))
27 |
28 | def __getitem__(self, key):
29 | return self.data[key]
30 |
31 | def t_entities(self):
32 | return [tuple(entity[:3]) for entity in self['entities'] if entity[0]=='target']
33 |
34 | def o_entities(self):
35 | return [tuple(entity[:3]) for entity in self['entities'] if entity[0]=='opinion']
36 |
37 | def entity_label(self, target_oriented, length):
38 | entities = self.t_entities() if target_oriented else self.o_entities()
39 | return Example.make_start_end_labels(entities, length)
40 |
41 | def table_label(self, length, ty, id_len):
42 | label = [[-1 for _ in range(length)] for _ in range(length)]
43 | id_len = id_len.item()
44 |
45 | for i in range(1, id_len-1):
46 | for j in range(1, id_len-1):
47 | label[i][j] = 0
48 |
49 | for t_start, t_end, o_start, o_end, pol in self['pairs']:
50 | if ty == 'S':
51 | label[t_start+1][o_start+1] = 1
52 | elif ty == 'E':
53 | label[t_end][o_end] = 1
54 | return label
55 |
56 | @staticmethod
57 | def make_start_end_labels(entities, length, plus_one=True):
58 | start_label = [0] * length
59 | end_label = [0] * length
60 |
61 | for (t, s, e) in entities:
62 | if plus_one:
63 | s, e = s+1, e+1
64 |
65 | if s < length:
66 | start_label[s] = 1
67 |
68 | if e-1 < length:
69 | end_label[e-1] = 1
70 |
71 | return start_label, end_label
72 |
73 |
74 | class DataCollatorForASTE:
75 | def __init__(self, tokenizer, max_seq_length):
76 | self.tokenizer = tokenizer
77 | self.max_seq_length = max_seq_length
78 |
79 |
80 | def __call__(self, examples):
81 |
82 | batch = self.tokenizer_function(examples)
83 |
84 | length = batch['input_ids'].size(1)
85 |
86 | batch['t_start_labels'], batch['t_end_labels'] = self.start_end_labels(examples, True, length)
87 | batch['o_start_labels'], batch['o_end_labels'] = self.start_end_labels(examples, False, length)
88 | batch['example_ids'] = [example['ID'] for example in examples]
89 | batch['table_labels_S'] = torch.tensor([examples[i].table_label(length, 'S', (batch['input_ids'][i]>0).sum()) for i in range(len(examples))], dtype=torch.long)
90 | batch['table_labels_E'] = torch.tensor([examples[i].table_label(length, 'E', (batch['input_ids'][i]>0).sum()) for i in range(len(examples))], dtype=torch.long)
91 |
92 | al = [example['pairs'] for example in examples]
93 | pairs_ret = []
94 | for pairs in al:
95 | pairs_chg = []
96 | for p in pairs:
97 | pairs_chg.append([p[0],p[1],p[2], p[3], polarity_map[p[4]]+1])
98 | pairs_ret.append(pairs_chg)
99 | batch['pairs_true'] = pairs_ret
100 |
101 | return {
102 | 'ids': batch['example_ids'],
103 | 'input_ids' : batch['input_ids'],
104 | 'attention_mask': batch['attention_mask'],
105 |
106 | 't_start_labels': batch['t_start_labels'],
107 | 't_end_labels' : batch['t_end_labels'],
108 | 'o_start_labels': batch['o_start_labels'],
109 | 'o_end_labels' : batch['o_end_labels'],
110 |
111 | 'start_label_masks': batch['start_label_masks'],
112 | 'end_label_masks' : batch['end_label_masks'],
113 |
114 | 'table_labels_S' : batch['table_labels_S'],
115 | 'table_labels_E' : batch['table_labels_E'],
116 | 'pairs_true' : batch['pairs_true'],
117 | }
118 |
119 | def start_end_labels(self, examples, target_oriented, length):
120 | start_labels = []
121 | end_labels = []
122 |
123 | for example in examples:
124 | start_label, end_label = example.entity_label(target_oriented, length)
125 | start_labels.append(start_label)
126 | end_labels.append(end_label)
127 |
128 | start_labels = torch.tensor(start_labels, dtype=torch.long)
129 | end_labels = torch.tensor(end_labels, dtype=torch.long)
130 |
131 | return start_labels, end_labels
132 |
133 | def tokenizer_function(self, examples):
134 | text = [example['sentence'] for example in examples]
135 | kwargs = {
136 | 'text': text,
137 | 'return_tensors': 'pt'
138 | }
139 |
140 | if self.max_seq_length in (-1, 'longest'):
141 | kwargs['padding'] = True
142 | else:
143 | kwargs['padding'] = 'max_length'
144 | kwargs['max_length'] = self.max_seq_length
145 | kwargs['truncation'] = True
146 |
147 | batch_encodings = self.tokenizer(**kwargs)
148 | length = batch_encodings['input_ids'].size(1)
149 |
150 | start_label_masks = []
151 | end_label_masks = []
152 |
153 | for i in range(len(examples)):
154 | encoding = batch_encodings[i]
155 | word_ids = encoding.word_ids
156 | type_ids = encoding.type_ids
157 |
158 | start_label_mask = [(1 if type_ids[i]==0 else 0) for i in range(length)]
159 | end_label_mask = [(1 if type_ids[i]==0 else 0) for i in range(length)]
160 |
161 | for token_idx in range(length):
162 | current_word_idx = word_ids[token_idx]
163 | prev_word_idx = word_ids[token_idx-1] if token_idx-1 > 0 else None
164 | next_word_idx = word_ids[token_idx+1] if token_idx+1 < length else None
165 |
166 | if prev_word_idx is not None and current_word_idx == prev_word_idx:
167 | start_label_mask[token_idx] = 0
168 | if next_word_idx is not None and current_word_idx == next_word_idx:
169 | end_label_mask[token_idx] = 0
170 |
171 | start_label_masks.append(start_label_mask)
172 | end_label_masks.append(end_label_mask)
173 |
174 | batch_encodings = dict(batch_encodings)
175 | batch_encodings['start_label_masks'] = torch.tensor(start_label_masks, dtype=torch.long)
176 | batch_encodings['end_label_masks'] = torch.tensor(end_label_masks, dtype=torch.long)
177 |
178 | return batch_encodings
179 |
180 |
181 | class ASTEDataModule(pl.LightningDataModule):
182 | def __init__(self,
183 | model_name_or_path: str='',
184 | max_seq_length: int = -1,
185 | train_batch_size: int = 32,
186 | eval_batch_size: int = 32,
187 | data_dir: str = '',
188 | num_workers: int = 4,
189 | cuda_ids: int = -1,
190 | ):
191 |
192 | super().__init__()
193 |
194 | self.model_name_or_path = model_name_or_path
195 | self.max_seq_length = max_seq_length if max_seq_length > 0 else 'longest'
196 | self.train_batch_size = train_batch_size
197 | self.eval_batch_size = eval_batch_size
198 |
199 | self.data_dir = data_dir
200 | self.num_workers = num_workers
201 | self.cuda_ids = cuda_ids
202 |
203 | self.table_num_labels = 6 # 4
204 |
205 | try:
206 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
207 | except:
208 | self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)
209 |
210 | def load_dataset(self):
211 | train_file_name = os.path.join(self.data_dir, 'train.json')
212 | dev_file_name = os.path.join(self.data_dir, 'dev.json')
213 | test_file_name = os.path.join(self.data_dir, 'test.json')
214 |
215 | if not os.path.exists(dev_file_name):
216 | dev_file_name = test_file_name
217 |
218 | train_examples = [Example(data, self.max_seq_length) for data in load_json(train_file_name)]
219 | dev_examples = [Example(data, self.max_seq_length) for data in load_json(dev_file_name)]
220 | test_examples = [Example(data, self.max_seq_length) for data in load_json(test_file_name)]
221 |
222 | self.raw_datasets = {
223 | 'train': train_examples,
224 | 'dev' : dev_examples,
225 | 'test' : test_examples
226 | }
227 |
228 | def get_dataloader(self, mode, batch_size, shuffle):
229 | dataloader = DataLoader(
230 | dataset=self.raw_datasets[mode],
231 | batch_size=batch_size,
232 | shuffle=shuffle,
233 | num_workers=self.num_workers,
234 | collate_fn=DataCollatorForASTE(tokenizer=self.tokenizer,
235 | max_seq_length=self.max_seq_length),
236 | pin_memory=True,
237 | prefetch_factor=16
238 | )
239 |
240 | print(mode, len(dataloader))
241 | return dataloader
242 |
243 | def train_dataloader(self):
244 | return self.get_dataloader('train', self.train_batch_size, shuffle=True)
245 |
246 | def val_dataloader(self):
247 | return self.get_dataloader("dev", self.eval_batch_size, shuffle=False)
248 |
249 | def test_dataloader(self):
250 | return self.get_dataloader("test", self.eval_batch_size, shuffle=False)
251 |
--------------------------------------------------------------------------------
/code/aste_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | import argparse
5 | import random
6 | import numpy as np
7 |
8 |
9 | import pytorch_lightning as pl
10 |
11 | pl.seed_everything(42)
12 |
13 | from transformers import AutoTokenizer, AutoConfig
14 | from transformers.optimization import AdamW
15 | from pytorch_lightning.utilities import rank_zero_info
16 | from transformers import (
17 | get_linear_schedule_with_warmup,
18 | get_cosine_schedule_with_warmup,
19 | get_cosine_with_hard_restarts_schedule_with_warmup,
20 | get_polynomial_decay_schedule_with_warmup,
21 | get_constant_schedule_with_warmup
22 | )
23 |
24 | arg_to_scheduler = {
25 | 'linear': get_linear_schedule_with_warmup,
26 | 'cosine': get_cosine_schedule_with_warmup,
27 | 'cosine_w_restarts': get_cosine_with_hard_restarts_schedule_with_warmup,
28 | 'polynomial': get_polynomial_decay_schedule_with_warmup,
29 | 'constant': get_constant_schedule_with_warmup,
30 | }
31 |
32 |
33 |
34 | from model.bdtf_model import BDTFModel
35 | from utils.aste_datamodule import ASTEDataModule
36 | from utils.aste_result import Result
37 | from utils import params_count
38 |
39 |
40 | logger = logging.getLogger(__name__)
41 |
42 |
43 | class ASTE(pl.LightningModule):
44 | def __init__(self, hparams, data_module):
45 | super().__init__()
46 | self.save_hyperparameters(hparams)
47 | self.data_module = data_module
48 |
49 | self.config = AutoConfig.from_pretrained(self.hparams.model_name_or_path)
50 | self.config.table_num_labels = self.data_module.table_num_labels
51 | self.config.table_encoder = self.hparams.table_encoder
52 | self.config.num_table_layers = self.hparams.num_table_layers
53 | self.config.span_pruning = self.hparams.span_pruning
54 | self.config.seq2mat = self.hparams.seq2mat
55 | self.config.num_d = self.hparams.num_d
56 |
57 | self.model = BDTFModel.from_pretrained(self.hparams.model_name_or_path, config=self.config)
58 |
59 | print(self.model.config)
60 | print('---------------------------------------')
61 | print('total params_count:', params_count(self.model))
62 | print('---------------------------------------')
63 |
64 | @pl.utilities.rank_zero_only
65 | def save_model(self):
66 | dir_name = os.path.join(self.hparams.output_dir, str(self.hparams.cuda_ids), 'model')
67 | print(f'## save model to {dir_name}')
68 | self.model.save_pretrained(dir_name)
69 |
70 | def load_model(self):
71 | dir_name = os.path.join(self.hparams.output_dir, str(self.hparams.cuda_ids), 'model')
72 | print(f'## load model to {dir_name}')
73 | self.model = BDTFModel.from_pretrained(dir_name)
74 |
75 | def forward(self, **inputs):
76 | outputs = self.model(**inputs)
77 | return outputs
78 |
79 | def training_step(self, batch, batch_idx):
80 | outputs = self(**batch)
81 | loss = outputs['table_loss_S'] + outputs['table_loss_E'] + outputs['pair_loss']
82 |
83 | self.log('train_loss', loss)
84 | return loss
85 |
86 | def validation_step(self, batch, batch_idx):
87 | outputs = self(**batch)
88 | loss = outputs['table_loss_S'] + outputs['table_loss_E'] + outputs['pair_loss']
89 |
90 | self.log('valid_loss', loss)
91 |
92 | return {
93 | 'ids': outputs['ids'],
94 | 'table_predict_S': outputs['table_predict_S'],
95 | 'table_predict_E': outputs['table_predict_E'],
96 | 'table_labels_S': outputs['table_labels_S'],
97 | 'table_labels_E': outputs['table_labels_E'],
98 | 'pair_preds': outputs['pairs_preds']
99 | }
100 |
101 | def validation_epoch_end(self, outputs):
102 | examples = self.data_module.raw_datasets['dev']
103 |
104 | self.current_val_result = Result.parse_from(outputs, examples)
105 | self.current_val_result.cal_metric()
106 |
107 | if not hasattr(self, 'best_val_result'):
108 | self.best_val_result = self.current_val_result
109 |
110 | elif self.best_val_result < self.current_val_result:
111 | self.best_val_result = self.current_val_result
112 | self.save_model()
113 |
114 | def test_step(self, batch, batch_idx):
115 | return self.validation_step(batch,batch_idx)
116 |
117 | def test_epoch_end(self, outputs):
118 | examples = self.data_module.raw_datasets['test']
119 |
120 | self.test_result = Result.parse_from(outputs, examples)
121 | self.test_result.cal_metric()
122 |
123 | def save_test_result(self):
124 | dir_name = os.path.join(self.hparams.output_dir, 'result')
125 | self.test_result.save(dir_name, self.hparams)
126 |
127 | def setup(self, stage):
128 | if stage == 'fit':
129 | self.train_loader = self.train_dataloader()
130 | ngpus = (len(self.hparams.gpus.split(',')) if type(self.hparams.gpus) is str else self.hparams.gpus)
131 | effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * ngpus
132 | dataset_size = len(self.train_loader.dataset)
133 | self.total_steps = (dataset_size / effective_batch_size) * self.hparams.max_epochs
134 |
135 | def get_lr_scheduler(self):
136 | get_scheduler_func = arg_to_scheduler[self.hparams.lr_scheduler]
137 | if self.hparams.lr_scheduler == 'constant':
138 | scheduler = get_scheduler_func(self.opt, num_warmup_steps=self.hparams.warmup_steps)
139 | else:
140 | scheduler = get_scheduler_func(self.opt, num_warmup_steps=self.hparams.warmup_steps,
141 | num_training_steps=self.total_steps)
142 | scheduler = {'scheduler': scheduler, 'interval': 'step', 'frequency': 1}
143 | return scheduler
144 |
145 | def configure_optimizers(self):
146 | no_decay = ['bias', 'LayerNorm.weight']
147 |
148 | def has_keywords(n, keywords):
149 | return any(nd in n for nd in keywords)
150 |
151 | optimizer_grouped_parameters = [
152 | {
153 | 'params': [p for n, p in self.model.named_parameters() if not has_keywords(n, no_decay)],
154 | 'lr': self.hparams.learning_rate,
155 | 'weight_decay': 0
156 | },
157 | {
158 | 'params': [p for n, p in self.model.named_parameters() if has_keywords(n, no_decay)],
159 | 'lr': self.hparams.learning_rate,
160 | 'weight_decay': self.hparams.weight_decay
161 | }
162 | ]
163 |
164 | optimizer = AdamW(optimizer_grouped_parameters, eps=self.hparams.adam_epsilon)
165 | self.opt = optimizer
166 | scheduler = self.get_lr_scheduler()
167 |
168 | return [optimizer], [scheduler]
169 |
170 | @staticmethod
171 | def add_model_specific_args(parser):
172 | parser.add_argument("--learning_rate", default=1e-5, type=float)
173 | parser.add_argument("--adam_epsilon", default=1e-8, type=float)
174 | parser.add_argument("--warmup_steps", default=0, type=int)
175 | parser.add_argument("--weight_decay", default=0., type=float)
176 | parser.add_argument("--lr_scheduler", type=str)
177 |
178 | parser.add_argument("--seed", default=42, type=int)
179 | parser.add_argument("--output_dir", type=str)
180 | parser.add_argument("--do_train", action='store_true')
181 |
182 | parser.add_argument("--table_encoder", type=str, default='resnet', choices=['resnet','none'])
183 | parser.add_argument("--num_table_layers", type=int, default=2)
184 | parser.add_argument("--span_pruning", type=float, default=0.3)
185 | parser.add_argument("--seq2mat", type=str, default='none',choices=['none','tensor','context','tensorcontext'])
186 | parser.add_argument("--num_d", type=int, default=64)
187 | return parser
188 |
189 |
190 | class LoggingCallback(pl.Callback):
191 | def on_validation_end(self, trainer, pl_module):
192 | metrics = trainer.callback_metrics
193 | metrics = {k:(v.detach() if type(v) is torch.Tensor else v) for k,v in metrics.items()}
194 | rank_zero_info(metrics)
195 |
196 |
197 | class LoggingCallback(pl.Callback):
198 | def on_validation_end(self, trainer, pl_module):
199 | print('-------------------------------------------------------------------------------------------------------------------\n[current]\t', end='')
200 | pl_module.current_val_result.report()
201 |
202 | print('[best]\t\t', end='')
203 | pl_module.best_val_result.report()
204 | print('-------------------------------------------------------------------------------------------------------------------\n')
205 |
206 | def on_test_end(self, trainer, pl_module):
207 | pl_module.test_result.report()
208 | pl_module.save_test_result()
209 |
210 |
211 | def main():
212 | parser = argparse.ArgumentParser()
213 | parser = pl.Trainer.add_argparse_args(parser)
214 | parser = ASTE.add_model_specific_args(parser)
215 | parser = ASTEDataModule.add_argparse_args(parser)
216 |
217 | args = parser.parse_args()
218 | pl.seed_everything(args.seed)
219 |
220 | if args.learning_rate >= 1:
221 | args.learning_rate /= 1e5
222 |
223 |
224 | data_module = ASTEDataModule.from_argparse_args(args)
225 | data_module.load_dataset()
226 | model = ASTE(args, data_module)
227 |
228 | logging_callback = LoggingCallback()
229 |
230 | kwargs = {
231 | 'weights_summary': None,
232 | 'callbacks': [logging_callback],
233 | 'logger': True,
234 | 'checkpoint_callback': False,
235 | 'num_sanity_val_steps': 5 if args.do_train else 0,
236 | }
237 |
238 | trainer = pl.Trainer.from_argparse_args(args, **kwargs)
239 |
240 | trainer.fit(model, datamodule=data_module)
241 | model.load_model()
242 | trainer.test(model, datamodule=data_module)
243 |
244 | if __name__ == '__main__':
245 | main()
--------------------------------------------------------------------------------