├── README.md ├── models ├── BertParser.py ├── CharParser.py ├── __init__.py ├── callbacks.py └── metrics.py ├── modules ├── __init__.py └── pipe.py ├── requirements.txt ├── train.py └── train_bert.py /README.md: -------------------------------------------------------------------------------- 1 | ## A Unified Model for Joint Chinese Word Segmentation and Dependency Parsing 2 | 3 | This is the code for the paper [A Unified Model for Joint Chinese Word Segmentation and Dependency Parsing](https://arxiv.org/abs/1904.04697) 4 | 5 | #### Requirements 6 | This project needs the natural language processing python package 7 | [fastNLP](https://github.com/fastnlp/fastNLP). You can install by 8 | the following command 9 | 10 | ```bash 11 | pip install fastNLP 12 | ``` 13 | 14 | 15 | ### Data 16 | Your data should in the format as following 17 | ``` 18 | 1 中国 _ NR NR _ 4 nn _ _ 19 | 2 残疾人 _ NN NN _ 4 nn _ _ 20 | 3 体育 _ NN NN _ 4 nn _ _ 21 | 4 事业 _ NN NN _ 5 nsubj _ _ 22 | 5 方兴未艾 _ VV VV _ 0 root _ _ 23 | 24 | 1 新华社 _ NR NR _ 12 dep _ _ 25 | ``` 26 | The 1st, 3rd, 6th, 7th(starts from 0) column should be words, pos tags, 27 | dependency heads and dependency labels, respectively. Empty line separate 28 | two instances. 29 | 30 | You should place your data like the following structure 31 | ``` 32 | -JointCwsParser 33 | ... 34 | -train.py 35 | -train_bert.py 36 | -data 37 | -ctb5 38 | -train.conll 39 | -dev.conll 40 | -test.conll 41 | -ctb7 42 | -... 43 | -ctb9 44 | -... 45 | ``` 46 | We use code from https://github.com/hankcs/TreebankPreprocessing to convert the original format into the conll format. 47 | 48 | 49 | ### Run the code 50 | You can directly run by 51 | ``` 52 | python train.py --dataset ctb5 53 | ``` 54 | or 55 | ``` 56 | python train_bert.py --dataset ctb5 57 | ``` 58 | FastNLP will download pretrained embeddings or BERT weight automatically. 59 | 60 | -------------------------------------------------------------------------------- /models/BertParser.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 使用Bert,上面直接接一个biaffine 4 | 5 | from torch import nn 6 | from fastNLP.modules.dropout import TimestepDropout 7 | import torch 8 | 9 | from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear, BiaffineParser 10 | import torch.nn.functional as F 11 | 12 | 13 | class BertParser(BiaffineParser): 14 | def __init__(self, embed, num_label, arc_mlp_size=500, label_mlp_size=100, dropout=0.5, use_greedy_infer=False, app_index=0): 15 | super(BiaffineParser, self).__init__() 16 | 17 | self.embed = embed 18 | 19 | self.mlp = nn.Sequential(nn.Linear(self.embed.embed_size, arc_mlp_size * 2 + label_mlp_size * 2), 20 | nn.LeakyReLU(0.1), 21 | TimestepDropout(p=dropout),) 22 | self.arc_mlp_size = arc_mlp_size 23 | self.label_mlp_size = label_mlp_size 24 | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) 25 | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) 26 | self.use_greedy_infer = use_greedy_infer 27 | self.reset_parameters() 28 | 29 | self.app_index = app_index 30 | self.num_label = num_label 31 | if self.app_index != 0: 32 | raise ValueError("现在app_index必须等于0") 33 | 34 | self.dropout = nn.Dropout(dropout) 35 | 36 | def reset_parameters(self): 37 | for name, m in self.named_modules(): 38 | if 'embed' in name: 39 | pass 40 | elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'): 41 | pass 42 | else: 43 | for p in m.parameters(): 44 | if len(p.size())>1: 45 | nn.init.xavier_normal_(p, gain=0.1) 46 | else: 47 | nn.init.uniform_(p, -0.1, 0.1) 48 | 49 | def _forward(self, chars, gold_heads=None, char_labels=None): 50 | batch_size, max_len = chars.shape 51 | 52 | feats = self.embed(chars) 53 | mask = chars.ne(0) 54 | feats = self.dropout(feats) 55 | feats = self.mlp(feats) 56 | arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size 57 | arc_dep, arc_head = feats[:,:,:arc_sz], feats[:,:,arc_sz:2*arc_sz] 58 | label_dep, label_head = feats[:,:,2*arc_sz:2*arc_sz+label_sz], feats[:,:,2*arc_sz+label_sz:] 59 | 60 | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] 61 | 62 | if gold_heads is None or not self.training: 63 | # use greedy decoding in training 64 | if self.training or self.use_greedy_infer: 65 | heads = self.greedy_decoder(arc_pred, mask) 66 | else: 67 | heads = self.mst_decoder(arc_pred, mask) 68 | head_pred = heads 69 | else: 70 | assert self.training # must be training mode 71 | if gold_heads is None: 72 | heads = self.greedy_decoder(arc_pred, mask) 73 | head_pred = heads 74 | else: 75 | head_pred = None 76 | heads = gold_heads 77 | 78 | batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1) 79 | label_head = label_head[batch_range, heads].contiguous() 80 | label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label] 81 | # 这里限制一下,只有当head为下一个时,才能预测app这个label 82 | arange_index = torch.arange(1, max_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\ 83 | .repeat(batch_size, 1) # batch_size x max_len 84 | app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app 85 | app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label) 86 | app_masks[:, :, 1:] = 0 87 | label_pred = label_pred.masked_fill(app_masks, float('-inf')) 88 | if gold_heads is not None: 89 | res_dict = {'loss':self.loss(arc_pred, label_pred, gold_heads, char_labels, mask)} 90 | else: 91 | res_dict = {'label_preds': label_pred.max(2)[1], 'head_preds': head_pred} 92 | return res_dict 93 | 94 | def forward(self, chars, char_heads, char_labels): 95 | return self._forward(chars, gold_heads=char_heads, char_labels=char_labels) 96 | 97 | @staticmethod 98 | def loss(arc_pred, label_pred, arc_true, label_true, mask): 99 | """ 100 | Compute loss. 101 | 102 | :param arc_pred: [batch_size, seq_len, seq_len] 103 | :param label_pred: [batch_size, seq_len, n_tags] 104 | :param arc_true: [batch_size, seq_len] 105 | :param label_true: [batch_size, seq_len] 106 | :param mask: [batch_size, seq_len] 107 | :return: loss value 108 | """ 109 | 110 | batch_size, seq_len, _ = arc_pred.shape 111 | flip_mask = (mask == 0) 112 | # _arc_pred = arc_pred.clone() 113 | _arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) 114 | 115 | arc_true.data[:, 0].fill_(-1) 116 | label_true.data[:, 0].fill_(-1) 117 | 118 | arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) 119 | label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) 120 | 121 | return arc_nll + label_nll 122 | 123 | def predict(self, chars): 124 | """ 125 | 126 | max_len是包含root的 127 | 128 | :param chars: batch_size x max_len 129 | :return: 130 | """ 131 | res = self._forward(chars, gold_heads=None) 132 | return res -------------------------------------------------------------------------------- /models/CharParser.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from fastNLP.models.biaffine_parser import BiaffineParser 5 | from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from fastNLP.modules.dropout import TimestepDropout 13 | from fastNLP.modules.encoder.variational_rnn import VarLSTM 14 | from fastNLP import seq_len_to_mask 15 | from fastNLP.embeddings import Embedding 16 | 17 | 18 | def drop_input_independent(word_embeddings, dropout_emb): 19 | batch_size, seq_length, _ = word_embeddings.size() 20 | word_masks = word_embeddings.new(batch_size, seq_length).fill_(1 - dropout_emb) 21 | word_masks = torch.bernoulli(word_masks) 22 | word_masks = word_masks.unsqueeze(dim=2) 23 | word_embeddings = word_embeddings * word_masks 24 | 25 | return word_embeddings 26 | 27 | 28 | class CharBiaffineParser(BiaffineParser): 29 | def __init__(self, char_vocab_size, 30 | emb_dim, 31 | bigram_vocab_size, 32 | trigram_vocab_size, 33 | num_label, 34 | rnn_layers=3, 35 | rnn_hidden_size=800, #单向的数量 36 | arc_mlp_size=500, 37 | label_mlp_size=100, 38 | dropout=0.3, 39 | encoder='lstm', 40 | use_greedy_infer=False, 41 | app_index = 0, 42 | pre_chars_embed=None, 43 | pre_bigrams_embed=None, 44 | pre_trigrams_embed=None): 45 | 46 | 47 | super(BiaffineParser, self).__init__() 48 | rnn_out_size = 2 * rnn_hidden_size 49 | self.char_embed = Embedding((char_vocab_size, emb_dim)) 50 | self.bigram_embed = Embedding((bigram_vocab_size, emb_dim)) 51 | self.trigram_embed = Embedding((trigram_vocab_size, emb_dim)) 52 | if pre_chars_embed: 53 | self.pre_char_embed = Embedding(pre_chars_embed) 54 | self.pre_char_embed.requires_grad = False 55 | if pre_bigrams_embed: 56 | self.pre_bigram_embed = Embedding(pre_bigrams_embed) 57 | self.pre_bigram_embed.requires_grad = False 58 | if pre_trigrams_embed: 59 | self.pre_trigram_embed = Embedding(pre_trigrams_embed) 60 | self.pre_trigram_embed.requires_grad = False 61 | self.timestep_drop = TimestepDropout(dropout) 62 | self.encoder_name = encoder 63 | 64 | if encoder == 'var-lstm': 65 | self.encoder = VarLSTM(input_size=emb_dim*3, 66 | hidden_size=rnn_hidden_size, 67 | num_layers=rnn_layers, 68 | bias=True, 69 | batch_first=True, 70 | input_dropout=dropout, 71 | hidden_dropout=dropout, 72 | bidirectional=True) 73 | elif encoder == 'lstm': 74 | self.encoder = nn.LSTM(input_size=emb_dim*3, 75 | hidden_size=rnn_hidden_size, 76 | num_layers=rnn_layers, 77 | bias=True, 78 | batch_first=True, 79 | dropout=dropout, 80 | bidirectional=True) 81 | 82 | else: 83 | raise ValueError('unsupported encoder type: {}'.format(encoder)) 84 | 85 | self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), 86 | nn.LeakyReLU(0.1), 87 | TimestepDropout(p=dropout),) 88 | self.arc_mlp_size = arc_mlp_size 89 | self.label_mlp_size = label_mlp_size 90 | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) 91 | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) 92 | self.use_greedy_infer = use_greedy_infer 93 | self.reset_parameters() 94 | self.dropout = dropout 95 | 96 | self.app_index = app_index 97 | self.num_label = num_label 98 | if self.app_index != 0: 99 | raise ValueError("现在app_index必须等于0") 100 | 101 | def reset_parameters(self): 102 | for name, m in self.named_modules(): 103 | if 'embed' in name: 104 | pass 105 | elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'): 106 | pass 107 | else: 108 | for p in m.parameters(): 109 | if len(p.size())>1: 110 | nn.init.xavier_normal_(p, gain=0.1) 111 | else: 112 | nn.init.uniform_(p, -0.1, 0.1) 113 | 114 | def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None, 115 | pre_trigrams=None): 116 | """ 117 | max_len是包含root的 118 | :param chars: batch_size x max_len 119 | :param ngrams: batch_size x max_len*ngram_per_char 120 | :param seq_lens: batch_size 121 | :param gold_heads: batch_size x max_len 122 | :param pre_chars: batch_size x max_len 123 | :param pre_ngrams: batch_size x max_len*ngram_per_char 124 | :return dict: parsing results 125 | arc_pred: [batch_size, seq_len, seq_len] 126 | label_pred: [batch_size, seq_len, seq_len] 127 | mask: [batch_size, seq_len] 128 | head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads 129 | """ 130 | # prepare embeddings 131 | batch_size, seq_len = chars.shape 132 | # print('forward {} {}'.format(batch_size, seq_len)) 133 | 134 | # get sequence mask 135 | mask = seq_len_to_mask(seq_lens).long() 136 | 137 | chars = self.char_embed(chars) # [N,L] -> [N,L,C_0] 138 | bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1] 139 | trigrams = self.trigram_embed(trigrams) 140 | 141 | if pre_chars is not None: 142 | pre_chars = self.pre_char_embed(pre_chars) 143 | # pre_chars = self.pre_char_fc(pre_chars) 144 | chars = pre_chars + chars 145 | if pre_bigrams is not None: 146 | pre_bigrams = self.pre_bigram_embed(pre_bigrams) 147 | # pre_bigrams = self.pre_bigram_fc(pre_bigrams) 148 | bigrams = bigrams + pre_bigrams 149 | if pre_trigrams is not None: 150 | pre_trigrams = self.pre_trigram_embed(pre_trigrams) 151 | # pre_trigrams = self.pre_trigram_fc(pre_trigrams) 152 | trigrams = trigrams + pre_trigrams 153 | 154 | x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C] 155 | 156 | # encoder, extract features 157 | if self.training: 158 | x = drop_input_independent(x, self.dropout) 159 | sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) 160 | x = x[sort_idx] 161 | x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) 162 | feat, _ = self.encoder(x) # -> [N,L,C] 163 | feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) 164 | _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) 165 | feat = feat[unsort_idx] 166 | feat = self.timestep_drop(feat) 167 | 168 | # for arc biaffine 169 | # mlp, reduce dim 170 | feat = self.mlp(feat) 171 | arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size 172 | arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] 173 | label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] 174 | 175 | # biaffine arc classifier 176 | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] 177 | 178 | # use gold or predicted arc to predict label 179 | if gold_heads is None or not self.training: 180 | # use greedy decoding in training 181 | if self.training or self.use_greedy_infer: 182 | heads = self.greedy_decoder(arc_pred, mask) 183 | else: 184 | heads = self.mst_decoder(arc_pred, mask) 185 | head_pred = heads 186 | else: 187 | assert self.training # must be training mode 188 | if gold_heads is None: 189 | heads = self.greedy_decoder(arc_pred, mask) 190 | head_pred = heads 191 | else: 192 | head_pred = None 193 | heads = gold_heads 194 | # heads: batch_size x max_len 195 | 196 | batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1) 197 | label_head = label_head[batch_range, heads].contiguous() 198 | label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label] 199 | # 这里限制一下,只有当head为下一个时,才能预测app这个label 200 | arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\ 201 | .repeat(batch_size, 1) # batch_size x max_len 202 | app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app 203 | app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label) 204 | app_masks[:, :, 1:] = 0 205 | label_pred = label_pred.masked_fill(app_masks, -np.inf) 206 | 207 | res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} 208 | if head_pred is not None: 209 | res_dict['head_pred'] = head_pred 210 | return res_dict 211 | 212 | @staticmethod 213 | def loss(arc_pred, label_pred, arc_true, label_true, mask): 214 | """ 215 | Compute loss. 216 | 217 | :param arc_pred: [batch_size, seq_len, seq_len] 218 | :param label_pred: [batch_size, seq_len, n_tags] 219 | :param arc_true: [batch_size, seq_len] 220 | :param label_true: [batch_size, seq_len] 221 | :param mask: [batch_size, seq_len] 222 | :return: loss value 223 | """ 224 | 225 | batch_size, seq_len, _ = arc_pred.shape 226 | flip_mask = (mask == 0) 227 | # _arc_pred = arc_pred.clone() 228 | _arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) 229 | 230 | arc_true.data[:, 0].fill_(-1) 231 | label_true.data[:, 0].fill_(-1) 232 | 233 | arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) 234 | label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) 235 | 236 | return arc_nll + label_nll 237 | 238 | def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars, pre_bigrams, pre_trigrams): 239 | """ 240 | 241 | max_len是包含root的 242 | 243 | :param chars: batch_size x max_len 244 | :param ngrams: batch_size x max_len*ngram_per_char 245 | :param seq_lens: batch_size 246 | :param pre_chars: batch_size x max_len 247 | :param pre_ngrams: batch_size x max_len*ngram_per_cha 248 | :return: 249 | """ 250 | res = self(chars, bigrams, trigrams, seq_lens, pre_chars=pre_chars, pre_bigrams=pre_bigrams, 251 | pre_trigrams=pre_trigrams, gold_heads=None) 252 | output = {} 253 | output['arc_pred'] = res.pop('head_pred') 254 | _, label_pred = res.pop('label_pred').max(2) 255 | output['label_pred'] = label_pred 256 | return output 257 | 258 | class CharParser(nn.Module): 259 | def __init__(self, char_vocab_size, 260 | emb_dim, 261 | bigram_vocab_size, 262 | trigram_vocab_size, 263 | num_label, 264 | rnn_layers=3, 265 | rnn_hidden_size=400, #单向的数量 266 | arc_mlp_size=500, 267 | label_mlp_size=100, 268 | dropout=0.3, 269 | encoder='var-lstm', 270 | use_greedy_infer=False, 271 | app_index = 0, 272 | pre_chars_embed=None, 273 | pre_bigrams_embed=None, 274 | pre_trigrams_embed=None): 275 | super().__init__() 276 | 277 | self.parser = CharBiaffineParser(char_vocab_size, 278 | emb_dim, 279 | bigram_vocab_size, 280 | trigram_vocab_size, 281 | num_label, 282 | rnn_layers, 283 | rnn_hidden_size, #单向的数量 284 | arc_mlp_size, 285 | label_mlp_size, 286 | dropout, 287 | encoder, 288 | use_greedy_infer, 289 | app_index, 290 | pre_chars_embed=pre_chars_embed, 291 | pre_bigrams_embed=pre_bigrams_embed, 292 | pre_trigrams_embed=pre_trigrams_embed) 293 | 294 | def forward(self, chars, bigrams, trigrams, seq_lens, char_heads, char_labels, pre_chars=None, pre_bigrams=None, 295 | pre_trigrams=None): 296 | res_dict = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=char_heads, pre_chars=pre_chars, 297 | pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) 298 | arc_pred = res_dict['arc_pred'] 299 | label_pred = res_dict['label_pred'] 300 | masks = res_dict['mask'] 301 | loss = self.parser.loss(arc_pred, label_pred, char_heads, char_labels, masks) 302 | return {'loss': loss} 303 | 304 | def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars=None, pre_bigrams=None, pre_trigrams=None): 305 | res = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=pre_chars, 306 | pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) 307 | output = {} 308 | output['head_preds'] = res.pop('head_pred') 309 | _, label_pred = res.pop('label_pred').max(2) 310 | output['label_preds'] = label_pred 311 | return output 312 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fastnlp/JointCwsParser/20f1121f3c19f359fdb6c34b503748c46c858a56/models/__init__.py -------------------------------------------------------------------------------- /models/callbacks.py: -------------------------------------------------------------------------------- 1 | 2 | from fastNLP.core.callback import Callback 3 | import torch 4 | from torch import nn 5 | 6 | class OptimizerCallback(Callback): 7 | def __init__(self, optimizer, scheduler, update_every=4): 8 | super().__init__() 9 | 10 | self._optimizer = optimizer 11 | self.scheduler = scheduler 12 | self._update_every = update_every 13 | 14 | def on_backward_end(self): 15 | if self.step % self._update_every==0: 16 | # nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5) 17 | # self._optimizer.step() 18 | self.scheduler.step() 19 | # self.model.zero_grad() 20 | 21 | 22 | class DevCallback(Callback): 23 | def __init__(self, tester, metric_key='u_f1'): 24 | super().__init__() 25 | self.tester = tester 26 | setattr(tester, 'verbose', 0) 27 | 28 | self.metric_key = metric_key 29 | 30 | self.record_best = False 31 | self.best_eval_value = 0 32 | self.best_eval_res = None 33 | 34 | self.best_dev_res = None # 存取dev的表现 35 | 36 | def on_valid_begin(self): 37 | eval_res = self.tester.test() 38 | metric_name = self.tester.metrics[0].__class__.__name__ 39 | metric_value = eval_res[metric_name][self.metric_key] 40 | if metric_value>self.best_eval_value: 41 | self.best_eval_value = metric_value 42 | self.best_epoch = self.trainer.epoch 43 | self.record_best = True 44 | self.best_eval_res = eval_res 45 | self.test_eval_res = eval_res 46 | eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \ 47 | self.tester._format_eval_results(eval_res) 48 | self.pbar.write(eval_str) 49 | 50 | def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): 51 | if self.record_best: 52 | self.best_dev_res = eval_result 53 | self.record_best = False 54 | if is_better_eval: 55 | self.best_dev_res_on_dev = eval_result 56 | self.best_test_res_on_dev = self.test_eval_res 57 | self.dev_epoch = self.epoch 58 | 59 | def on_train_end(self): 60 | print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch, 61 | self.tester._format_eval_results(self.best_eval_res), 62 | self.tester._format_eval_results(self.best_dev_res))) 63 | print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch, 64 | self.tester._format_eval_results(self.best_test_res_on_dev), 65 | self.tester._format_eval_results(self.best_dev_res_on_dev))) 66 | 67 | 68 | from fastNLP import Callback, Tester, DataSet 69 | 70 | 71 | class EvaluateCallback(Callback): 72 | """ 73 | 通过使用该Callback可以使得Trainer在evaluate dev之外还可以evaluate其它数据集,比如测试集。每一次验证dev之前都会先验证EvaluateCallback 74 | 中的数据。 75 | """ 76 | 77 | def __init__(self, data=None, tester=None): 78 | """ 79 | :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用Trainer中的metric对数据进行验证。如果需要传入多个 80 | DataSet请通过dict的方式传入。 81 | :param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象, 通过使用Tester对象,可以使得验证的metric与Trainer中 82 | 的metric不一样。 83 | """ 84 | super().__init__() 85 | self.datasets = {} 86 | self.testers = {} 87 | self.best_test_metric_sofar = 0 88 | self.best_test_sofar = None 89 | self.best_test_epoch = 0 90 | self.best_dev_test = None 91 | self.best_dev_epoch = 0 92 | if tester is not None: 93 | if isinstance(tester, dict): 94 | for name, test in tester.items(): 95 | if not isinstance(test, Tester): 96 | raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") 97 | self.testers['tester-' + name] = test 98 | if isinstance(tester, Tester): 99 | self.testers['tester-test'] = tester 100 | for tester in self.testers.values(): 101 | setattr(tester, 'verbose', 0) 102 | 103 | if isinstance(data, dict): 104 | for key, value in data.items(): 105 | assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." 106 | for key, value in data.items(): 107 | self.datasets['data-' + key] = value 108 | elif isinstance(data, DataSet): 109 | self.datasets['data-test'] = data 110 | elif data is not None: 111 | raise TypeError("data receives dict[DataSet] or DataSet object.") 112 | 113 | def on_train_begin(self): 114 | if len(self.datasets) > 0 and self.trainer.dev_data is None: 115 | raise RuntimeError("Trainer has no dev data, you cannot pass extra DataSet to do evaluation.") 116 | 117 | if len(self.datasets) > 0: 118 | for key, data in self.datasets.items(): 119 | tester = Tester(data=data, model=self.model, 120 | batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), 121 | metrics=self.trainer.metrics, verbose=0, 122 | use_tqdm=self.trainer.test_use_tqdm) 123 | self.testers[key] = tester 124 | 125 | def on_valid_end(self, eval_result, metric_key, optimizer, better_result): 126 | if len(self.testers) > 0: 127 | for idx, (key, tester) in enumerate(self.testers.items()): 128 | try: 129 | eval_result = tester.test() 130 | if idx == 0: 131 | indicator, indicator_val = _check_eval_results(eval_result) 132 | if indicator_val>self.best_test_metric_sofar: 133 | self.best_test_metric_sofar = indicator_val 134 | self.best_test_epoch = self.epoch 135 | self.best_test_sofar = eval_result 136 | if better_result: 137 | self.best_dev_test = eval_result 138 | self.best_dev_epoch = self.epoch 139 | self.logger.info("EvaluateCallback evaluation on {}:".format(key)) 140 | self.logger.info(tester._format_eval_results(eval_result)) 141 | except Exception as e: 142 | self.logger.error("Exception happens when evaluate on DataSet named `{}`.".format(key)) 143 | raise e 144 | 145 | def on_train_end(self): 146 | if self.best_test_sofar: 147 | self.logger.info("Best test performance(may not correspond to the best dev performance):{} achieved at Epoch:{}.".format(self.best_test_sofar, self.best_test_epoch)) 148 | if self.best_dev_test: 149 | self.logger.info("Best test performance(correspond to the best dev performance):{} achieved at Epoch:{}.".format(self.best_dev_test, self.best_dev_epoch)) 150 | 151 | 152 | def _check_eval_results(metrics, metric_key=None): 153 | # metrics: tester返回的结果 154 | # metric_key: 一个用来做筛选的指标,来自Trainer的初始化 155 | if isinstance(metrics, tuple): 156 | loss, metrics = metrics 157 | 158 | if isinstance(metrics, dict): 159 | metric_dict = list(metrics.values())[0] # 取第一个metric 160 | 161 | if metric_key is None: 162 | indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] 163 | else: 164 | # metric_key is set 165 | if metric_key not in metric_dict: 166 | raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") 167 | indicator_val = metric_dict[metric_key] 168 | indicator = metric_key 169 | else: 170 | raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) 171 | return indicator, indicator_val 172 | -------------------------------------------------------------------------------- /models/metrics.py: -------------------------------------------------------------------------------- 1 | from fastNLP.core.metrics import MetricBase 2 | from fastNLP.core.utils import seq_len_to_mask 3 | import torch 4 | 5 | 6 | class SegAppCharParseF1Metric(MetricBase): 7 | # 8 | def __init__(self, app_index): 9 | super().__init__() 10 | self.app_index = app_index 11 | 12 | self.parse_head_tp = 0 13 | self.parse_label_tp = 0 14 | self.rec_tol = 0 15 | self.pre_tol = 0 16 | 17 | def evaluate(self, gold_word_pairs, gold_label_word_pairs, head_preds, label_preds, seq_lens, 18 | pun_masks): 19 | """ 20 | 21 | max_len是不包含root的character的长度 22 | :param gold_word_pairs: List[List[((head_start, head_end), (dep_start, dep_end)), ...]], batch_size 23 | :param gold_label_word_pairs: List[List[((head_start, head_end), label, (dep_start, dep_end)), ...]], batch_size 24 | :param head_preds: batch_size x max_len 25 | :param label_preds: batch_size x max_len 26 | :param seq_lens: 27 | :param pun_masks: batch_size x 28 | :return: 29 | """ 30 | # 去掉root 31 | head_preds = head_preds[:, 1:].tolist() 32 | label_preds = label_preds[:, 1:].tolist() 33 | seq_lens = (seq_lens - 1).tolist() 34 | 35 | # 先解码出words,POS,heads, labels, 对应的character范围 36 | for b in range(len(head_preds)): 37 | seq_len = seq_lens[b] 38 | head_pred = head_preds[b][:seq_len] 39 | label_pred = label_preds[b][:seq_len] 40 | 41 | words = [] # 存放[word_start, word_end),相对起始位置,不考虑root 42 | heads = [] 43 | labels = [] 44 | ranges = [] # 对应该char是第几个word,长度是seq_len+1 45 | word_idx = 0 46 | word_start_idx = 0 47 | for idx, (label, head) in enumerate(zip(label_pred, head_pred)): 48 | ranges.append(word_idx) 49 | if label == self.app_index: 50 | pass 51 | else: 52 | labels.append(label) 53 | heads.append(head) 54 | words.append((word_start_idx, idx+1)) 55 | word_start_idx = idx+1 56 | word_idx += 1 57 | 58 | head_dep_tuple = [] # head在前面 59 | head_label_dep_tuple = [] 60 | for idx, head in enumerate(heads): 61 | span = words[idx] 62 | if span[0]==span[1]-1 and pun_masks[b, span[0]]: 63 | continue # exclude punctuations 64 | if head == 0: 65 | head_dep_tuple.append((('root', words[idx]))) 66 | head_label_dep_tuple.append(('root', labels[idx], words[idx])) 67 | else: 68 | head_word_idx = ranges[head-1] 69 | head_word_span = words[head_word_idx] 70 | head_dep_tuple.append(((head_word_span, words[idx]))) 71 | head_label_dep_tuple.append((head_word_span, labels[idx], words[idx])) 72 | gold_head_dep_tuple = set([(tuple(pair[0]) if not isinstance(pair[0], str) else pair[0], 73 | tuple(pair[1]) if not isinstance(pair[1], str) else pair[1]) for pair in gold_word_pairs[b]]) 74 | gold_head_label_dep_tuple = set([(tuple(pair[0]) if not isinstance(pair[0], str) else pair[0], 75 | pair[1], 76 | tuple(pair[2]) if not isinstance(pair[2], str) else pair[2]) for pair in gold_label_word_pairs[b]]) 77 | 78 | for head_dep, head_label_dep in zip(head_dep_tuple, head_label_dep_tuple): 79 | if head_dep in gold_head_dep_tuple: 80 | self.parse_head_tp += 1 81 | if head_label_dep in gold_head_label_dep_tuple: 82 | self.parse_label_tp += 1 83 | self.pre_tol += len(head_dep_tuple) 84 | self.rec_tol += len(gold_head_dep_tuple) 85 | 86 | def get_metric(self, reset=True): 87 | u_p = self.parse_head_tp / self.pre_tol 88 | u_r = self.parse_head_tp / self.rec_tol 89 | u_f = 2*u_p*u_r/(1e-6 + u_p + u_r) 90 | l_p = self.parse_label_tp / self.pre_tol 91 | l_r = self.parse_label_tp / self.rec_tol 92 | l_f = 2*l_p*l_r/(1e-6 + l_p + l_r) 93 | 94 | if reset: 95 | self.parse_head_tp = 0 96 | self.parse_label_tp = 0 97 | self.rec_tol = 0 98 | self.pre_tol = 0 99 | 100 | return {'u_f1': round(u_f, 4), 'u_p': round(u_p, 4), 'u_r/uas':round(u_r, 4), 101 | 'l_f1': round(l_f, 4), 'l_p': round(l_p, 4), 'l_r/las': round(l_r, 4)} 102 | 103 | 104 | class CWSMetric(MetricBase): 105 | def __init__(self, app_index): 106 | super().__init__() 107 | self.app_index = app_index 108 | self.pre = 0 109 | self.rec = 0 110 | self.tp = 0 111 | 112 | def evaluate(self, seg_targets, seg_masks, label_preds, seq_lens): 113 | """ 114 | 115 | :param seg_targets: batch_size x max_len, 每个位置预测的是该word的长度-1,在word结束的地方。 116 | :param seg_masks: batch_size x max_len,只有在word结束的地方为1 117 | :param label_preds: batch_size x max_len 118 | :param seq_lens: batch_size 119 | :return: 120 | """ 121 | 122 | pred_masks = torch.zeros_like(seg_masks) 123 | pred_segs = torch.zeros_like(seg_targets) 124 | 125 | seq_lens = (seq_lens - 1).tolist() 126 | for idx, label_pred in enumerate(label_preds[:, 1:].tolist()): 127 | seq_len = seq_lens[idx] 128 | label_pred = label_pred[:seq_len] 129 | word_len = 0 130 | for l_i, label in enumerate(label_pred): 131 | if label==self.app_index and l_i!=len(label_pred)-1: 132 | word_len += 1 133 | else: 134 | pred_segs[idx, l_i] = word_len # 这个词的长度为word_len 135 | pred_masks[idx, l_i] = 1 136 | word_len = 0 137 | 138 | right_mask = seg_targets.eq(pred_segs) # 对长度的预测一致 139 | self.rec += seg_masks.sum().item() 140 | self.pre += pred_masks.sum().item() 141 | # 且pred和target在同一个地方有值 142 | self.tp += (right_mask.__and__(pred_masks.bool().__and__(seg_masks.bool()))).sum().item() 143 | 144 | def get_metric(self, reset=True): 145 | res = {} 146 | res['rec'] = round(self.tp/(self.rec+1e-6), 4) 147 | res['pre'] = round(self.tp/(self.pre+1e-6), 4) 148 | res['f1'] = round(2*res['rec']*res['pre']/(res['pre'] + res['rec'] + 1e-6), 4) 149 | 150 | if reset: 151 | self.pre = 0 152 | self.rec = 0 153 | self.tp = 0 154 | 155 | return res 156 | 157 | 158 | class ParserMetric(MetricBase): 159 | def __init__(self, ): 160 | super().__init__() 161 | self.num_arc = 0 162 | self.num_label = 0 163 | self.num_sample = 0 164 | 165 | def get_metric(self, reset=True): 166 | res = {'UAS': round(self.num_arc*1.0 / self.num_sample, 4), 167 | 'LAS': round(self.num_label*1.0 / self.num_sample, 4)} 168 | if reset: 169 | self.num_sample = self.num_label = self.num_arc = 0 170 | return res 171 | 172 | def evaluate(self, head_preds, label_preds, heads, labels, seq_lens=None): 173 | """Evaluate the performance of prediction. 174 | """ 175 | if seq_lens is None: 176 | seq_mask = head_preds.new_ones(head_preds.size(), dtype=torch.bool) 177 | else: 178 | seq_mask = seq_len_to_mask(seq_lens.long()) 179 | # mask out tag 180 | seq_mask[:, 0] = 0 181 | head_pred_correct = (head_preds == heads).__and__(seq_mask) 182 | label_pred_correct = (label_preds == labels).__and__(head_pred_correct) 183 | self.num_arc += head_pred_correct.float().sum().item() 184 | self.num_label += label_pred_correct.float().sum().item() 185 | self.num_sample += seq_mask.sum().item() 186 | 187 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fastnlp/JointCwsParser/20f1121f3c19f359fdb6c34b503748c46c858a56/modules/__init__.py -------------------------------------------------------------------------------- /modules/pipe.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from fastNLP.io.data_bundle import DataBundle 4 | from fastNLP.io import ConllLoader, Pipe 5 | import numpy as np 6 | 7 | from itertools import chain 8 | from fastNLP import DataSet, Vocabulary 9 | from functools import partial 10 | from fastNLP.io.utils import check_loader_paths 11 | 12 | class CTBxJointPipe(Pipe): 13 | """ 14 | 文件夹下应该具有以下的文件结构 15 | -train.conllx 16 | -dev.conllx 17 | -test.conllx 18 | 每个文件中的内容如下(空格隔开不同的句子, 共有) 19 | 1 费孝通 _ NR NR _ 3 nsubjpass _ _ 20 | 2 被 _ SB SB _ 3 pass _ _ 21 | 3 授予 _ VV VV _ 0 root _ _ 22 | 4 麦格赛赛 _ NR NR _ 5 nn _ _ 23 | 5 奖 _ NN NN _ 3 dobj _ _ 24 | 25 | 1 新华社 _ NR NR _ 7 dep _ _ 26 | 2 马尼拉 _ NR NR _ 7 dep _ _ 27 | 3 8月 _ NT NT _ 7 dep _ _ 28 | 4 31日 _ NT NT _ 7 dep _ _ 29 | ... 30 | 31 | """ 32 | def __init__(self): 33 | self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7]) 34 | 35 | def load(self, path:str): 36 | """ 37 | 给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容 38 | words: list[str] 39 | pos_tags: list[str] 40 | heads: list[int] 41 | labels: list[str] 42 | 43 | :param path: 44 | :return: 45 | """ 46 | dataset = self._loader._load(path) 47 | dataset.heads.int() 48 | return dataset 49 | 50 | def process_from_file(self, paths): 51 | """ 52 | 53 | :param paths: 54 | :return: 55 | Dataset包含以下的field 56 | chars: 57 | bigrams: 58 | trigrams: 59 | pre_chars: 60 | pre_bigrams: 61 | pre_trigrams: 62 | seg_targets: 63 | seg_masks: 64 | seq_lens: 65 | char_labels: 66 | char_heads: 67 | gold_word_pairs: 68 | seg_targets: 69 | seg_masks: 70 | char_labels: 71 | char_heads: 72 | pun_masks: 73 | gold_label_word_pairs: 74 | """ 75 | paths = check_loader_paths(paths) 76 | data = DataBundle() 77 | 78 | for name, path in paths.items(): 79 | dataset = self.load(path) 80 | data.datasets[name] = dataset 81 | 82 | char_labels_vocab = Vocabulary(padding=None, unknown=None) 83 | 84 | def process(dataset, char_label_vocab): 85 | dataset.apply(add_word_lst, new_field_name='word_lst') 86 | dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars') 87 | dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams') 88 | dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams') 89 | dataset.apply(add_char_heads, new_field_name='char_heads') 90 | dataset.apply(add_char_labels, new_field_name='char_labels') 91 | dataset.apply(add_segs, new_field_name='seg_targets') 92 | dataset.apply(add_mask, new_field_name='seg_masks') 93 | dataset.add_seq_len('chars', new_field_name='seq_lens') 94 | dataset.apply(add_pun_masks, new_field_name='pun_masks') 95 | if len(char_label_vocab.word_count)==0: 96 | char_label_vocab.from_dataset(dataset, field_name='char_labels') 97 | char_label_vocab.index_dataset(dataset, field_name='char_labels') 98 | new_dataset = add_root(dataset) 99 | new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True) 100 | global add_label_word_pairs 101 | add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab) 102 | new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True) 103 | 104 | new_dataset.set_pad_val('char_labels', -1) 105 | new_dataset.set_pad_val('char_heads', -1) 106 | 107 | return new_dataset 108 | 109 | for name in list(paths.keys()): 110 | dataset = data.datasets[name] 111 | dataset = process(dataset, char_labels_vocab) 112 | data.datasets[name] = dataset 113 | 114 | data.vocabs['char_labels'] = char_labels_vocab 115 | 116 | char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars', 117 | no_create_entry_dataset=[data.get_dataset('dev'), 118 | data.get_dataset('test')]) 119 | bigram_vocab = Vocabulary(min_freq=3).from_dataset(data.datasets['train'], field_name='bigrams', 120 | no_create_entry_dataset=[data.get_dataset('dev'), 121 | data.get_dataset('test')]) 122 | trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams', 123 | no_create_entry_dataset=[data.get_dataset('dev'), 124 | data.get_dataset('test')]) 125 | 126 | for name in ['chars', 'bigrams', 'trigrams']: 127 | vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values())) 128 | vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name) 129 | data.vocabs['pre_{}'.format(name)] = vocab 130 | 131 | for name, vocab in zip(['chars', 'bigrams', 'trigrams'], 132 | [char_vocab, bigram_vocab, trigram_vocab]): 133 | vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name) 134 | data.vocabs[name] = vocab 135 | 136 | for name, dataset in data.datasets.items(): 137 | dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars', 138 | 'pre_bigrams', 'pre_trigrams') 139 | dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels', 140 | 'char_heads', 141 | 'pun_masks', 'gold_label_word_pairs') 142 | 143 | return data 144 | 145 | 146 | def add_label_word_pairs(instance, label_vocab): 147 | # List[List[((head_start, head_end], (dep_start, dep_end]), ...]] 148 | word_end_indexes = np.array(list(map(len, instance['word_lst']))) 149 | word_end_indexes = np.cumsum(word_end_indexes).tolist() 150 | word_end_indexes.insert(0, 0) 151 | word_pairs = [] 152 | labels = instance['labels'] 153 | pos_tags = instance['pos_tags'] 154 | for idx, head in enumerate(instance['heads']): 155 | if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 156 | continue 157 | label = label_vocab.to_index(labels[idx]) 158 | if head==0: 159 | word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1])))) 160 | else: 161 | word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label, 162 | (word_end_indexes[idx], word_end_indexes[idx + 1]))) 163 | return word_pairs 164 | 165 | def add_word_pairs(instance): 166 | # List[List[((head_start, head_end], (dep_start, dep_end]), ...]] 167 | word_end_indexes = np.array(list(map(len, instance['word_lst']))) 168 | word_end_indexes = np.cumsum(word_end_indexes).tolist() 169 | word_end_indexes.insert(0, 0) 170 | word_pairs = [] 171 | pos_tags = instance['pos_tags'] 172 | for idx, head in enumerate(instance['heads']): 173 | if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 174 | continue 175 | if head==0: 176 | word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1])))) 177 | else: 178 | word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), 179 | (word_end_indexes[idx], word_end_indexes[idx + 1]))) 180 | return word_pairs 181 | 182 | def add_root(dataset): 183 | new_dataset = DataSet() 184 | for sample in dataset: 185 | chars = ['char_root'] + sample['chars'] 186 | bigrams = ['bigram_root'] + sample['bigrams'] 187 | trigrams = ['trigram_root'] + sample['trigrams'] 188 | seq_lens = sample['seq_lens']+1 189 | char_labels = [0] + sample['char_labels'] 190 | char_heads = [0] + sample['char_heads'] 191 | sample['chars'] = chars 192 | sample['bigrams'] = bigrams 193 | sample['trigrams'] = trigrams 194 | sample['seq_lens'] = seq_lens 195 | sample['char_labels'] = char_labels 196 | sample['char_heads'] = char_heads 197 | new_dataset.append(sample) 198 | return new_dataset 199 | 200 | def add_pun_masks(instance): 201 | tags = instance['pos_tags'] 202 | pun_masks = [] 203 | for word, tag in zip(instance['words'], tags): 204 | if tag=='PU': 205 | pun_masks.extend([1]*len(word)) 206 | else: 207 | pun_masks.extend([0]*len(word)) 208 | return pun_masks 209 | 210 | def add_word_lst(instance): 211 | words = instance['words'] 212 | word_lst = [list(word) for word in words] 213 | return word_lst 214 | 215 | def add_bigram(instance): 216 | chars = instance['chars'] 217 | length = len(chars) 218 | chars = chars + [''] 219 | bigrams = [] 220 | for i in range(length): 221 | bigrams.append(''.join(chars[i:i + 2])) 222 | return bigrams 223 | 224 | def add_trigram(instance): 225 | chars = instance['chars'] 226 | length = len(chars) 227 | chars = chars + [''] * 2 228 | trigrams = [] 229 | for i in range(length): 230 | trigrams.append(''.join(chars[i:i + 3])) 231 | return trigrams 232 | 233 | def add_char_heads(instance): 234 | words = instance['word_lst'] 235 | heads = instance['heads'] 236 | char_heads = [] 237 | char_index = 1 # 因此存在root节点所以需要从1开始 238 | head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1 239 | for word, head in zip(words, heads): 240 | char_head = [] 241 | if len(word)>1: 242 | char_head.append(char_index+1) 243 | char_index += 1 244 | for _ in range(len(word)-2): 245 | char_index += 1 246 | char_head.append(char_index) 247 | char_index += 1 248 | char_head.append(head_end_indexes[head-1]) 249 | char_heads.extend(char_head) 250 | return char_heads 251 | 252 | def add_char_labels(instance): 253 | """ 254 | 将word_lst中的数据按照下面的方式设置label 255 | 比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)".. 256 | 对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label 257 | :param instance: 258 | :return: 259 | """ 260 | words = instance['word_lst'] 261 | labels = instance['labels'] 262 | char_labels = [] 263 | for word, label in zip(words, labels): 264 | for _ in range(len(word)-1): 265 | char_labels.append('APP') 266 | char_labels.append(label) 267 | return char_labels 268 | 269 | # add seg_targets 270 | def add_segs(instance): 271 | words = instance['word_lst'] 272 | segs = [0]*len(instance['chars']) 273 | index = 0 274 | for word in words: 275 | index = index + len(word) - 1 276 | segs[index] = len(word)-1 277 | index = index + 1 278 | return segs 279 | 280 | # add target_masks 281 | def add_mask(instance): 282 | words = instance['word_lst'] 283 | mask = [] 284 | for word in words: 285 | mask.extend([0] * (len(word) - 1)) 286 | mask.append(1) 287 | return mask 288 | 289 | 290 | 291 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastNLP=0.5 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from modules.pipe import CTBxJointPipe 2 | from fastNLP.embeddings.static_embedding import StaticEmbedding 3 | from torch import nn 4 | from functools import partial 5 | from models.CharParser import CharParser 6 | from models.metrics import SegAppCharParseF1Metric, CWSMetric 7 | from fastNLP import BucketSampler, Trainer 8 | from torch import optim 9 | from fastNLP import GradientClipCallback 10 | from fastNLP import cache_results 11 | import argparse 12 | from models.callbacks import EvaluateCallback 13 | 14 | 15 | uniform_init = partial(nn.init.normal_, std=0.02) 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset', type=str, choices=['ctb5', 'ctb7', 'ctb9'], default='ctb5') 19 | args = parser.parse_args() 20 | 21 | data_name = args.dataset 22 | ###################################################hyper 23 | lr = 0.002 # 0.01~0.001 24 | dropout = 0.33 # 0.3~0.6 25 | arc_mlp_size = 500 # 200, 300 26 | rnn_hidden_size = 400 # 200, 300, 400 27 | rnn_layers = 3 # 2, 3 28 | encoder = 'var-lstm' # var-lstm, lstm 29 | batch_size = 128 30 | update_every = 1 31 | n_epochs = 100 32 | 33 | weight_decay = 0 # 1e-5, 1e-6, 0 34 | emb_size = 100 # 64 , 100 35 | label_mlp_size = 100 36 | ####################################################hyper 37 | data_folder = f'../data/{data_name}' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 38 | device = 0 39 | 40 | @cache_results('caches/{}.pkl'.format(data_name), _refresh=False) 41 | def get_data(): 42 | data = CTBxJointPipe().process_from_file(data_folder) 43 | char_labels_vocab = data.vocabs['char_labels'] 44 | 45 | pre_chars_vocab = data.vocabs['pre_chars'] 46 | pre_bigrams_vocab = data.vocabs['pre_bigrams'] 47 | pre_trigrams_vocab = data.vocabs['pre_trigrams'] 48 | 49 | chars_vocab = data.vocabs['chars'] 50 | bigrams_vocab = data.vocabs['bigrams'] 51 | trigrams_vocab = data.vocabs['trigrams'] 52 | pre_chars_embed = StaticEmbedding(pre_chars_vocab, 53 | model_dir_or_name='cn-char-fastnlp-100d', 54 | init_method=uniform_init, normalize=False) 55 | pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data / pre_chars_embed.embedding.weight.data.std() 56 | pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, 57 | model_dir_or_name='cn-bi-fastnlp-100d', 58 | init_method=uniform_init, normalize=False) 59 | pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data / pre_bigrams_embed.embedding.weight.data.std() 60 | pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, 61 | model_dir_or_name='cn-tri-fastnlp-100d', 62 | init_method=uniform_init, normalize=False) 63 | pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data / pre_trigrams_embed.embedding.weight.data.std() 64 | 65 | return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data 66 | 67 | chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() 68 | 69 | print(data) 70 | model = CharParser(char_vocab_size=len(chars_vocab), 71 | emb_dim=emb_size, 72 | bigram_vocab_size=len(bigrams_vocab), 73 | trigram_vocab_size=len(trigrams_vocab), 74 | num_label=len(char_labels_vocab), 75 | rnn_layers=rnn_layers, 76 | rnn_hidden_size=rnn_hidden_size, 77 | arc_mlp_size=arc_mlp_size, 78 | label_mlp_size=label_mlp_size, 79 | dropout=dropout, 80 | encoder=encoder, 81 | use_greedy_infer=False, 82 | app_index=char_labels_vocab['APP'], 83 | pre_chars_embed=pre_chars_embed, 84 | pre_bigrams_embed=pre_bigrams_embed, 85 | pre_trigrams_embed=pre_trigrams_embed) 86 | 87 | metric1 = SegAppCharParseF1Metric(char_labels_vocab['APP']) 88 | metric2 = CWSMetric(char_labels_vocab['APP']) 89 | metrics = [metric1, metric2] 90 | 91 | optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr, 92 | weight_decay=weight_decay, betas=[0.9, 0.9]) 93 | 94 | sampler = BucketSampler(seq_len_field_name='seq_lens') 95 | callbacks = [] 96 | 97 | from fastNLP.core.callback import Callback 98 | from torch.optim.lr_scheduler import LambdaLR 99 | class SchedulerCallback(Callback): 100 | def __init__(self, scheduler): 101 | super().__init__() 102 | self.scheduler = scheduler 103 | 104 | def on_step_end(self): 105 | if self.step % self.update_every==0: 106 | self.scheduler.step() 107 | 108 | scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) 109 | scheduler_callback = SchedulerCallback(scheduler) 110 | 111 | callbacks.append(scheduler_callback) 112 | callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) 113 | callbacks.append(EvaluateCallback(data.get_dataset('test'))) 114 | 115 | trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, 116 | print_every=3, 117 | validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, 118 | check_code_level=0, metric_key='u_f1', sampler=sampler, num_workers=2, use_tqdm=True, 119 | device=device, callbacks=callbacks, update_every=update_every, dev_batch_size=256) 120 | trainer.train(load_best_model=False) 121 | -------------------------------------------------------------------------------- /train_bert.py: -------------------------------------------------------------------------------- 1 | from modules.pipe import CTBxJointPipe 2 | from fastNLP.embeddings import BertEmbedding 3 | from torch import nn 4 | from functools import partial 5 | from models.BertParser import BertParser 6 | from models.metrics import SegAppCharParseF1Metric, CWSMetric 7 | from fastNLP import BucketSampler, Trainer 8 | from torch import optim 9 | from fastNLP import GradientClipCallback, WarmupCallback 10 | from fastNLP import cache_results 11 | import argparse 12 | from models.callbacks import EvaluateCallback 13 | 14 | 15 | uniform_init = partial(nn.init.normal_, std=0.02) 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset', type=str, choices=['ctb5', 'ctb7', 'ctb9'], default='ctb5') 19 | args = parser.parse_args() 20 | 21 | data_name = args.dataset 22 | ###################################################hyper 23 | # 需要变动的超参放到这里 24 | lr = 2e-5 # 0.01~0.001 25 | dropout = 0.5 # 0.3~0.6 26 | arc_mlp_size = 500 # 200, 300 27 | encoder = 'bert' 28 | batch_size = 6 29 | update_every = 1 30 | n_epochs = 5 31 | 32 | label_mlp_size = 100 33 | ####################################################hyper 34 | data_folder = f'../data/{data_name}' 35 | device = 0 36 | 37 | @cache_results('caches/{}_bert.pkl'.format(data_name), _refresh=False) 38 | def get_data(): 39 | data = CTBxJointPipe().process_from_file(data_folder) 40 | data.delete_field('bigrams') 41 | data.delete_field('trigrams') 42 | data.delete_field('chars') 43 | data.rename_field('pre_chars', 'chars') 44 | data.delete_field('pre_bigrams') 45 | data.delete_field('pre_trigrams') 46 | bert_embed = BertEmbedding(data.get_vocab('chars'), model_dir_or_name='cn', requires_grad=True) 47 | return data, bert_embed 48 | 49 | data, bert_embed = get_data() 50 | 51 | print(data) 52 | model = BertParser(embed=bert_embed, num_label=len(data.get_vocab('char_labels')), arc_mlp_size=arc_mlp_size, 53 | label_mlp_size=label_mlp_size, dropout=dropout, 54 | use_greedy_infer=False, 55 | app_index=0) 56 | 57 | metric1 = SegAppCharParseF1Metric(data.get_vocab('char_labels')['APP']) 58 | metric2 = CWSMetric(data.get_vocab('char_labels')['APP']) 59 | metrics = [metric1, metric2] 60 | 61 | optimizer = optim.AdamW([param for param in model.parameters() if param.requires_grad], lr=lr, 62 | weight_decay=1e-2) 63 | 64 | sampler = BucketSampler(seq_len_field_name='seq_lens') 65 | callbacks = [] 66 | 67 | warmup_callback = WarmupCallback(schedule='linear') 68 | 69 | callbacks.append(warmup_callback) 70 | callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) 71 | callbacks.append(EvaluateCallback(data.get_dataset('test'))) 72 | 73 | trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, 74 | print_every=3, 75 | validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, 76 | check_code_level=0, metric_key='u_f1', sampler=sampler, num_workers=2, use_tqdm=True, 77 | device=device, callbacks=callbacks, update_every=update_every, dev_batch_size=6) 78 | trainer.train(load_best_model=False) 79 | --------------------------------------------------------------------------------