├── README.md ├── models ├── TENER.py └── __init__.py ├── modules ├── TransformerEmbedding.py ├── __init__.py ├── callbacks.py ├── pipe.py ├── relative_transformer.py ├── transformer.py └── utils.py ├── requirements.txt ├── train_elmo_en.py ├── train_tener_cn.py └── train_tener_en.py /README.md: -------------------------------------------------------------------------------- 1 | ## TENER: Adapting Transformer Encoder for Named Entity Recognition 2 | 3 | 4 | This is the code for the paper [TENER](https://arxiv.org/abs/1911.04474). 5 | 6 | TENER (Transformer Encoder for Named Entity Recognition) is a Transformer-based model which 7 | aims to tackle the NER task. Compared with the naive Transformer, we 8 | found relative position embedding is quite important in the NER task. Experiments 9 | in the English and Chinese NER datasets prove the effectiveness. 10 | 11 | #### Requirements 12 | This project needs the natural language processing python package 13 | [fastNLP](https://github.com/fastnlp/fastNLP). You can install by 14 | the following command 15 | 16 | ```bash 17 | pip install fastNLP 18 | ``` 19 | 20 | #### Run the code 21 | 22 | (1) Prepare the English dataset. 23 | 24 | ##### Conll2003 25 | 26 | Your file should like the following (The first token in a line 27 | is the word, the last token is the NER tag.) 28 | 29 | ``` 30 | LONDON NNP B-NP B-LOC 31 | 1996-08-30 CD I-NP O 32 | 33 | West NNP B-NP B-MISC 34 | Indian NNP I-NP I-MISC 35 | all-rounder NN I-NP O 36 | Phil NNP I-NP B-PER 37 | 38 | ``` 39 | 40 | ##### OntoNotes 41 | 42 | Suggest to use the following code to prepare your data 43 | [OntoNotes-5.0-NER](https://github.com/yhcc/OntoNotes-5.0-NER). 44 | Or you can prepare data like the Conll2003 style, and then replace the 45 | OntoNotesNERPipe with Conll2003NERPipe in the code. 46 | 47 | For English datasets, we use the Glove 100d pretrained embedding. FastNLP will 48 | download it automatically. 49 | 50 | You can use the following code to run (make sure you have changed the 51 | data path) 52 | 53 | ``` 54 | python train_tener_en.py --dataset conll2003 55 | ``` 56 | or 57 | ``` 58 | python train_tener_en.py --dataset en-ontonotes 59 | ``` 60 | 61 | Although we tried hard to make sure you can reproduce our results, 62 | the results may still disappoint you. This is usually caused by 63 | the best dev performance does not correlate well with the test performance 64 | . Several runs should be helpful. 65 | 66 | The ELMo version (FastNLP will download ELMo weights automatically, you just need 67 | to change the data path in train_elmo_en.) 68 | 69 | ``` 70 | python train_elmo_en.py --dataset en-ontonotes 71 | ``` 72 | 73 | 74 | 75 | ##### MSRA, OntoNotes4.0, Weibo, Resume 76 | Your data should only have two columns, the first is the character, 77 | the second is the tag, like the following 78 | ``` 79 | 口 O 80 | 腔 O 81 | 溃 O 82 | 疡 O 83 | 加 O 84 | 上 O 85 | ``` 86 | 87 | For the Chinese datasets, you can download the pretrained unigram and 88 | bigram embeddings in [Baidu Cloud](https://pan.baidu.com/s/1pLO6T9D#list/path=%2Fsharelink808087924-1080546002081577%2FNeuralSegmentation&parentPath=%2Fsharelink808087924-1080546002081577). 89 | Download the 'gigaword_chn.all.a2b.uni.iter50.vec' and 'gigaword_chn.all.a2b.bi.iter50.vec'. 90 | Then replace the embedding path in train_tener_cn.py 91 | 92 | You can run the code by the following command 93 | 94 | ``` 95 | python train_tener_cn.py --dataset ontonotes 96 | ``` 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /models/TENER.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from fastNLP.modules import ConditionalRandomField, allowed_transitions 4 | from modules.transformer import TransformerEncoder 5 | 6 | from torch import nn 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | class TENER(nn.Module): 12 | def __init__(self, tag_vocab, embed, num_layers, d_model, n_head, feedforward_dim, dropout, 13 | after_norm=True, attn_type='adatrans', bi_embed=None, 14 | fc_dropout=0.3, pos_embed=None, scale=False, dropout_attn=None): 15 | """ 16 | 17 | :param tag_vocab: fastNLP Vocabulary 18 | :param embed: fastNLP TokenEmbedding 19 | :param num_layers: number of self-attention layers 20 | :param d_model: input size 21 | :param n_head: number of head 22 | :param feedforward_dim: the dimension of ffn 23 | :param dropout: dropout in self-attention 24 | :param after_norm: normalization place 25 | :param attn_type: adatrans, naive 26 | :param rel_pos_embed: position embedding的类型,支持sin, fix, None. relative时可为None 27 | :param bi_embed: Used in Chinese scenerio 28 | :param fc_dropout: dropout rate before the fc layer 29 | """ 30 | super().__init__() 31 | 32 | self.embed = embed 33 | embed_size = self.embed.embed_size 34 | self.bi_embed = None 35 | if bi_embed is not None: 36 | self.bi_embed = bi_embed 37 | embed_size += self.bi_embed.embed_size 38 | 39 | self.in_fc = nn.Linear(embed_size, d_model) 40 | 41 | self.transformer = TransformerEncoder(num_layers, d_model, n_head, feedforward_dim, dropout, 42 | after_norm=after_norm, attn_type=attn_type, 43 | scale=scale, dropout_attn=dropout_attn, 44 | pos_embed=pos_embed) 45 | self.fc_dropout = nn.Dropout(fc_dropout) 46 | self.out_fc = nn.Linear(d_model, len(tag_vocab)) 47 | 48 | trans = allowed_transitions(tag_vocab, include_start_end=True) 49 | self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=True, allowed_transitions=trans) 50 | 51 | def _forward(self, chars, target, bigrams=None): 52 | mask = chars.ne(0) 53 | chars = self.embed(chars) 54 | if self.bi_embed is not None: 55 | bigrams = self.bi_embed(bigrams) 56 | chars = torch.cat([chars, bigrams], dim=-1) 57 | 58 | chars = self.in_fc(chars) 59 | chars = self.transformer(chars, mask) 60 | chars = self.fc_dropout(chars) 61 | chars = self.out_fc(chars) 62 | logits = F.log_softmax(chars, dim=-1) 63 | if target is None: 64 | paths, _ = self.crf.viterbi_decode(logits, mask) 65 | return {'pred': paths} 66 | else: 67 | loss = self.crf(logits, target, mask) 68 | return {'loss': loss} 69 | 70 | def forward(self, chars, target, bigrams=None): 71 | return self._forward(chars, target, bigrams) 72 | 73 | def predict(self, chars, bigrams=None): 74 | return self._forward(chars, target=None, bigrams=bigrams) 75 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fastnlp/TENER/d2614d509dffb9b30636e3523a2f8f0dc4876708/models/__init__.py -------------------------------------------------------------------------------- /modules/TransformerEmbedding.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from fastNLP.embeddings import TokenEmbedding 4 | import torch 5 | from fastNLP import Vocabulary 6 | import torch.nn.functional as F 7 | from fastNLP import logger 8 | from fastNLP.embeddings.utils import _construct_char_vocab_from_vocab, get_embeddings 9 | from torch import nn 10 | from .transformer import TransformerEncoder 11 | 12 | 13 | class TransformerCharEmbed(TokenEmbedding): 14 | def __init__(self, vocab: Vocabulary, embed_size: int = 30, char_emb_size: int = 30, word_dropout: float = 0, 15 | dropout: float = 0, pool_method: str = 'max', activation='relu', 16 | min_char_freq: int = 2, requires_grad=True, include_word_start_end=True, 17 | char_attn_type='adatrans', char_n_head=3, char_dim_ffn=60, char_scale=False, char_pos_embed=None, 18 | char_dropout=0.15, char_after_norm=False): 19 | """ 20 | :param vocab: 词表 21 | :param embed_size: TransformerCharEmbed的输出维度。默认值为50. 22 | :param char_emb_size: character的embedding的维度。默认值为50. 同时也是Transformer的d_model大小 23 | :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 24 | :param dropout: 以多大概率drop character embedding的输出以及最终的word的输出。 25 | :param pool_method: 支持'max', 'avg'。 26 | :param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. 27 | :param min_char_freq: character的最小出现次数。默认值为2. 28 | :param requires_grad: 29 | :param include_word_start_end: 是否使用特殊的tag标记word的开始与结束 30 | :param char_attn_type: adatrans or naive. 31 | :param char_n_head: 多少个head 32 | :param char_dim_ffn: transformer中ffn中间层的大小 33 | :param char_scale: 是否使用scale 34 | :param char_pos_embed: None, 'fix', 'sin'. What kind of position embedding. When char_attn_type=relative, None is 35 | ok 36 | :param char_dropout: Dropout in Transformer encoder 37 | :param char_after_norm: the normalization place. 38 | """ 39 | super(TransformerCharEmbed, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) 40 | 41 | assert char_emb_size%char_n_head == 0, "d_model should divide n_head." 42 | 43 | assert pool_method in ('max', 'avg') 44 | self.pool_method = pool_method 45 | # activation function 46 | if isinstance(activation, str): 47 | if activation.lower() == 'relu': 48 | self.activation = F.relu 49 | elif activation.lower() == 'sigmoid': 50 | self.activation = F.sigmoid 51 | elif activation.lower() == 'tanh': 52 | self.activation = F.tanh 53 | elif activation is None: 54 | self.activation = lambda x: x 55 | elif callable(activation): 56 | self.activation = activation 57 | else: 58 | raise Exception( 59 | "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") 60 | 61 | logger.info("Start constructing character vocabulary.") 62 | # 建立char的词表 63 | self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq, 64 | include_word_start_end=include_word_start_end) 65 | self.char_pad_index = self.char_vocab.padding_idx 66 | logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") 67 | # 对vocab进行index 68 | max_word_len = max(map(lambda x: len(x[0]), vocab)) 69 | if include_word_start_end: 70 | max_word_len += 2 71 | self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len), 72 | fill_value=self.char_pad_index, dtype=torch.long)) 73 | self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) 74 | for word, index in vocab: 75 | # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否 76 | if include_word_start_end: 77 | word = [''] + list(word) + [''] 78 | self.words_to_chars_embedding[index, :len(word)] = \ 79 | torch.LongTensor([self.char_vocab.to_index(c) for c in word]) 80 | self.word_lengths[index] = len(word) 81 | 82 | self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) 83 | self.transformer = TransformerEncoder(1, char_emb_size, char_n_head, char_dim_ffn, dropout=char_dropout, after_norm=char_after_norm, 84 | attn_type=char_attn_type, pos_embed=char_pos_embed, scale=char_scale) 85 | self.fc = nn.Linear(char_emb_size, embed_size) 86 | 87 | self._embed_size = embed_size 88 | 89 | self.requires_grad = requires_grad 90 | 91 | def forward(self, words): 92 | """ 93 | 输入words的index后,生成对应的words的表示。 94 | 95 | :param words: [batch_size, max_len] 96 | :return: [batch_size, max_len, embed_size] 97 | """ 98 | words = self.drop_word(words) 99 | batch_size, max_len = words.size() 100 | chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len 101 | word_lengths = self.word_lengths[words] # batch_size x max_len 102 | max_word_len = word_lengths.max() 103 | chars = chars[:, :, :max_word_len] 104 | # 为mask的地方为1 105 | chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 106 | char_embeds = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size 107 | char_embeds = self.dropout(char_embeds) 108 | reshaped_chars = char_embeds.reshape(batch_size * max_len, max_word_len, -1) 109 | 110 | trans_chars = self.transformer(reshaped_chars, chars_masks.eq(0).reshape(-1, max_word_len)) 111 | trans_chars = trans_chars.reshape(batch_size, max_len, max_word_len, -1) 112 | trans_chars = self.activation(trans_chars) 113 | if self.pool_method == 'max': 114 | trans_chars = trans_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) 115 | chars, _ = torch.max(trans_chars, dim=-2) # batch_size x max_len x H 116 | else: 117 | trans_chars = trans_chars.masked_fill(chars_masks.unsqueeze(-1), 0) 118 | chars = torch.sum(trans_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float() 119 | 120 | chars = self.fc(chars) 121 | 122 | return self.dropout(chars) -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fastnlp/TENER/d2614d509dffb9b30636e3523a2f8f0dc4876708/modules/__init__.py -------------------------------------------------------------------------------- /modules/callbacks.py: -------------------------------------------------------------------------------- 1 | 2 | from fastNLP import Callback, Tester, DataSet 3 | 4 | 5 | class EvaluateCallback(Callback): 6 | """ 7 | 通过使用该Callback可以使得Trainer在evaluate dev之外还可以evaluate其它数据集,比如测试集。每一次验证dev之前都会先验证EvaluateCallback 8 | 中的数据。 9 | """ 10 | 11 | def __init__(self, data=None, tester=None): 12 | """ 13 | :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用Trainer中的metric对数据进行验证。如果需要传入多个 14 | DataSet请通过dict的方式传入。 15 | :param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象, 通过使用Tester对象,可以使得验证的metric与Trainer中 16 | 的metric不一样。 17 | """ 18 | super().__init__() 19 | self.datasets = {} 20 | self.testers = {} 21 | self.best_test_metric_sofar = 0 22 | self.best_test_sofar = None 23 | self.best_test_epoch = 0 24 | self.best_dev_test = None 25 | self.best_dev_epoch = 0 26 | if tester is not None: 27 | if isinstance(tester, dict): 28 | for name, test in tester.items(): 29 | if not isinstance(test, Tester): 30 | raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") 31 | self.testers['tester-' + name] = test 32 | if isinstance(tester, Tester): 33 | self.testers['tester-test'] = tester 34 | for tester in self.testers.values(): 35 | setattr(tester, 'verbose', 0) 36 | 37 | if isinstance(data, dict): 38 | for key, value in data.items(): 39 | assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." 40 | for key, value in data.items(): 41 | self.datasets['data-' + key] = value 42 | elif isinstance(data, DataSet): 43 | self.datasets['data-test'] = data 44 | elif data is not None: 45 | raise TypeError("data receives dict[DataSet] or DataSet object.") 46 | 47 | def on_train_begin(self): 48 | if len(self.datasets) > 0 and self.trainer.dev_data is None: 49 | raise RuntimeError("Trainer has no dev data, you cannot pass extra DataSet to do evaluation.") 50 | 51 | if len(self.datasets) > 0: 52 | for key, data in self.datasets.items(): 53 | tester = Tester(data=data, model=self.model, 54 | batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), 55 | metrics=self.trainer.metrics, verbose=0, 56 | use_tqdm=self.trainer.test_use_tqdm) 57 | self.testers[key] = tester 58 | 59 | def on_valid_end(self, eval_result, metric_key, optimizer, better_result): 60 | if len(self.testers) > 0: 61 | for idx, (key, tester) in enumerate(self.testers.items()): 62 | try: 63 | eval_result = tester.test() 64 | if idx == 0: 65 | indicator, indicator_val = _check_eval_results(eval_result) 66 | if indicator_val>self.best_test_metric_sofar: 67 | self.best_test_metric_sofar = indicator_val 68 | self.best_test_epoch = self.epoch 69 | self.best_test_sofar = eval_result 70 | if better_result: 71 | self.best_dev_test = eval_result 72 | self.best_dev_epoch = self.epoch 73 | self.logger.info("EvaluateCallback evaluation on {}:".format(key)) 74 | self.logger.info(tester._format_eval_results(eval_result)) 75 | except Exception as e: 76 | self.logger.error("Exception happens when evaluate on DataSet named `{}`.".format(key)) 77 | raise e 78 | 79 | def on_train_end(self): 80 | if self.best_test_sofar: 81 | self.logger.info("Best test performance(may not correspond to the best dev performance):{} achieved at Epoch:{}.".format(self.best_test_sofar, self.best_test_epoch)) 82 | if self.best_dev_test: 83 | self.logger.info("Best test performance(correspond to the best dev performance):{} achieved at Epoch:{}.".format(self.best_dev_test, self.best_dev_epoch)) 84 | 85 | 86 | def _check_eval_results(metrics, metric_key=None): 87 | # metrics: tester返回的结果 88 | # metric_key: 一个用来做筛选的指标,来自Trainer的初始化 89 | if isinstance(metrics, tuple): 90 | loss, metrics = metrics 91 | 92 | if isinstance(metrics, dict): 93 | metric_dict = list(metrics.values())[0] # 取第一个metric 94 | 95 | if metric_key is None: 96 | indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] 97 | else: 98 | # metric_key is set 99 | if metric_key not in metric_dict: 100 | raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") 101 | indicator_val = metric_dict[metric_key] 102 | indicator = metric_key 103 | else: 104 | raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) 105 | return indicator, indicator_val -------------------------------------------------------------------------------- /modules/pipe.py: -------------------------------------------------------------------------------- 1 | 2 | from fastNLP.io import Pipe, ConllLoader 3 | from fastNLP.io import DataBundle 4 | from fastNLP.io.pipe.utils import _add_words_field, _indexize 5 | from fastNLP.io.pipe.utils import iob2, iob2bioes 6 | from fastNLP.io.pipe.utils import _add_chars_field 7 | from fastNLP.io.utils import check_loader_paths 8 | 9 | from fastNLP.io import Conll2003NERLoader 10 | from fastNLP import Const 11 | 12 | def word_shape(words): 13 | shapes = [] 14 | for word in words: 15 | caps = [] 16 | for char in word: 17 | caps.append(char.isupper()) 18 | if all(caps): 19 | shapes.append(0) 20 | elif any(caps) is False: 21 | shapes.append(1) 22 | elif caps[0]: 23 | shapes.append(2) 24 | elif any(caps): 25 | shapes.append(3) 26 | else: 27 | shapes.append(4) 28 | return shapes 29 | 30 | 31 | class Conll2003NERPipe(Pipe): 32 | """ 33 | Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 34 | (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 35 | Vocabulary转换为index。 36 | 经过该Pipe过后,DataSet中的内容如下所示 37 | 38 | .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader 39 | :header: "raw_words", "target", "words", "seq_len" 40 | 41 | "[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 42 | "[AL-AIN, United, Arab, ...]", "[3, 4,...]", "[4, 5, 6,...]", 6 43 | "[...]", "[...]", "[...]", . 44 | 45 | raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 46 | target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 47 | 48 | dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: 49 | 50 | +-------------+-----------+--------+-------+---------+ 51 | | field_names | raw_words | target | words | seq_len | 52 | +-------------+-----------+--------+-------+---------+ 53 | | is_input | False | True | True | True | 54 | | is_target | False | True | False | True | 55 | | ignore_type | | False | False | False | 56 | | pad_value | | 0 | 0 | 0 | 57 | +-------------+-----------+--------+-------+---------+ 58 | 59 | """ 60 | 61 | def __init__(self, encoding_type: str = 'bio', lower: bool = False, word_shape: bool=False): 62 | """ 63 | 64 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 65 | :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 66 | :param boll word_shape: 是否新增一列word shape,5维 67 | """ 68 | if encoding_type == 'bio': 69 | self.convert_tag = iob2 70 | elif encoding_type == 'bioes': 71 | self.convert_tag = lambda words: iob2bioes(iob2(words)) 72 | else: 73 | raise ValueError("encoding_type only supports `bio` and `bioes`.") 74 | self.lower = lower 75 | self.word_shape = word_shape 76 | 77 | def process(self, data_bundle: DataBundle) -> DataBundle: 78 | """ 79 | 支持的DataSet的field为 80 | 81 | .. csv-table:: 82 | :header: "raw_words", "target" 83 | 84 | "[Nadim, Ladki]", "[B-PER, I-PER]" 85 | "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" 86 | "[...]", "[...]" 87 | 88 | :param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]在传入DataBundle基础上原位修改。 89 | :return DataBundle: 90 | """ 91 | # 转换tag 92 | for name, dataset in data_bundle.datasets.items(): 93 | dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) 94 | 95 | _add_words_field(data_bundle, lower=self.lower) 96 | 97 | if self.word_shape: 98 | data_bundle.apply_field(word_shape, field_name='raw_words', new_field_name='word_shapes') 99 | data_bundle.set_input('word_shapes') 100 | 101 | # 将所有digit转为0 102 | data_bundle.apply_field(lambda chars:[''.join(['0' if c.isdigit() else c for c in char]) for char in chars], 103 | field_name=Const.INPUT, new_field_name=Const.INPUT) 104 | 105 | # index 106 | _indexize(data_bundle) 107 | 108 | input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] 109 | target_fields = [Const.TARGET, Const.INPUT_LEN] 110 | 111 | for name, dataset in data_bundle.datasets.items(): 112 | dataset.add_seq_len(Const.INPUT) 113 | 114 | data_bundle.set_input(*input_fields) 115 | data_bundle.set_target(*target_fields) 116 | 117 | return data_bundle 118 | 119 | def process_from_file(self, paths) -> DataBundle: 120 | """ 121 | 122 | :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 123 | :return: DataBundle 124 | """ 125 | # 读取数据 126 | data_bundle = Conll2003NERLoader().load(paths) 127 | data_bundle = self.process(data_bundle) 128 | 129 | return data_bundle 130 | 131 | 132 | from fastNLP.io import OntoNotesNERLoader 133 | 134 | class OntoNotesNERPipe(Pipe): 135 | """ 136 | 处理OntoNotes的NER数据,处理之后DataSet中的field情况为 137 | 138 | .. csv-table:: 139 | :header: "raw_words", "target", "words", "seq_len" 140 | 141 | "[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 142 | "[AL-AIN, United, Arab, ...]", "[3, 4]", "[4, 5, 6,...]", 6 143 | "[...]", "[...]", "[...]", . 144 | 145 | raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 146 | target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 147 | 148 | dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: 149 | 150 | +-------------+-----------+--------+-------+---------+ 151 | | field_names | raw_words | target | words | seq_len | 152 | +-------------+-----------+--------+-------+---------+ 153 | | is_input | False | True | True | True | 154 | | is_target | False | True | False | True | 155 | | ignore_type | | False | False | False | 156 | | pad_value | | 0 | 0 | 0 | 157 | +-------------+-----------+--------+-------+---------+ 158 | 159 | """ 160 | def __init__(self, encoding_type: str = 'bio', lower: bool = False, word_shape: bool=False): 161 | """ 162 | 163 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 164 | :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 165 | :param boll word_shape: 是否新增一列word shape,5维 166 | """ 167 | if encoding_type == 'bio': 168 | self.convert_tag = iob2 169 | elif encoding_type == 'bioes': 170 | self.convert_tag = lambda words: iob2bioes(iob2(words)) 171 | else: 172 | raise ValueError("encoding_type only supports `bio` and `bioes`.") 173 | self.lower = lower 174 | self.word_shape = word_shape 175 | 176 | def process(self, data_bundle: DataBundle) -> DataBundle: 177 | """ 178 | 支持的DataSet的field为 179 | 180 | .. csv-table:: 181 | :header: "raw_words", "target" 182 | 183 | "[Nadim, Ladki]", "[B-PER, I-PER]" 184 | "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" 185 | "[...]", "[...]" 186 | 187 | :param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]在传入DataBundle基础上原位修改。 188 | :return DataBundle: 189 | """ 190 | # 转换tag 191 | for name, dataset in data_bundle.datasets.items(): 192 | dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) 193 | 194 | _add_words_field(data_bundle, lower=self.lower) 195 | 196 | if self.word_shape: 197 | data_bundle.apply_field(word_shape, field_name='raw_words', new_field_name='word_shapes') 198 | data_bundle.set_input('word_shapes') 199 | 200 | # 将所有digit转为0 201 | data_bundle.apply_field(lambda chars:[''.join(['0' if c.isdigit() else c for c in char]) for char in chars], 202 | field_name=Const.INPUT, new_field_name=Const.INPUT) 203 | 204 | # index 205 | _indexize(data_bundle) 206 | 207 | input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] 208 | target_fields = [Const.TARGET, Const.INPUT_LEN] 209 | 210 | for name, dataset in data_bundle.datasets.items(): 211 | dataset.add_seq_len(Const.INPUT) 212 | 213 | data_bundle.set_input(*input_fields) 214 | data_bundle.set_target(*target_fields) 215 | 216 | return data_bundle 217 | 218 | def process_from_file(self, paths): 219 | data_bundle = OntoNotesNERLoader().load(paths) 220 | return self.process(data_bundle) 221 | 222 | 223 | def bmeso2bio(tags): 224 | new_tags = [] 225 | for tag in tags: 226 | tag = tag.lower() 227 | if tag.startswith('m') or tag.startswith('e'): 228 | tag = 'i' + tag[1:] 229 | if tag.startswith('s'): 230 | tag = 'b' + tag[1:] 231 | new_tags.append(tag) 232 | return new_tags 233 | 234 | 235 | def bmeso2bioes(tags): 236 | new_tags = [] 237 | for tag in tags: 238 | lowered_tag = tag.lower() 239 | if lowered_tag.startswith('m'): 240 | tag = 'i' + tag[1:] 241 | new_tags.append(tag) 242 | return new_tags 243 | 244 | 245 | class CNNERPipe(Pipe): 246 | def __init__(self, bigrams=False, encoding_type='bmeso'): 247 | super().__init__() 248 | self.bigrams = bigrams 249 | if encoding_type=='bmeso': 250 | self.encoding_func = lambda x:x 251 | elif encoding_type=='bio': 252 | self.encoding_func = bmeso2bio 253 | elif encoding_type == 'bioes': 254 | self.encoding_func = bmeso2bioes 255 | else: 256 | raise RuntimeError("Only support bio, bmeso, bioes") 257 | 258 | def process(self, data_bundle: DataBundle): 259 | _add_chars_field(data_bundle, lower=False) 260 | 261 | data_bundle.apply_field(self.encoding_func, field_name=Const.TARGET, new_field_name=Const.TARGET) 262 | 263 | # 将所有digit转为0 264 | data_bundle.apply_field(lambda chars:[''.join(['0' if c.isdigit() else c for c in char]) for char in chars], 265 | field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT) 266 | 267 | # 268 | input_field_names = [Const.CHAR_INPUT] 269 | if self.bigrams: 270 | data_bundle.apply_field(lambda chars:[c1+c2 for c1,c2 in zip(chars, chars[1:]+[''])], 271 | field_name=Const.CHAR_INPUT, new_field_name='bigrams') 272 | input_field_names.append('bigrams') 273 | 274 | # index 275 | _indexize(data_bundle, input_field_names=input_field_names, target_field_names=Const.TARGET) 276 | 277 | input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names 278 | target_fields = [Const.TARGET, Const.INPUT_LEN] 279 | 280 | for name, dataset in data_bundle.datasets.items(): 281 | dataset.add_seq_len(Const.CHAR_INPUT) 282 | 283 | data_bundle.set_input(*input_fields) 284 | data_bundle.set_target(*target_fields) 285 | 286 | return data_bundle 287 | 288 | def process_from_file(self, paths): 289 | paths = check_loader_paths(paths) 290 | loader = ConllLoader(headers=['raw_chars', 'target']) 291 | data_bundle = loader.load(paths) 292 | return self.process(data_bundle) -------------------------------------------------------------------------------- /modules/relative_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class RelativeEmbedding(nn.Module): 8 | def forward(self, input): 9 | """Input is expected to be of size [bsz x seqlen]. 10 | """ 11 | bsz, seq_len = input.size() 12 | max_pos = self.padding_idx + seq_len 13 | if max_pos > self.origin_shift: 14 | # recompute/expand embeddings if needed 15 | weights = self.get_embedding( 16 | max_pos*2, 17 | self.embedding_dim, 18 | self.padding_idx, 19 | ) 20 | weights = weights.to(self._float_tensor) 21 | del self.weights 22 | self.origin_shift = weights.size(0)//2 23 | self.register_buffer('weights', weights) 24 | 25 | positions = torch.arange(-seq_len, seq_len).to(input.device).long() + self.origin_shift # 2*seq_len 26 | embed = self.weights.index_select(0, positions.long()).detach() 27 | return embed 28 | 29 | 30 | class RelativeSinusoidalPositionalEmbedding(RelativeEmbedding): 31 | """This module produces sinusoidal positional embeddings of any length. 32 | Padding symbols are ignored. 33 | """ 34 | 35 | def __init__(self, embedding_dim, padding_idx, init_size=1568): 36 | """ 37 | 38 | :param embedding_dim: 每个位置的dimension 39 | :param padding_idx: 40 | :param init_size: 41 | """ 42 | super().__init__() 43 | self.embedding_dim = embedding_dim 44 | self.padding_idx = padding_idx 45 | assert init_size%2==0 46 | weights = self.get_embedding( 47 | init_size+1, 48 | embedding_dim, 49 | padding_idx, 50 | ) 51 | self.register_buffer('weights', weights) 52 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 53 | 54 | def get_embedding(self, num_embeddings, embedding_dim, padding_idx=None): 55 | """Build sinusoidal embeddings. 56 | This matches the implementation in tensor2tensor, but differs slightly 57 | from the description in Section 3.5 of "Attention Is All You Need". 58 | """ 59 | half_dim = embedding_dim // 2 60 | emb = math.log(10000) / (half_dim - 1) 61 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 62 | emb = torch.arange(-num_embeddings//2, num_embeddings//2, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 63 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 64 | if embedding_dim % 2 == 1: 65 | # zero pad 66 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 67 | if padding_idx is not None: 68 | emb[padding_idx, :] = 0 69 | self.origin_shift = num_embeddings//2 + 1 70 | return emb 71 | 72 | 73 | class RelativeMultiHeadAttn(nn.Module): 74 | def __init__(self, d_model, n_head, dropout, r_w_bias=None, r_r_bias=None, scale=False): 75 | """ 76 | 77 | :param int d_model: 78 | :param int n_head: 79 | :param dropout: 对attention map的dropout 80 | :param r_w_bias: n_head x head_dim or None, 如果为dim 81 | :param r_r_bias: n_head x head_dim or None, 82 | :param scale: 83 | :param rel_pos_embed: 84 | """ 85 | super().__init__() 86 | self.qkv_linear = nn.Linear(d_model, d_model * 3, bias=False) 87 | self.n_head = n_head 88 | self.head_dim = d_model // n_head 89 | self.dropout_layer = nn.Dropout(dropout) 90 | 91 | self.pos_embed = RelativeSinusoidalPositionalEmbedding(d_model//n_head, 0, 1200) 92 | 93 | if scale: 94 | self.scale = math.sqrt(d_model // n_head) 95 | else: 96 | self.scale = 1 97 | 98 | if r_r_bias is None or r_w_bias is None: # Biases are not shared 99 | self.r_r_bias = nn.Parameter(nn.init.xavier_normal_(torch.zeros(n_head, d_model // n_head))) 100 | self.r_w_bias = nn.Parameter(nn.init.xavier_normal_(torch.zeros(n_head, d_model // n_head))) 101 | else: 102 | self.r_r_bias = r_r_bias # r_r_bias就是v 103 | self.r_w_bias = r_w_bias # r_w_bias就是u 104 | 105 | def forward(self, x, mask): 106 | """ 107 | 108 | :param x: batch_size x max_len x d_model 109 | :param mask: batch_size x max_len 110 | :return: 111 | """ 112 | 113 | batch_size, max_len, d_model = x.size() 114 | pos_embed = self.pos_embed(mask) # l x head_dim 115 | 116 | qkv = self.qkv_linear(x) # batch_size x max_len x d_model3 117 | q, k, v = torch.chunk(qkv, chunks=3, dim=-1) 118 | q = q.view(batch_size, max_len, self.n_head, -1).transpose(1, 2) 119 | k = k.view(batch_size, max_len, self.n_head, -1).transpose(1, 2) 120 | v = v.view(batch_size, max_len, self.n_head, -1).transpose(1, 2) # b x n x l x d 121 | 122 | rw_head_q = q + self.r_r_bias[:, None] 123 | AC = torch.einsum('bnqd,bnkd->bnqk', [rw_head_q, k]) # b x n x l x d, n是head 124 | 125 | D_ = torch.einsum('nd,ld->nl', self.r_w_bias, pos_embed)[None, :, None] # head x 2max_len, 每个head对位置的bias 126 | B_ = torch.einsum('bnqd,ld->bnql', q, pos_embed) # bsz x head x max_len x 2max_len,每个query对每个shift的偏移 127 | E_ = torch.einsum('bnqd,ld->bnql', k, pos_embed) # bsz x head x max_len x 2max_len, key对relative的bias 128 | BD = B_ + D_ # bsz x head x max_len x 2max_len, 要转换为bsz x head x max_len x max_len 129 | BDE = self._shift(BD) + self._transpose_shift(E_) 130 | attn = AC + BDE 131 | 132 | attn = attn / self.scale 133 | 134 | attn = attn.masked_fill(mask[:, None, None, :].eq(0), float('-inf')) 135 | 136 | attn = F.softmax(attn, dim=-1) 137 | attn = self.dropout_layer(attn) 138 | v = torch.matmul(attn, v).transpose(1, 2).reshape(batch_size, max_len, d_model) # b x n x l x d 139 | 140 | return v 141 | 142 | def _shift(self, BD): 143 | """ 144 | 类似 145 | -3 -2 -1 0 1 2 146 | -3 -2 -1 0 1 2 147 | -3 -2 -1 0 1 2 148 | 149 | 转换为 150 | 0 1 2 151 | -1 0 1 152 | -2 -1 0 153 | 154 | :param BD: batch_size x n_head x max_len x 2max_len 155 | :return: batch_size x n_head x max_len x max_len 156 | """ 157 | bsz, n_head, max_len, _ = BD.size() 158 | zero_pad = BD.new_zeros(bsz, n_head, max_len, 1) 159 | BD = torch.cat([BD, zero_pad], dim=-1).view(bsz, n_head, -1, max_len) # bsz x n_head x (2max_len+1) x max_len 160 | BD = BD[:, :, :-1].view(bsz, n_head, max_len, -1) # bsz x n_head x 2max_len x max_len 161 | BD = BD[:, :, :, max_len:] 162 | return BD 163 | 164 | def _transpose_shift(self, E): 165 | """ 166 | 类似 167 | -3 -2 -1 0 1 2 168 | -30 -20 -10 00 10 20 169 | -300 -200 -100 000 100 200 170 | 171 | 转换为 172 | 0 -10 -200 173 | 1 00 -100 174 | 2 10 000 175 | 176 | 177 | :param E: batch_size x n_head x max_len x 2max_len 178 | :return: batch_size x n_head x max_len x max_len 179 | """ 180 | bsz, n_head, max_len, _ = E.size() 181 | zero_pad = E.new_zeros(bsz, n_head, max_len, 1) 182 | # bsz x n_head x -1 x (max_len+1) 183 | E = torch.cat([E, zero_pad], dim=-1).view(bsz, n_head, -1, max_len) 184 | indice = (torch.arange(max_len)*2+1).to(E.device) 185 | E = E.index_select(index=indice, dim=-2).transpose(-1,-2) # bsz x n_head x max_len x max_len 186 | 187 | return E 188 | -------------------------------------------------------------------------------- /modules/transformer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from torch import nn 7 | import math 8 | from copy import deepcopy 9 | 10 | from .relative_transformer import RelativeMultiHeadAttn 11 | 12 | 13 | class MultiHeadAttn(nn.Module): 14 | def __init__(self, d_model, n_head, dropout=0.1, scale=False): 15 | """ 16 | 17 | :param d_model: 18 | :param n_head: 19 | :param scale: 是否scale输出 20 | """ 21 | super().__init__() 22 | assert d_model%n_head==0 23 | 24 | self.n_head = n_head 25 | self.qkv_linear = nn.Linear(d_model, 3*d_model, bias=False) 26 | self.fc = nn.Linear(d_model, d_model) 27 | self.dropout_layer = nn.Dropout(dropout) 28 | 29 | if scale: 30 | self.scale = math.sqrt(d_model//n_head) 31 | else: 32 | self.scale = 1 33 | 34 | def forward(self, x, mask): 35 | """ 36 | 37 | :param x: bsz x max_len x d_model 38 | :param mask: bsz x max_len 39 | :return: 40 | """ 41 | batch_size, max_len, d_model = x.size() 42 | x = self.qkv_linear(x) 43 | q, k, v = torch.chunk(x, 3, dim=-1) 44 | q = q.view(batch_size, max_len, self.n_head, -1).transpose(1, 2) 45 | k = k.view(batch_size, max_len, self.n_head, -1).permute(0, 2, 3, 1) 46 | v = v.view(batch_size, max_len, self.n_head, -1).transpose(1, 2) 47 | 48 | attn = torch.matmul(q, k) # batch_size x n_head x max_len x max_len 49 | attn = attn/self.scale 50 | attn.masked_fill_(mask=mask[:, None, None].eq(0), value=float('-inf')) 51 | 52 | attn = F.softmax(attn, dim=-1) # batch_size x n_head x max_len x max_len 53 | attn = self.dropout_layer(attn) 54 | v = torch.matmul(attn, v) # batch_size x n_head x max_len x d_model//n_head 55 | v = v.transpose(1, 2).reshape(batch_size, max_len, -1) 56 | v = self.fc(v) 57 | 58 | return v 59 | 60 | 61 | class TransformerLayer(nn.Module): 62 | def __init__(self, d_model, self_attn, feedforward_dim, after_norm, dropout): 63 | """ 64 | 65 | :param int d_model: 一般512之类的 66 | :param self_attn: self attention模块,输入为x:batch_size x max_len x d_model, mask:batch_size x max_len, 输出为 67 | batch_size x max_len x d_model 68 | :param int feedforward_dim: FFN中间层的dimension的大小 69 | :param bool after_norm: norm的位置不一样,如果为False,则embedding可以直接连到输出 70 | :param float dropout: 一共三个位置的dropout的大小 71 | """ 72 | super().__init__() 73 | 74 | self.norm1 = nn.LayerNorm(d_model) 75 | self.norm2 = nn.LayerNorm(d_model) 76 | 77 | self.self_attn = self_attn 78 | 79 | self.after_norm = after_norm 80 | 81 | self.ffn = nn.Sequential(nn.Linear(d_model, feedforward_dim), 82 | nn.ReLU(), 83 | nn.Dropout(dropout), 84 | nn.Linear(feedforward_dim, d_model), 85 | nn.Dropout(dropout)) 86 | 87 | def forward(self, x, mask): 88 | """ 89 | 90 | :param x: batch_size x max_len x hidden_size 91 | :param mask: batch_size x max_len, 为0的地方为pad 92 | :return: batch_size x max_len x hidden_size 93 | """ 94 | residual = x 95 | if not self.after_norm: 96 | x = self.norm1(x) 97 | 98 | x = self.self_attn(x, mask) 99 | x = x + residual 100 | if self.after_norm: 101 | x = self.norm1(x) 102 | residual = x 103 | if not self.after_norm: 104 | x = self.norm2(x) 105 | x = self.ffn(x) 106 | x = residual + x 107 | if self.after_norm: 108 | x = self.norm2(x) 109 | return x 110 | 111 | 112 | class TransformerEncoder(nn.Module): 113 | def __init__(self, num_layers, d_model, n_head, feedforward_dim, dropout, after_norm=True, attn_type='naive', 114 | scale=False, dropout_attn=None, pos_embed=None): 115 | super().__init__() 116 | if dropout_attn is None: 117 | dropout_attn = dropout 118 | self.d_model = d_model 119 | 120 | if pos_embed is None: 121 | self.pos_embed = None 122 | elif pos_embed == 'sin': 123 | self.pos_embed = SinusoidalPositionalEmbedding(d_model, 0, init_size=1024) 124 | elif pos_embed == 'fix': 125 | self.pos_embed = LearnedPositionalEmbedding(1024, d_model, 0) 126 | 127 | if attn_type == 'transformer': 128 | self_attn = MultiHeadAttn(d_model, n_head, dropout_attn, scale=scale) 129 | elif attn_type == 'adatrans': 130 | self_attn = RelativeMultiHeadAttn(d_model, n_head, dropout_attn, scale=scale) 131 | 132 | self.layers = nn.ModuleList([TransformerLayer(d_model, deepcopy(self_attn), feedforward_dim, after_norm, dropout) 133 | for _ in range(num_layers)]) 134 | 135 | def forward(self, x, mask): 136 | """ 137 | 138 | :param x: batch_size x max_len 139 | :param mask: batch_size x max_len. 有value的地方为1 140 | :return: 141 | """ 142 | if self.pos_embed is not None: 143 | x = x + self.pos_embed(mask) 144 | 145 | for layer in self.layers: 146 | x = layer(x, mask) 147 | return x 148 | 149 | 150 | def make_positions(tensor, padding_idx): 151 | """Replace non-padding symbols with their position numbers. 152 | Position numbers begin at padding_idx+1. Padding symbols are ignored. 153 | """ 154 | # The series of casts and type-conversions here are carefully 155 | # balanced to both work with ONNX export and XLA. In particular XLA 156 | # prefers ints, cumsum defaults to output longs, and ONNX doesn't know 157 | # how to handle the dtype kwarg in cumsum. 158 | mask = tensor.ne(padding_idx).int() 159 | return ( 160 | torch.cumsum(mask, dim=1).type_as(mask) * mask 161 | ).long() + padding_idx 162 | 163 | 164 | class SinusoidalPositionalEmbedding(nn.Module): 165 | """This module produces sinusoidal positional embeddings of any length. 166 | Padding symbols are ignored. 167 | """ 168 | 169 | def __init__(self, embedding_dim, padding_idx, init_size=1568): 170 | super().__init__() 171 | self.embedding_dim = embedding_dim 172 | self.padding_idx = padding_idx 173 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 174 | init_size, 175 | embedding_dim, 176 | padding_idx, 177 | ) 178 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 179 | 180 | @staticmethod 181 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None): 182 | """Build sinusoidal embeddings. 183 | This matches the implementation in tensor2tensor, but differs slightly 184 | from the description in Section 3.5 of "Attention Is All You Need". 185 | """ 186 | half_dim = embedding_dim // 2 187 | emb = math.log(10000) / (half_dim - 1) 188 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 189 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 190 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 191 | if embedding_dim % 2 == 1: 192 | # zero pad 193 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 194 | if padding_idx is not None: 195 | emb[padding_idx, :] = 0 196 | return emb 197 | 198 | def forward(self, input): 199 | """Input is expected to be of size [bsz x seqlen].""" 200 | bsz, seq_len = input.size() 201 | max_pos = self.padding_idx + 1 + seq_len 202 | if max_pos > self.weights.size(0): 203 | # recompute/expand embeddings if needed 204 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 205 | max_pos, 206 | self.embedding_dim, 207 | self.padding_idx, 208 | ) 209 | self.weights = self.weights.to(self._float_tensor) 210 | 211 | positions = make_positions(input, self.padding_idx) 212 | return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() 213 | 214 | def max_positions(self): 215 | """Maximum number of supported positions.""" 216 | return int(1e5) # an arbitrary large number 217 | 218 | 219 | class LearnedPositionalEmbedding(nn.Embedding): 220 | """ 221 | This module learns positional embeddings up to a fixed maximum size. 222 | Padding ids are ignored by either offsetting based on padding_idx 223 | or by setting padding_idx to None and ensuring that the appropriate 224 | position ids are passed to the forward function. 225 | """ 226 | 227 | def __init__( 228 | self, 229 | num_embeddings: int, 230 | embedding_dim: int, 231 | padding_idx: int, 232 | ): 233 | super().__init__(num_embeddings, embedding_dim, padding_idx) 234 | 235 | def forward(self, input): 236 | # positions: batch_size x max_len, 把words的index输入就好了 237 | positions = make_positions(input, self.padding_idx) 238 | return super().forward(positions) 239 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | # 设置随机数种子 2 | 3 | 4 | def set_rng_seed(rng_seed:int = None, random:bool = True, numpy:bool = True, 5 | pytorch:bool=True, deterministic:bool=True): 6 | """ 7 | 设置模块的随机数种子。由于pytorch还存在cudnn导致的非deterministic的运行,所以一些情况下可能即使seed一样,结果也不一致 8 | 需要在fitlog.commit()或fitlog.set_log_dir()之后运行才会记录该rng_seed到log中 9 | :param int rng_seed: 将这些模块的随机数设置到多少,默认为随机生成一个。 10 | :param bool, random: 是否将python自带的random模块的seed设置为rng_seed. 11 | :param bool, numpy: 是否将numpy的seed设置为rng_seed. 12 | :param bool, pytorch: 是否将pytorch的seed设置为rng_seed(设置torch.manual_seed和torch.cuda.manual_seed_all). 13 | :param bool, deterministic: 是否将pytorch的torch.backends.cudnn.deterministic设置为True 14 | """ 15 | if rng_seed is None: 16 | import time 17 | rng_seed = int(time.time()%1000000) 18 | if random: 19 | import random 20 | random.seed(rng_seed) 21 | if numpy: 22 | try: 23 | import numpy 24 | numpy.random.seed(rng_seed) 25 | except: 26 | pass 27 | if pytorch: 28 | try: 29 | import torch 30 | torch.manual_seed(rng_seed) 31 | torch.cuda.manual_seed_all(rng_seed) 32 | if deterministic: 33 | torch.backends.cudnn.deterministic = True 34 | except: 35 | pass 36 | return rng_seed -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastNLP>=0.5 -------------------------------------------------------------------------------- /train_elmo_en.py: -------------------------------------------------------------------------------- 1 | 2 | from models.TENER import TENER 3 | from fastNLP.embeddings import CNNCharEmbedding 4 | from fastNLP import cache_results 5 | from fastNLP import Trainer, GradientClipCallback, WarmupCallback 6 | from torch import optim 7 | from fastNLP import SpanFPreRecMetric, BucketSampler 8 | from fastNLP.io.pipe.conll import OntoNotesNERPipe 9 | from fastNLP.embeddings import StaticEmbedding, StackEmbedding, LSTMCharEmbedding, ElmoEmbedding 10 | from modules.TransformerEmbedding import TransformerCharEmbed 11 | from modules.pipe import Conll2003NERPipe 12 | from modules.callbacks import EvaluateCallback 13 | 14 | import argparse 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument('--dataset', type=str, default='en-ontonotes', choices=['conll2003', 'en-ontonotes']) 20 | 21 | args = parser.parse_args() 22 | 23 | dataset = args.dataset 24 | if dataset == 'en-ontonotes': 25 | n_heads = 10 26 | head_dims = 96 27 | num_layers = 2 28 | lr = 0.0009 29 | attn_type = 'adatrans' 30 | optim_type = 'sgd' 31 | trans_dropout = 0.15 32 | batch_size = 16 33 | elif dataset == 'conll2003': 34 | n_heads = 12 35 | head_dims = 128 36 | num_layers = 2 37 | lr = 0.0001 38 | attn_type = 'adatrans' 39 | optim_type = 'adam' 40 | trans_dropout = 0.45 # 有可能是0.4 41 | batch_size = 32 42 | else: 43 | raise RuntimeError("Only support conll2003, en-ontonotes") 44 | 45 | 46 | char_type = 'adatrans' 47 | 48 | pos_embed = None 49 | 50 | model_type = 'elmo' 51 | warmup_steps = 0.01 52 | after_norm = 1 53 | fc_dropout=0.4 54 | normalize_embed = True 55 | 56 | encoding_type = 'bioes' 57 | name = 'caches/elmo_{}_{}_{}_{}_{}.pkl'.format(dataset, model_type, encoding_type, char_type, normalize_embed) 58 | d_model = n_heads * head_dims 59 | dim_feedforward = int(2 * d_model) 60 | 61 | device = 0 62 | 63 | # scale为1时,同时character和模型的scale都是1 64 | 65 | @cache_results(name, _refresh=False) 66 | def load_data(): 67 | # 替换路径 68 | if dataset == 'conll2003': 69 | # conll2003的lr不能超过0.002 70 | paths = {'test': "../data/conll2003/test.txt", 71 | 'train': "../data/conll2003/train.txt", 72 | 'dev': "../data/conll2003/dev.txt"} 73 | data = Conll2003NERPipe(encoding_type=encoding_type).process_from_file(paths) 74 | elif dataset == 'en-ontonotes': 75 | paths = '../data/en-ontonotes/english' 76 | data = OntoNotesNERPipe(encoding_type=encoding_type).process_from_file(paths) 77 | char_embed = None 78 | if char_type == 'cnn': 79 | char_embed = CNNCharEmbedding(vocab=data.get_vocab('words'), embed_size=30, char_emb_size=30, filter_nums=[30], 80 | kernel_sizes=[3], word_dropout=0, dropout=0.3, pool_method='max' 81 | , include_word_start_end=False, min_char_freq=2) 82 | elif char_type in ['adatrans', 'naive']: 83 | char_embed = TransformerCharEmbed(vocab=data.get_vocab('words'), embed_size=30, char_emb_size=30, word_dropout=0, 84 | dropout=0.3, pool_method='max', activation='relu', 85 | min_char_freq=2, requires_grad=True, include_word_start_end=False, 86 | char_attn_type=char_type, char_n_head=3, char_dim_ffn=60, char_scale=char_type=='naive', 87 | char_dropout=0.15, char_after_norm=True) 88 | elif char_type == 'lstm': 89 | char_embed = LSTMCharEmbedding(vocab=data.get_vocab('words'), embed_size=30, char_emb_size=30, word_dropout=0, 90 | dropout=0.3, hidden_size=100, pool_method='max', activation='relu', 91 | min_char_freq=2, bidirectional=True, requires_grad=True, include_word_start_end=False) 92 | word_embed = StaticEmbedding(vocab=data.get_vocab('words'), 93 | model_dir_or_name='en-glove-6b-100d', 94 | requires_grad=True, lower=True, word_dropout=0, dropout=0.5, 95 | only_norm_found_vector=normalize_embed) 96 | data.rename_field('words', 'chars') 97 | 98 | embed = ElmoEmbedding(vocab=data.get_vocab('chars'), model_dir_or_name='en-original', layers='mix', requires_grad=False, 99 | word_dropout=0.0, dropout=0.5, cache_word_reprs=False) 100 | embed.set_mix_weights_requires_grad() 101 | 102 | embed = StackEmbedding([embed, word_embed, char_embed], dropout=0, word_dropout=0.02) 103 | 104 | return data, embed 105 | 106 | data_bundle, embed = load_data() 107 | print(data_bundle) 108 | 109 | model = TENER(tag_vocab=data_bundle.get_vocab('target'), embed=embed, num_layers=num_layers, 110 | d_model=d_model, n_head=n_heads, 111 | feedforward_dim=dim_feedforward, dropout=trans_dropout, 112 | after_norm=after_norm, attn_type=attn_type, 113 | bi_embed=None, 114 | fc_dropout=fc_dropout, 115 | pos_embed=pos_embed, 116 | scale=attn_type=='naive') 117 | 118 | if optim_type == 'sgd': 119 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9) 120 | else: 121 | optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99)) 122 | 123 | callbacks = [] 124 | clip_callback = GradientClipCallback(clip_type='value', clip_value=5) 125 | evaluate_callback = EvaluateCallback(data_bundle.get_dataset('test')) 126 | 127 | if warmup_steps>0: 128 | warmup_callback = WarmupCallback(warmup_steps, schedule='linear') 129 | callbacks.append(warmup_callback) 130 | callbacks.extend([clip_callback, evaluate_callback]) 131 | 132 | trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer, batch_size=batch_size, sampler=BucketSampler(), 133 | num_workers=0, n_epochs=100, dev_data=data_bundle.get_dataset('dev'), 134 | metrics=SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'), encoding_type=encoding_type), 135 | dev_batch_size=batch_size, callbacks=callbacks, device=device, test_use_tqdm=False, 136 | use_tqdm=True, print_every=300, save_path=None) 137 | trainer.train(load_best_model=False) 138 | -------------------------------------------------------------------------------- /train_tener_cn.py: -------------------------------------------------------------------------------- 1 | from models.TENER import TENER 2 | from fastNLP import cache_results 3 | from fastNLP import Trainer, GradientClipCallback, WarmupCallback 4 | from torch import optim 5 | from fastNLP import SpanFPreRecMetric, BucketSampler 6 | from fastNLP.embeddings import StaticEmbedding 7 | from modules.pipe import CNNERPipe 8 | 9 | import argparse 10 | from modules.callbacks import EvaluateCallback 11 | 12 | device = 0 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--dataset', type=str, default='resume', choices=['weibo', 'resume', 'ontonotes', 'msra']) 16 | 17 | args = parser.parse_args() 18 | 19 | dataset = args.dataset 20 | if dataset == 'resume': 21 | n_heads = 4 22 | head_dims = 64 23 | num_layers = 2 24 | lr = 0.0007 25 | attn_type = 'adatrans' 26 | n_epochs = 50 27 | elif dataset == 'weibo': 28 | n_heads = 4 29 | head_dims = 32 30 | num_layers = 1 31 | lr = 0.001 32 | attn_type = 'adatrans' 33 | n_epochs = 100 34 | elif dataset == 'ontonotes': 35 | n_heads = 4 36 | head_dims = 48 37 | num_layers = 2 38 | lr = 0.0007 39 | attn_type = 'adatrans' 40 | n_epochs = 100 41 | elif dataset == 'msra': 42 | n_heads = 6 43 | head_dims = 80 44 | num_layers = 2 45 | lr = 0.0007 46 | attn_type = 'adatrans' 47 | n_epochs = 100 48 | 49 | pos_embed = None 50 | 51 | batch_size = 16 52 | warmup_steps = 0.01 53 | after_norm = 1 54 | model_type = 'transformer' 55 | normalize_embed = True 56 | 57 | dropout=0.15 58 | fc_dropout=0.4 59 | 60 | encoding_type = 'bmeso' 61 | name = 'caches/{}_{}_{}_{}.pkl'.format(dataset, model_type, encoding_type, normalize_embed) 62 | d_model = n_heads * head_dims 63 | dim_feedforward = int(2 * d_model) 64 | 65 | 66 | @cache_results(name, _refresh=False) 67 | def load_data(): 68 | # 替换路径 69 | if dataset == 'ontonotes': 70 | paths = {'train':'../data/OntoNote4NER/train.char.bmes', 71 | "dev":'../data/OntoNote4NER/dev.char.bmes', 72 | "test":'../data/OntoNote4NER/test.char.bmes'} 73 | min_freq = 2 74 | elif dataset == 'weibo': 75 | paths = {'train': '../data/WeiboNER/train.all.bmes', 76 | 'dev':'../data/WeiboNER/dev.all.bmes', 77 | 'test':'../data/WeiboNER/test.all.bmes'} 78 | min_freq = 1 79 | elif dataset == 'resume': 80 | paths = {'train': '../data/ResumeNER/train.char.bmes', 81 | 'dev':'../data/ResumeNER/dev.char.bmes', 82 | 'test':'../data/ResumeNER/test.char.bmes'} 83 | min_freq = 1 84 | elif dataset == 'msra': 85 | paths = {'train': '../data/MSRANER/train_dev.char.bmes', 86 | 'dev':'../data/MSRANER/test.char.bmes', 87 | 'test':'../data/MSRANER/test.char.bmes'} 88 | min_freq = 2 89 | data_bundle = CNNERPipe(bigrams=True, encoding_type=encoding_type).process_from_file(paths) 90 | embed = StaticEmbedding(data_bundle.get_vocab('chars'), 91 | model_dir_or_name='../data/gigaword_chn.all.a2b.uni.ite50.vec', 92 | min_freq=1, only_norm_found_vector=normalize_embed, word_dropout=0.01, dropout=0.3) 93 | 94 | bi_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), 95 | model_dir_or_name='../data/gigaword_chn.all.a2b.bi.ite50.vec', 96 | word_dropout=0.02, dropout=0.3, min_freq=min_freq, 97 | only_norm_found_vector=normalize_embed, only_train_min_freq=True) 98 | 99 | return data_bundle, embed, bi_embed 100 | 101 | data_bundle, embed, bi_embed = load_data() 102 | print(data_bundle) 103 | 104 | model = TENER(tag_vocab=data_bundle.get_vocab('target'), embed=embed, num_layers=num_layers, 105 | d_model=d_model, n_head=n_heads, 106 | feedforward_dim=dim_feedforward, dropout=dropout, 107 | after_norm=after_norm, attn_type=attn_type, 108 | bi_embed=bi_embed, 109 | fc_dropout=fc_dropout, 110 | pos_embed=pos_embed, 111 | scale=attn_type=='transformer') 112 | 113 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9) 114 | 115 | callbacks = [] 116 | clip_callback = GradientClipCallback(clip_type='value', clip_value=5) 117 | evaluate_callback = EvaluateCallback(data_bundle.get_dataset('test')) 118 | 119 | if warmup_steps>0: 120 | warmup_callback = WarmupCallback(warmup_steps, schedule='linear') 121 | callbacks.append(warmup_callback) 122 | callbacks.extend([clip_callback, evaluate_callback]) 123 | 124 | trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer, batch_size=batch_size, sampler=BucketSampler(), 125 | num_workers=2, n_epochs=n_epochs, dev_data=data_bundle.get_dataset('dev'), 126 | metrics=SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'), encoding_type=encoding_type), 127 | dev_batch_size=batch_size, callbacks=callbacks, device=device, test_use_tqdm=False, 128 | use_tqdm=True, print_every=300, save_path=None) 129 | trainer.train(load_best_model=False) 130 | -------------------------------------------------------------------------------- /train_tener_en.py: -------------------------------------------------------------------------------- 1 | from models.TENER import TENER 2 | from fastNLP.embeddings import CNNCharEmbedding 3 | from fastNLP import cache_results 4 | from fastNLP import Trainer, GradientClipCallback, WarmupCallback 5 | from torch import optim 6 | from fastNLP import SpanFPreRecMetric, BucketSampler 7 | from fastNLP.io.pipe.conll import OntoNotesNERPipe 8 | from fastNLP.embeddings import StaticEmbedding, StackEmbedding, LSTMCharEmbedding 9 | from modules.TransformerEmbedding import TransformerCharEmbed 10 | from modules.pipe import Conll2003NERPipe 11 | 12 | import argparse 13 | from modules.callbacks import EvaluateCallback 14 | 15 | device = 0 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument('--dataset', type=str, default='en-ontonotes', choices=['conll2003', 'en-ontonotes']) 19 | 20 | args = parser.parse_args() 21 | 22 | dataset = args.dataset 23 | 24 | if dataset == 'conll2003': 25 | n_heads = 14 26 | head_dims = 128 27 | num_layers = 2 28 | lr = 0.0009 29 | attn_type = 'adatrans' 30 | char_type = 'cnn' 31 | elif dataset == 'en-ontonotes': 32 | n_heads = 8 33 | head_dims = 96 34 | num_layers = 2 35 | lr = 0.0007 36 | attn_type = 'adatrans' 37 | char_type = 'adatrans' 38 | 39 | pos_embed = None 40 | 41 | #########hyper 42 | batch_size = 16 43 | warmup_steps = 0.01 44 | after_norm = 1 45 | model_type = 'transformer' 46 | normalize_embed = True 47 | #########hyper 48 | 49 | dropout=0.15 50 | fc_dropout=0.4 51 | 52 | encoding_type = 'bioes' 53 | name = 'caches/{}_{}_{}_{}_{}.pkl'.format(dataset, model_type, encoding_type, char_type, normalize_embed) 54 | d_model = n_heads * head_dims 55 | dim_feedforward = int(2 * d_model) 56 | 57 | 58 | 59 | @cache_results(name, _refresh=False) 60 | def load_data(): 61 | # 替换路径 62 | if dataset == 'conll2003': 63 | # conll2003的lr不能超过0.002 64 | paths = {'test': "../data/conll2003/test.txt", 65 | 'train': "../data/conll2003/train.txt", 66 | 'dev': "../data/conll2003/dev.txt"} 67 | data = Conll2003NERPipe(encoding_type=encoding_type).process_from_file(paths) 68 | elif dataset == 'en-ontonotes': 69 | # 会使用这个文件夹下的train.txt, test.txt, dev.txt等文件 70 | paths = '../data/en-ontonotes/english' 71 | data = OntoNotesNERPipe(encoding_type=encoding_type).process_from_file(paths) 72 | char_embed = None 73 | if char_type == 'cnn': 74 | char_embed = CNNCharEmbedding(vocab=data.get_vocab('words'), embed_size=30, char_emb_size=30, filter_nums=[30], 75 | kernel_sizes=[3], word_dropout=0, dropout=0.3, pool_method='max' 76 | , include_word_start_end=False, min_char_freq=2) 77 | elif char_type in ['adatrans', 'naive']: 78 | char_embed = TransformerCharEmbed(vocab=data.get_vocab('words'), embed_size=30, char_emb_size=30, word_dropout=0, 79 | dropout=0.3, pool_method='max', activation='relu', 80 | min_char_freq=2, requires_grad=True, include_word_start_end=False, 81 | char_attn_type=char_type, char_n_head=3, char_dim_ffn=60, char_scale=char_type=='naive', 82 | char_dropout=0.15, char_after_norm=True) 83 | elif char_type == 'lstm': 84 | char_embed = LSTMCharEmbedding(vocab=data.get_vocab('words'), embed_size=30, char_emb_size=30, word_dropout=0, 85 | dropout=0.3, hidden_size=100, pool_method='max', activation='relu', 86 | min_char_freq=2, bidirectional=True, requires_grad=True, include_word_start_end=False) 87 | word_embed = StaticEmbedding(vocab=data.get_vocab('words'), 88 | model_dir_or_name='en-glove-6b-100d', 89 | requires_grad=True, lower=True, word_dropout=0, dropout=0.5, 90 | only_norm_found_vector=normalize_embed) 91 | if char_embed is not None: 92 | embed = StackEmbedding([word_embed, char_embed], dropout=0, word_dropout=0.02) 93 | else: 94 | word_embed.word_drop = 0.02 95 | embed = word_embed 96 | 97 | data.rename_field('words', 'chars') 98 | return data, embed 99 | 100 | data_bundle, embed = load_data() 101 | print(data_bundle) 102 | 103 | model = TENER(tag_vocab=data_bundle.get_vocab('target'), embed=embed, num_layers=num_layers, 104 | d_model=d_model, n_head=n_heads, 105 | feedforward_dim=dim_feedforward, dropout=dropout, 106 | after_norm=after_norm, attn_type=attn_type, 107 | bi_embed=None, 108 | fc_dropout=fc_dropout, 109 | pos_embed=pos_embed, 110 | scale=attn_type=='transformer') 111 | 112 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9) 113 | 114 | callbacks = [] 115 | clip_callback = GradientClipCallback(clip_type='value', clip_value=5) 116 | evaluate_callback = EvaluateCallback(data_bundle.get_dataset('test')) 117 | 118 | if warmup_steps>0: 119 | warmup_callback = WarmupCallback(warmup_steps, schedule='linear') 120 | callbacks.append(warmup_callback) 121 | callbacks.extend([clip_callback, evaluate_callback]) 122 | 123 | trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer, batch_size=batch_size, sampler=BucketSampler(), 124 | num_workers=2, n_epochs=100, dev_data=data_bundle.get_dataset('dev'), 125 | metrics=SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'), encoding_type=encoding_type), 126 | dev_batch_size=batch_size*5, callbacks=callbacks, device=device, test_use_tqdm=False, 127 | use_tqdm=True, print_every=300, save_path=None) 128 | trainer.train(load_best_model=False) 129 | --------------------------------------------------------------------------------