├── data ├── dev.txt ├── train.txt ├── tag.txt └── bert │ └── bert_config.json ├── model ├── __init__.py ├── bert_lstm_crf.py └── crf.py ├── README.md ├── LICENSE ├── config.py ├── utils.py └── main.py /data/dev.txt: -------------------------------------------------------------------------------- 1 | 李 白 是 个 诗 人|||B-PER E-PER O O O O 2 | -------------------------------------------------------------------------------- /data/train.txt: -------------------------------------------------------------------------------- 1 | 李 白 是 个 诗 人|||B-PER E-PER O O O O 2 | -------------------------------------------------------------------------------- /data/tag.txt: -------------------------------------------------------------------------------- 1 | 2 | B-PER 3 | I-PER 4 | E-PER 5 | O 6 | 7 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from model.crf import CRF 3 | from model.bert_lstm_crf import BERT_LSTM_CRF 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bert-BiLSTM-CRF-pytorch 2 | 使用谷歌预训练bert做字嵌入的BiLSTM-CRF序列标注模型 3 | 4 | 本模型使用谷歌预训练bert模型(https://github.com/google-research/bert), 5 | 同时使用pytorch-pretrained-BERT(https://github.com/huggingface/pytorch-pretrained-BERT) 6 | 项目加载bert模型并转化为pytorch参数,CRF代码参考了SLTK(https://github.com/liu-nlper/SLTK) 7 | 8 | 准备数据格式参见data 9 | 10 | 模型参数可以在config中进行设置 11 | 12 | 运行代码 13 | 14 | python main.py train --use_cuda=False --batch_size=10 15 | 16 | pytorch.bin 百度网盘链接 链接:https://pan.baidu.com/s/160cvZXyR_qdAv801bDY2mQ 提取码:q67r  17 | 18 | 作者也是新手,很希望看到的大家能够提意见,共同学习 19 | -------------------------------------------------------------------------------- /data/bert/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128 19 | } 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 chenxiaoyouyou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | 4 | class Config(object): 5 | def __init__(self): 6 | self.label_file = './data/tag.txt' 7 | self.train_file = './data/train.txt' 8 | self.dev_file = './data/dev.txt' 9 | self.test_file = './data/test.txt' 10 | self.vocab = './data/bert/vocab.txt' 11 | self.max_length = 300 12 | self.use_cuda = False 13 | self.gpu = 0 14 | self.batch_size = 50 15 | self.bert_path = './data/bert' 16 | self.rnn_hidden = 500 17 | self.bert_embedding = 768 18 | self.dropout1 = 0.5 19 | self.dropout_ratio = 0.5 20 | self.rnn_layer = 1 21 | self.lr = 0.0001 22 | self.lr_decay = 0.00001 23 | self.weight_decay = 0.00005 24 | self.checkpoint = 'result/' 25 | self.optim = 'Adam' 26 | self.load_model = False 27 | self.load_path = None 28 | self.base_epoch = 100 29 | 30 | def update(self, **kwargs): 31 | for k, v in kwargs.items(): 32 | setattr(self, k, v) 33 | 34 | def __str__(self): 35 | 36 | return '\n'.join(['%s:%s' % item for item in self.__dict__.items()]) 37 | 38 | 39 | if __name__ == '__main__': 40 | 41 | con = Config() 42 | con.update(gpu=8) 43 | print(con.gpu) 44 | print(con) 45 | -------------------------------------------------------------------------------- /model/bert_lstm_crf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # coding=utf-8 3 | import torch.nn as nn 4 | from pytorch_pretrained_bert import BertModel 5 | from model import CRF 6 | from torch.autograd import Variable 7 | import torch 8 | import ipdb 9 | 10 | 11 | class BERT_LSTM_CRF(nn.Module): 12 | """ 13 | bert_lstm_crf model 14 | """ 15 | def __init__(self, bert_config, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda=False): 16 | super(BERT_LSTM_CRF, self).__init__() 17 | self.embedding_dim = embedding_dim 18 | self.hidden_dim = hidden_dim 19 | self.word_embeds = BertModel.from_pretrained(bert_config) 20 | self.lstm = nn.LSTM(embedding_dim, hidden_dim, 21 | num_layers=rnn_layers, bidirectional=True, dropout=dropout_ratio, batch_first=True) 22 | self.rnn_layers = rnn_layers 23 | self.dropout1 = nn.Dropout(p=dropout1) 24 | self.crf = CRF(target_size=tagset_size, average_batch=True, use_cuda=use_cuda) 25 | self.liner = nn.Linear(hidden_dim*2, tagset_size+2) 26 | self.tagset_size = tagset_size 27 | 28 | def rand_init_hidden(self, batch_size): 29 | """ 30 | random initialize hidden variable 31 | """ 32 | return Variable( 33 | torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)), Variable( 34 | torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)) 35 | 36 | def forward(self, sentence, attention_mask=None): 37 | ''' 38 | args: 39 | sentence (word_seq_len, batch_size) : word-level representation of sentence 40 | hidden: initial hidden state 41 | 42 | return: 43 | crf output (word_seq_len, batch_size, tag_size, tag_size), hidden 44 | ''' 45 | batch_size = sentence.size(0) 46 | seq_length = sentence.size(1) 47 | embeds, _ = self.word_embeds(sentence, attention_mask=attention_mask, output_all_encoded_layers=False) 48 | hidden = self.rand_init_hidden(batch_size) 49 | if embeds.is_cuda: 50 | hidden = (i.cuda() for i in hidden) 51 | lstm_out, hidden = self.lstm(embeds, hidden) 52 | lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim*2) 53 | d_lstm_out = self.dropout1(lstm_out) 54 | l_out = self.liner(d_lstm_out) 55 | lstm_feats = l_out.contiguous().view(batch_size, seq_length, -1) 56 | return lstm_feats 57 | 58 | def loss(self, feats, mask, tags): 59 | """ 60 | feats: size=(batch_size, seq_len, tag_size) 61 | mask: size=(batch_size, seq_len) 62 | tags: size=(batch_size, seq_len) 63 | :return: 64 | """ 65 | loss_value = self.crf.neg_log_likelihood_loss(feats, mask, tags) 66 | batch_size = feats.size(0) 67 | loss_value /= float(batch_size) 68 | return loss_value 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import os 4 | import datetime 5 | import unicodedata 6 | 7 | 8 | class InputFeatures(object): 9 | def __init__(self, input_id, label_id, input_mask): 10 | self.input_id = input_id 11 | self.label_id = label_id 12 | self.input_mask = input_mask 13 | 14 | 15 | def load_vocab(vocab_file): 16 | """Loads a vocabulary file into a dictionary.""" 17 | vocab = {} 18 | index = 0 19 | with open(vocab_file, "r", encoding="utf-8") as reader: 20 | while True: 21 | token = reader.readline() 22 | if not token: 23 | break 24 | token = token.strip() 25 | vocab[token] = index 26 | index += 1 27 | return vocab 28 | 29 | 30 | def read_corpus(path, max_length, label_dic, vocab): 31 | """ 32 | :param path:数据文件路径 33 | :param max_length: 最大长度 34 | :param label_dic: 标签字典 35 | :return: 36 | """ 37 | file = open(path, encoding='utf-8') 38 | content = file.readlines() 39 | file.close() 40 | result = [] 41 | for line in content: 42 | text, label = line.strip().split('|||') 43 | tokens = text.split() 44 | label = label.split() 45 | if len(tokens) > max_length-2: 46 | tokens = tokens[0:(max_length-2)] 47 | label = label[0:(max_length-2)] 48 | tokens_f =['[CLS]'] + tokens + ['[SEP]'] 49 | label_f = [""] + label + [''] 50 | input_ids = [int(vocab[i]) if i in vocab else int(vocab['[UNK]']) for i in tokens_f] 51 | label_ids = [label_dic[i] for i in label_f] 52 | input_mask = [1] * len(input_ids) 53 | while len(input_ids) < max_length: 54 | input_ids.append(0) 55 | input_mask.append(0) 56 | label_ids.append(label_dic['']) 57 | assert len(input_ids) == max_length 58 | assert len(input_mask) == max_length 59 | assert len(label_ids) == max_length 60 | feature = InputFeatures(input_id=input_ids, input_mask=input_mask, label_id=label_ids) 61 | result.append(feature) 62 | return result 63 | 64 | 65 | def save_model(model, epoch, path='result', **kwargs): 66 | """ 67 | 默认保留所有模型 68 | :param model: 模型 69 | :param path: 保存路径 70 | :param loss: 校验损失 71 | :param last_loss: 最佳epoch损失 72 | :param kwargs: every_epoch or best_epoch 73 | :return: 74 | """ 75 | if not os.path.exists(path): 76 | os.mkdir(path) 77 | if kwargs.get('name', None) is None: 78 | cur_time = datetime.datetime.now().strftime('%Y-%m-%d#%H:%M:%S') 79 | name = cur_time + '--epoch:{}'.format(epoch) 80 | full_name = os.path.join(path, name) 81 | torch.save(model.state_dict(), full_name) 82 | print('Saved model at epoch {} successfully'.format(epoch)) 83 | with open('{}/checkpoint'.format(path), 'w') as file: 84 | file.write(name) 85 | print('Write to checkpoint') 86 | 87 | 88 | def load_model(model, path='result', **kwargs): 89 | if kwargs.get('name', None) is None: 90 | with open('{}/checkpoint'.format(path)) as file: 91 | content = file.read().strip() 92 | name = os.path.join(path, content) 93 | else: 94 | name=kwargs['name'] 95 | name = os.path.join(path,name) 96 | model.load_state_dict(torch.load(name, map_location=lambda storage, loc: storage)) 97 | print('load model {} successfully'.format(name)) 98 | return model 99 | 100 | 101 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from config import Config 6 | from model import BERT_LSTM_CRF 7 | import torch.optim as optim 8 | from utils import load_vocab, read_corpus, load_model, save_model 9 | from torch.utils.data import TensorDataset 10 | from torch.utils.data import DataLoader 11 | import fire 12 | 13 | 14 | def train(**kwargs): 15 | config = Config() 16 | config.update(**kwargs) 17 | print('当前设置为:\n', config) 18 | if config.use_cuda: 19 | torch.cuda.set_device(config.gpu) 20 | print('loading corpus') 21 | vocab = load_vocab(config.vocab) 22 | label_dic = load_vocab(config.label_file) 23 | tagset_size = len(label_dic) 24 | train_data = read_corpus(config.train_file, max_length=config.max_length, label_dic=label_dic, vocab=vocab) 25 | dev_data = read_corpus(config.dev_file, max_length=config.max_length, label_dic=label_dic, vocab=vocab) 26 | 27 | train_ids = torch.LongTensor([temp.input_id for temp in train_data]) 28 | train_masks = torch.LongTensor([temp.input_mask for temp in train_data]) 29 | train_tags = torch.LongTensor([temp.label_id for temp in train_data]) 30 | 31 | train_dataset = TensorDataset(train_ids, train_masks, train_tags) 32 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size) 33 | 34 | dev_ids = torch.LongTensor([temp.input_id for temp in dev_data]) 35 | dev_masks = torch.LongTensor([temp.input_mask for temp in dev_data]) 36 | dev_tags = torch.LongTensor([temp.label_id for temp in dev_data]) 37 | 38 | dev_dataset = TensorDataset(dev_ids, dev_masks, dev_tags) 39 | dev_loader = DataLoader(dev_dataset, shuffle=True, batch_size=config.batch_size) 40 | model = BERT_LSTM_CRF(config.bert_path, tagset_size, config.bert_embedding, config.rnn_hidden, config.rnn_layer, dropout_ratio=config.dropout_ratio, dropout1=config.dropout1, use_cuda=config.use_cuda) 41 | if config.load_model: 42 | assert config.load_path is not None 43 | model = load_model(model, name=config.load_path) 44 | if config.use_cuda: 45 | model.cuda() 46 | model.train() 47 | optimizer = getattr(optim, config.optim) 48 | optimizer = optimizer(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) 49 | eval_loss = 10000 50 | for epoch in range(config.base_epoch): 51 | step = 0 52 | for i, batch in enumerate(train_loader): 53 | step += 1 54 | model.zero_grad() 55 | inputs, masks, tags = batch 56 | inputs, masks, tags = Variable(inputs), Variable(masks), Variable(tags) 57 | if config.use_cuda: 58 | inputs, masks, tags = inputs.cuda(), masks.cuda(), tags.cuda() 59 | feats = model(inputs, masks) 60 | loss = model.loss(feats, masks,tags) 61 | loss.backward() 62 | optimizer.step() 63 | if step % 50 == 0: 64 | print('step: {} | epoch: {}| loss: {}'.format(step, epoch, loss.item())) 65 | loss_temp = dev(model, dev_loader, epoch, config) 66 | if loss_temp < eval_loss: 67 | save_model(model,epoch) 68 | 69 | 70 | def dev(model, dev_loader, epoch, config): 71 | model.eval() 72 | eval_loss = 0 73 | true = [] 74 | pred = [] 75 | length = 0 76 | for i, batch in enumerate(dev_loader): 77 | inputs, masks, tags = batch 78 | length += inputs.size(0) 79 | inputs, masks, tags = Variable(inputs), Variable(masks), Variable(tags) 80 | if config.use_cuda: 81 | inputs, masks, tags = inputs.cuda(), masks.cuda(), tags.cuda() 82 | feats = model(inputs, masks) 83 | path_score, best_path = model.crf(feats, masks.byte()) 84 | loss = model.loss(feats, masks, tags) 85 | eval_loss += loss.item() 86 | pred.extend([t for t in best_path]) 87 | true.extend([t for t in tags]) 88 | print('eval epoch: {}| loss: {}'.format(epoch, eval_loss/length)) 89 | model.train() 90 | return eval_loss 91 | 92 | 93 | if __name__ == '__main__': 94 | fire.Fire() 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /model/crf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import ipdb 7 | 8 | 9 | def log_sum_exp(vec, m_size): 10 | """ 11 | Args: 12 | vec: size=(batch_size, vanishing_dim, hidden_dim) 13 | m_size: hidden_dim 14 | 15 | Returns: 16 | size=(batch_size, hidden_dim) 17 | """ 18 | _, idx = torch.max(vec, 1) # B * 1 * M 19 | max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M 20 | return max_score.view(-1, m_size) + torch.log(torch.sum( 21 | torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) 22 | 23 | 24 | class CRF(nn.Module): 25 | 26 | def __init__(self, **kwargs): 27 | """ 28 | Args: 29 | target_size: int, target size 30 | use_cuda: bool, 是否使用gpu, default is True 31 | average_batch: bool, loss是否作平均, default is True 32 | """ 33 | super(CRF, self).__init__() 34 | for k in kwargs: 35 | self.__setattr__(k, kwargs[k]) 36 | self.START_TAG_IDX, self.END_TAG_IDX = -2, -1 37 | init_transitions = torch.zeros(self.target_size+2, self.target_size+2) 38 | init_transitions[:, self.START_TAG_IDX] = -1000. 39 | init_transitions[self.END_TAG_IDX, :] = -1000. 40 | if self.use_cuda: 41 | init_transitions = init_transitions.cuda() 42 | self.transitions = nn.Parameter(init_transitions) 43 | 44 | def _forward_alg(self, feats, mask=None): 45 | """ 46 | Do the forward algorithm to compute the partition function (batched). 47 | 48 | Args: 49 | feats: size=(batch_size, seq_len, self.target_size+2) 50 | mask: size=(batch_size, seq_len) 51 | 52 | Returns: 53 | xxx 54 | """ 55 | batch_size = feats.size(0) 56 | seq_len = feats.size(1) 57 | tag_size = feats.size(-1) 58 | 59 | mask = mask.transpose(1, 0).contiguous() 60 | ins_num = batch_size * seq_len 61 | feats = feats.transpose(1, 0).contiguous().view( 62 | ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) 63 | 64 | scores = feats + self.transitions.view( 65 | 1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) 66 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 67 | seq_iter = enumerate(scores) 68 | try: 69 | _, inivalues = seq_iter.__next__() 70 | except: 71 | _, inivalues = seq_iter.next() 72 | 73 | partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1) 74 | for idx, cur_values in seq_iter: 75 | cur_values = cur_values + partition.contiguous().view( 76 | batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 77 | cur_partition = log_sum_exp(cur_values, tag_size) 78 | mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) 79 | masked_cur_partition = cur_partition.masked_select(mask_idx.byte()) 80 | if masked_cur_partition.dim() != 0: 81 | mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) 82 | partition.masked_scatter_(mask_idx.byte(), masked_cur_partition) 83 | cur_values = self.transitions.view(1, tag_size, tag_size).expand( 84 | batch_size, tag_size, tag_size) + partition.contiguous().view( 85 | batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 86 | cur_partition = log_sum_exp(cur_values, tag_size) 87 | final_partition = cur_partition[:, self.END_TAG_IDX] 88 | return final_partition.sum(), scores 89 | 90 | def _viterbi_decode(self, feats, mask=None): 91 | """ 92 | Args: 93 | feats: size=(batch_size, seq_len, self.target_size+2) 94 | mask: size=(batch_size, seq_len) 95 | 96 | Returns: 97 | decode_idx: (batch_size, seq_len), viterbi decode结果 98 | path_score: size=(batch_size, 1), 每个句子的得分 99 | """ 100 | batch_size = feats.size(0) 101 | seq_len = feats.size(1) 102 | tag_size = feats.size(-1) 103 | 104 | length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long() 105 | mask = mask.transpose(1, 0).contiguous() 106 | ins_num = seq_len * batch_size 107 | feats = feats.transpose(1, 0).contiguous().view( 108 | ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) 109 | 110 | scores = feats + self.transitions.view( 111 | 1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) 112 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 113 | 114 | seq_iter = enumerate(scores) 115 | # record the position of the best score 116 | back_points = list() 117 | partition_history = list() 118 | mask = (1 - mask.long()).byte() 119 | try: 120 | _, inivalues = seq_iter.__next__() 121 | except: 122 | _, inivalues = seq_iter.next() 123 | partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1) 124 | partition_history.append(partition) 125 | 126 | for idx, cur_values in seq_iter: 127 | cur_values = cur_values + partition.contiguous().view( 128 | batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 129 | partition, cur_bp = torch.max(cur_values, 1) 130 | partition_history.append(partition.unsqueeze(-1)) 131 | 132 | cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) 133 | back_points.append(cur_bp) 134 | 135 | partition_history = torch.cat(partition_history).view( 136 | seq_len, batch_size, -1).transpose(1, 0).contiguous() 137 | 138 | last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1 139 | last_partition = torch.gather( 140 | partition_history, 1, last_position).view(batch_size, tag_size, 1) 141 | 142 | last_values = last_partition.expand(batch_size, tag_size, tag_size) + \ 143 | self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size) 144 | _, last_bp = torch.max(last_values, 1) 145 | pad_zero = Variable(torch.zeros(batch_size, tag_size)).long() 146 | if self.use_cuda: 147 | pad_zero = pad_zero.cuda() 148 | back_points.append(pad_zero) 149 | back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) 150 | 151 | pointer = last_bp[:, self.END_TAG_IDX] 152 | insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size) 153 | back_points = back_points.transpose(1, 0).contiguous() 154 | 155 | back_points.scatter_(1, last_position, insert_last) 156 | 157 | back_points = back_points.transpose(1, 0).contiguous() 158 | 159 | decode_idx = Variable(torch.LongTensor(seq_len, batch_size)) 160 | if self.use_cuda: 161 | decode_idx = decode_idx.cuda() 162 | decode_idx[-1] = pointer.data 163 | for idx in range(len(back_points)-2, -1, -1): 164 | pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) 165 | decode_idx[idx] = pointer.view(-1).data 166 | path_score = None 167 | decode_idx = decode_idx.transpose(1, 0) 168 | return path_score, decode_idx 169 | 170 | def forward(self, feats, mask=None): 171 | path_score, best_path = self._viterbi_decode(feats, mask) 172 | return path_score, best_path 173 | 174 | def _score_sentence(self, scores, mask, tags): 175 | """ 176 | Args: 177 | scores: size=(seq_len, batch_size, tag_size, tag_size) 178 | mask: size=(batch_size, seq_len) 179 | tags: size=(batch_size, seq_len) 180 | 181 | Returns: 182 | score: 183 | """ 184 | batch_size = scores.size(1) 185 | seq_len = scores.size(0) 186 | tag_size = scores.size(-1) 187 | 188 | new_tags = Variable(torch.LongTensor(batch_size, seq_len)) 189 | if self.use_cuda: 190 | new_tags = new_tags.cuda() 191 | for idx in range(seq_len): 192 | if idx == 0: 193 | new_tags[:, 0] = (tag_size - 2) * tag_size + tags[:, 0] 194 | else: 195 | new_tags[:, idx] = tags[:, idx-1] * tag_size + tags[:, idx] 196 | 197 | end_transition = self.transitions[:, self.END_TAG_IDX].contiguous().view( 198 | 1, tag_size).expand(batch_size, tag_size) 199 | length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long() 200 | end_ids = torch.gather(tags, 1, length_mask-1) 201 | 202 | end_energy = torch.gather(end_transition, 1, end_ids) 203 | 204 | new_tags = new_tags.transpose(1, 0).contiguous().view(seq_len, batch_size, 1) 205 | tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view( 206 | seq_len, batch_size) 207 | tg_energy = tg_energy.masked_select(mask.transpose(1, 0)) 208 | 209 | gold_score = tg_energy.sum() + end_energy.sum() 210 | 211 | return gold_score 212 | 213 | def neg_log_likelihood_loss(self, feats, mask, tags): 214 | """ 215 | Args: 216 | feats: size=(batch_size, seq_len, tag_size) 217 | mask: size=(batch_size, seq_len) 218 | tags: size=(batch_size, seq_len) 219 | """ 220 | batch_size = feats.size(0) 221 | mask = mask.byte() 222 | forward_score, scores = self._forward_alg(feats, mask) 223 | gold_score = self._score_sentence(scores, mask, tags) 224 | if self.average_batch: 225 | return (forward_score - gold_score) / batch_size 226 | return forward_score - gold_score 227 | 228 | 229 | 230 | 231 | --------------------------------------------------------------------------------