├── BILSTM+CRF ├── Io │ └── data_loader.py ├── config │ ├── __pycache__ │ │ └── config.cpython-36.pyc │ └── config.py ├── main │ ├── bilsm_crf.py │ └── main.py ├── net │ ├── bilstm.py │ ├── crf.py │ └── ner.py ├── preprocessing │ └── data_processor.py ├── run_bilstm_crf.py ├── test.py ├── train │ └── train.py └── util │ ├── Logginger.py │ ├── embedding_util.py │ ├── gpu_mem_track.py │ ├── lr_util.py │ ├── plot_util.py │ └── porgress_util.py ├── README.md └── pycrf └── code ├── 2014_cropus_cleaning.ipynb └── CRF.ipynb /BILSTM+CRF/Io/data_loader.py: -------------------------------------------------------------------------------- 1 | """将id格式的输入转换成dataset,并做动态padding""" 2 | 3 | import torch 4 | from torchtext.data import Field, TabularDataset 5 | from torchtext.data import BucketIterator 6 | 7 | import config.config as config 8 | 9 | 10 | def x_tokenize(x): 11 | # 如果加载进来的是已经转成id的文本 12 | # 此处必须将字符串转换成整型 13 | return [int(c) for c in x.split()] 14 | 15 | class BatchIterator(object): 16 | def __init__(self, train_path, valid_path, 17 | batch_size, fix_length=None, 18 | x_var="source", y_var="target"): 19 | self.train_path = train_path 20 | self.valid_path = valid_path 21 | self.batch_size = batch_size 22 | self.fix_length = fix_length 23 | self.x_var = x_var 24 | self.y_vars = y_var 25 | 26 | def create_dataset(self): 27 | SOURCE = Field(sequential=True, tokenize=x_tokenize, 28 | use_vocab=False, batch_first=True, 29 | fix_length=self.fix_length, # 如需静态padding,则设置fix_length, 但要注意要大于文本最大长度 30 | eos_token=None, init_token=None, 31 | include_lengths=True, pad_token=0) 32 | 33 | TARGET = Field(sequential=True, tokenize=x_tokenize, 34 | use_vocab=False, batch_first=True, 35 | fix_length=self.fix_length, # 如需静态padding,则设置fix_length, 但要注意要大于文本最大长度 36 | eos_token=None, init_token=None, 37 | include_lengths=False, pad_token=-1) 38 | 39 | fields = {'source': ('source', SOURCE), 'target': ('target', TARGET)} 40 | 41 | train, valid = TabularDataset.splits( 42 | path=config.ROOT_DIR, 43 | train=self.train_path, validation=self.valid_path, 44 | format="json", 45 | skip_header=False, 46 | fields=fields) 47 | return train, valid 48 | 49 | 50 | def get_iterator(self, train, valid): 51 | train_iter = BucketIterator(train, 52 | batch_size=self.batch_size, 53 | device = torch.device("cpu"), # cpu by -1, gpu by 0 54 | sort_key=lambda x: len(x.source), # field sorted by len 55 | sort_within_batch=True, 56 | repeat=False) 57 | val_iter = BucketIterator(valid, 58 | batch_size=self.batch_size, 59 | device=torch.device("cpu"), # cpu by -1, gpu by 0 60 | sort_key=lambda x: len(x.source), # field sorted by len 61 | sort_within_batch=True, 62 | repeat=False) 63 | 64 | train_iter = BatchWrapper(train_iter, x_var=self.x_var, y_vars=self.y_vars) 65 | val_iter = BatchWrapper(val_iter, x_var=self.x_var, y_vars=self.y_vars) 66 | ### batch = iter(train_iter) 67 | ### batch: ((text, length), y) 68 | return train_iter, val_iter 69 | 70 | 71 | 72 | class BatchWrapper(object): 73 | """对batch做个包装,方便调用,可选择性使用""" 74 | def __init__(self, dl, x_var, y_vars): 75 | self.dl, self.x_var, self.y_vars = dl, x_var, y_vars 76 | 77 | def __iter__(self): 78 | for batch in self.dl: 79 | x = getattr(batch, self.x_var) 80 | target = getattr(batch, self.y_vars) 81 | 82 | source = x[0] 83 | length = x[1] 84 | yield (source, target, length) 85 | 86 | def __len__(self): 87 | return len(self.dl) 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /BILSTM+CRF/config/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/circlePi/NER/77eaf3c566dd92c15fbb3e2105cec6561466988b/BILSTM+CRF/config/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /BILSTM+CRF/config/config.py: -------------------------------------------------------------------------------- 1 | ROOT_DIR = '/home/daizelin/NER' 2 | TRAIN_FILE = 'output/intermediate/train.json' 3 | VALID_FILE = 'output/intermediate/valid.json' 4 | RAW_SOURCE_DATA = 'data/source_BIO_2014_cropus.txt' 5 | RAW_TARGET_DATA = 'data/target_BIO_2014_cropus.txt' 6 | 7 | WORD2ID_FILE = 'output/intermediate/word2id.pkl' 8 | EMBEDDING_FILE = 'embedding/peopel_paper_min_count_1_window_5_300d.word2vec' 9 | LOG_PATH = 'output/logs' 10 | 11 | checkpoint_dir = 'output/checkpoints/bilstm_ner.ckpt' 12 | plot_path = 'output/images/img' 13 | 14 | 15 | # -----------PARAMETERS---------------- 16 | tag_to_ix = { 17 | "B_PER": 0, # 人名 18 | "I_PER": 1, 19 | "B_LOC": 2, # 地点 20 | "I_LOC": 3, 21 | "B_ORG": 4, # 机构 22 | "I_ORG": 5, 23 | "B_T": 6, # 时间 24 | "I_T": 7, 25 | "O": 8, # 其他 26 | "SOS": 9, # 起始符 27 | "EOS":10 # 结束符 28 | } 29 | 30 | labels = [i for i in range(0, 9)] 31 | 32 | flag_words = ['', ''] 33 | max_len = 100 34 | vocab_size = 10000 35 | is_debug = False 36 | 37 | # ------------NET PARAMS---------------- 38 | use_mem_track = False 39 | device = 0 40 | use_cuda = True 41 | word_embedding_dim = 300 42 | batch_size = 128 43 | cell_type ='GRU' 44 | dropout = 0.5 45 | num_epoch = 4 46 | lr_decay_mode = 'custom_decay' 47 | initial_lr = 0.001 -------------------------------------------------------------------------------- /BILSTM+CRF/main/bilsm_crf.py: -------------------------------------------------------------------------------- 1 | from preprocessing.data_processor import data_helper 2 | from Io.data_loader import BatchIterator 3 | 4 | from net.ner import BISLTM_CRF 5 | from train.train import fit 6 | 7 | import config.config as config 8 | from util.porgress_util import ProgressBar 9 | 10 | 11 | def bilstm_crf(): 12 | # 数据预处理 13 | word2id, epoch_size = data_helper(vocab_size=config.vocab_size, max_len=config.max_len, min_freq=1, 14 | valid_size=0.2, random_state=2018, shuffle=True, is_debug=config.is_debug) 15 | 16 | vocab_size = len(word2id) 17 | 18 | # 初始化进度条 19 | pbar = ProgressBar(epoch_size=epoch_size, batch_size=config.batch_size) 20 | 21 | # 加载batch 22 | bi = BatchIterator(config.TRAIN_FILE, 23 | config.VALID_FILE, 24 | config.batch_size, fix_length=config.max_len, 25 | x_var="text", y_var="label") 26 | train, valid = bi.create_dataset() 27 | train_iter, val_iter = bi.get_iterator(train, valid) 28 | 29 | model = BISLTM_CRF( 30 | vocab_size=config.vocab_size, 31 | word_embedding_dim=config.word_embedding_dim, 32 | word2id=word2id, 33 | hidden_size=128, bi_flag=True, 34 | num_layer=1, input_size=config.word_embedding_dim, 35 | cell_type=config.cell_type, 36 | dropout=config.dropout, 37 | num_tag=len(config.labels), 38 | tag2ix=config.tag_to_ix, 39 | checkpoint_dir=config.checkpoint_dir 40 | ) 41 | 42 | # 训练 43 | fit(model, train_iter, val_iter, 44 | config.num_epoch, pbar, 45 | config.lr_decay_mode, 46 | config.initial_lr, verbose=1) -------------------------------------------------------------------------------- /BILSTM+CRF/main/main.py: -------------------------------------------------------------------------------- 1 | from preprocessing.data_processor import data_helper 2 | from Io.data_loader import BatchIterator 3 | 4 | from net.ner import BILSTM_CRF 5 | from train.train import fit 6 | 7 | import config.config as config 8 | from util.porgress_util import ProgressBar 9 | 10 | 11 | def bilstm_crf(): 12 | # 数据预处理 13 | word2id, epoch_size = data_helper(vocab_size=config.vocab_size, max_len=config.max_len, min_freq=1, 14 | valid_size=0.2, random_state=2018, shuffle=True, is_debug=config.is_debug) 15 | 16 | vocab_size = len(word2id) 17 | 18 | # 初始化进度条 19 | pbar = ProgressBar(epoch_size=epoch_size, batch_size=config.batch_size) 20 | 21 | # 加载batch 22 | bi = BatchIterator(config.TRAIN_FILE, 23 | config.VALID_FILE, 24 | config.batch_size, fix_length=config.max_len, 25 | x_var="source", y_var="target") 26 | train, valid = bi.create_dataset() 27 | train_iter, val_iter = bi.get_iterator(train, valid) 28 | 29 | model = BILSTM_CRF( 30 | vocab_size=vocab_size, 31 | word_embedding_dim=config.word_embedding_dim, 32 | word2id=word2id, 33 | hidden_size=128, bi_flag=True, 34 | num_layer=1, input_size=config.word_embedding_dim, 35 | cell_type=config.cell_type, 36 | dropout=config.dropout, 37 | num_tag=len(config.labels), 38 | checkpoint_dir=config.checkpoint_dir 39 | ) 40 | 41 | # 训练 42 | fit(model, train_iter, val_iter, 43 | config.num_epoch, pbar, 44 | config.lr_decay_mode, 45 | config.initial_lr, verbose=1) -------------------------------------------------------------------------------- /BILSTM+CRF/net/bilstm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 6 | import numpy as np 7 | 8 | import config.config as config 9 | 10 | torch.manual_seed(2018) 11 | torch.cuda.manual_seed(2018) 12 | torch.cuda.manual_seed_all(2018) 13 | np.random.seed(2018) 14 | 15 | os.environ["CUDA_VISIBLE_DEVICE"] = "%d"%config.device 16 | 17 | class RNN(nn.Module): 18 | def __init__(self, 19 | hidden_size, bi_flag, 20 | num_layer, 21 | input_size, 22 | cell_type, 23 | dropout, 24 | num_tag): 25 | super(RNN, self).__init__() 26 | self.num_layer = num_layer 27 | self.hidden_size = hidden_size 28 | self.dropout = dropout 29 | 30 | 31 | if torch.cuda.is_available(): 32 | self.device = torch.device("cuda") 33 | 34 | 35 | 36 | if cell_type == "LSTM": 37 | self.rnn_cell = nn.LSTM(input_size=input_size, 38 | hidden_size=hidden_size, 39 | num_layers=num_layer, 40 | batch_first=True, 41 | dropout=dropout, 42 | bidirectional=bi_flag) 43 | elif cell_type == "GRU": 44 | self.rnn_cell = nn.GRU(input_size=input_size, 45 | hidden_size=hidden_size, 46 | num_layers=num_layer, 47 | batch_first=True, 48 | dropout=dropout, 49 | bidirectional=bi_flag) 50 | else: 51 | raise TypeError("RNN: Unknown rnn cell type") 52 | 53 | # 是否双向 54 | self.bi_num = 2 if bi_flag else 1 55 | self.linear = nn.Linear(in_features=hidden_size*self.bi_num, out_features=num_tag) 56 | 57 | def forward(self, embeddings, length): 58 | # 去除padding元素 59 | # embeddings_packed: (batch_size*time_steps, embedding_dim) 60 | embeddings_packed = pack_padded_sequence(embeddings, length, batch_first=True) 61 | output, _ = self.rnn_cell(embeddings_packed) 62 | output, _ = pad_packed_sequence(output, batch_first=True) 63 | output = self.linear(output) 64 | output = F.dropout(output, p=self.dropout, training=self.training) 65 | # output = F.tanh(output) 66 | return output 67 | -------------------------------------------------------------------------------- /BILSTM+CRF/net/crf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class CRF(nn.Module): 7 | """线性条件随机场""" 8 | def __init__(self, num_tag, use_cuda=False): 9 | if num_tag <= 0: 10 | raise ValueError("Invalid value of num_tag: %d" % num_tag) 11 | super(CRF, self).__init__() 12 | self.num_tag = num_tag 13 | self.start_tag = num_tag 14 | self.end_tag = num_tag + 1 15 | self.use_cuda = use_cuda 16 | # 转移矩阵transitions:P_jk 表示从tag_j到tag_k的概率 17 | # P_j* 表示所有从tag_j出发的边 18 | # P_*k 表示所有到tag_k的边 19 | self.transitions = nn.Parameter(torch.Tensor(num_tag + 2, num_tag + 2)) 20 | nn.init.uniform_(self.transitions, -0.1, 0.1) 21 | self.transitions.data[self.end_tag, :] = -10000 # 表示从EOS->其他标签为不可能事件, 如果发生,则产生一个极大的损失 22 | self.transitions.data[:, self.start_tag] = -10000 # 表示从其他标签->SOS为不可能事件, 同上 23 | 24 | def real_path_score(self, features, tags): 25 | """ 26 | features: (time_steps, num_tag) 27 | real_path_score表示真实路径分数 28 | 它由Emission score和Transition score两部分相加组成 29 | Emission score由LSTM输出结合真实的tag决定,表示我们希望由输出得到真实的标签 30 | Transition score则是crf层需要进行训练的参数,它是随机初始化的,表示标签序列前后间的约束关系(转移概率) 31 | Transition矩阵存储的是标签序列相互间的约束关系 32 | 在训练的过程中,希望real_path_score最高,因为这是所有路径中最可能的路径 33 | """ 34 | r = torch.LongTensor(range(features.size(0))) 35 | if self.use_cuda: 36 | pad_start_tags = torch.cat([torch.cuda.LongTensor([self.start_tag]), tags]) 37 | pad_stop_tags = torch.cat([tags, torch.cuda.LongTensor([self.end_tag])]) 38 | r = r.cuda() 39 | else: 40 | pad_start_tags = torch.cat([torch.LongTensor([self.start_tag]), tags]) 41 | pad_stop_tags = torch.cat([tags, torch.LongTensor([self.end_tag])]) 42 | # Transition score + Emission score 43 | score = torch.sum(self.transitions[pad_start_tags, pad_stop_tags]).cpu() + torch.sum(features[r, tags]) 44 | return score 45 | 46 | def all_possible_path_score(self, features): 47 | """ 48 | 计算所有可能的路径分数的log和:前向算法 49 | step1: 将forward列expand成3*3 50 | step2: 将下个单词的emission行expand成3*3 51 | step3: 将1和2和对应位置的转移矩阵相加 52 | step4: 更新forward,合并行 53 | step5: 取forward指数的对数计算total 54 | """ 55 | time_steps = features.size(0) 56 | # 初始化 57 | forward = Variable(torch.zeros(self.num_tag)) # 初始化START_TAG的发射分数为0 58 | if self.use_cuda: 59 | forward = forward.cuda() 60 | for i in range(0, time_steps): # START_TAG -> 1st word -> 2nd word ->...->END_TAG 61 | emission_start = forward.expand(self.num_tag, self.num_tag).t() 62 | emission_end = features[i,:].expand(self.num_tag, self.num_tag) 63 | if i == 0: 64 | trans_score = self.transitions[self.start_tag, :self.start_tag].cpu() 65 | else: 66 | trans_score = self.transitions[:self.start_tag, :self.start_tag].cpu() 67 | sum = emission_start + emission_end + trans_score 68 | forward = log_sum(sum, dim=0) 69 | forward = forward + self.transitions[:self.start_tag, self.end_tag].cpu() # END_TAG 70 | total_score = log_sum(forward, dim=0) 71 | return total_score 72 | 73 | def negative_log_loss(self, inputs, length, tags): 74 | """ 75 | features:(batch_size, time_step, num_tag) 76 | target_function = P_real_path_score/P_all_possible_path_score 77 | = exp(S_real_path_score)/ sum(exp(certain_path_score)) 78 | 我们希望P_real_path_score的概率越高越好,即target_function的值越大越好 79 | 因此,loss_function取其相反数,越小越好 80 | loss_function = -log(target_function) 81 | = -S_real_path_score + log(exp(S_1 + exp(S_2) + exp(S_3) + ...)) 82 | = -S_real_path_score + log(all_possible_path_score) 83 | """ 84 | if not self.use_cuda: 85 | inputs = inputs.cpu() 86 | length = length.cpu() 87 | tags = tags.cpu() 88 | 89 | loss = Variable(torch.tensor(0.), requires_grad=True) 90 | num_chars = torch.sum(length.data).float() 91 | for ix, (features, tag) in enumerate(zip(inputs, tags)): 92 | features = features[:length[ix]] 93 | tag = tag[:length[ix]] 94 | real_score = self.real_path_score(features, tag) 95 | total_score = self.all_possible_path_score(features) 96 | cost = total_score - real_score 97 | loss = loss + cost 98 | return loss/num_chars 99 | 100 | def viterbi(self, features): 101 | time_steps = features.size(0) 102 | forward = Variable(torch.zeros(self.num_tag)) # START_TAG 103 | if self.use_cuda: 104 | forward = forward.cuda() 105 | # back_points 到该点的最大分数 last_points 前一个点的索引 106 | back_points, index_points = [self.transitions[self.start_tag, :self.start_tag].cpu()], [torch.LongTensor([-1]).expand_as(forward)] 107 | for i in range(1, time_steps): # START_TAG -> 1st word -> 2nd word ->...->END_TAG 108 | emission_start = forward.expand(self.num_tag, self.num_tag).t() 109 | emission_end = features[i,:].expand(self.num_tag, self.num_tag) 110 | trans_score = self.transitions[:self.start_tag, :self.start_tag].cpu() 111 | sum = emission_start + emission_end + trans_score 112 | forward, index = torch.max(sum.detach(), dim=0) 113 | back_points.append(forward) 114 | index_points.append(index) 115 | back_points.append(forward + self.transitions[:self.start_tag, self.end_tag].cpu()) # END_TAG 116 | return back_points, index_points 117 | 118 | def get_best_path(self, features): 119 | back_points, index_points = self.viterbi(features) 120 | # 找到线头 121 | best_last_point = argmax(back_points[-1]) 122 | index_points = torch.stack(index_points) # 堆成矩阵 123 | m = index_points.size(0) 124 | # 初始化矩阵 125 | best_path = [best_last_point] 126 | # 循着线头找到其对应的最佳路径 127 | for i in range(m-1, 0, -1): 128 | best_index_point = index_points[i][best_last_point] 129 | best_path.append(best_index_point) 130 | best_last_point = best_index_point 131 | best_path.reverse() 132 | return best_path 133 | 134 | def get_batch_best_path(self, inputs, length): 135 | if not self.use_cuda: 136 | inputs = inputs.cpu() 137 | length = length.cpu() 138 | batch_best_path = [] 139 | max_len = inputs.size(1) 140 | for ix, features in enumerate(inputs): 141 | features = features[:length[ix]] 142 | best_path = self.get_best_path(features) 143 | best_path = torch.Tensor(best_path).long() 144 | best_path = padding(best_path, max_len) 145 | batch_best_path.append(best_path) 146 | batch_best_path = torch.stack(batch_best_path, dim=0) 147 | return batch_best_path 148 | 149 | 150 | def log_sum(matrix, dim): 151 | """ 152 | 前向算法是不断累积之前的结果,这样就会有个缺点 153 | 指数和累积到一定程度后,会超过计算机浮点值的最大值,变成inf,这样取log后也是inf 154 | 为了避免这种情况,我们做了改动: 155 | 1. 用一个合适的值clip去提指数和的公因子,这样就不会使某项变得过大而无法计算 156 | SUM = log(exp(s1)+exp(s2)+...+exp(s100)) 157 | = log{exp(clip)*[exp(s1-clip)+exp(s2-clip)+...+exp(s100-clip)]} 158 | = clip + log[exp(s1-clip)+exp(s2-clip)+...+exp(s100-clip)] 159 | where clip=max 160 | """ 161 | clip_value = torch.max(matrix) # 极大值 162 | clip_value = int(clip_value.data.tolist()) 163 | log_sum_value = clip_value + torch.log(torch.sum(torch.exp(matrix-clip_value), dim=dim)) 164 | return log_sum_value 165 | 166 | 167 | def argmax(matrix, dim=0): 168 | """(0.5, 0.4, 0.3)""" 169 | _, index = torch.max(matrix, dim=dim) 170 | return index 171 | 172 | 173 | def padding(vec, max_len, pad_token=-1): 174 | new_vec = torch.zeros(max_len).long() 175 | new_vec[:vec.size(0)] = vec 176 | new_vec[vec.size(0):] = pad_token 177 | return new_vec 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /BILSTM+CRF/net/ner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | import config.config as config 7 | from util.embedding_util import get_embedding 8 | 9 | from net.bilstm import RNN 10 | from net.crf import CRF 11 | 12 | from sklearn.metrics import f1_score, classification_report 13 | 14 | torch.manual_seed(2018) 15 | torch.cuda.manual_seed(2018) 16 | torch.cuda.manual_seed_all(2018) 17 | np.random.seed(2018) 18 | 19 | os.environ["CUDA_VISIBLE_DEVICE"] = "%d"%config.device 20 | 21 | 22 | class BILSTM_CRF(nn.Module): 23 | def __init__(self, 24 | vocab_size, 25 | word_embedding_dim, 26 | word2id, 27 | hidden_size, bi_flag, 28 | num_layer, input_size, 29 | cell_type, dropout, 30 | num_tag, 31 | checkpoint_dir): 32 | super(BILSTM_CRF, self).__init__() 33 | 34 | self.embedding = nn.Embedding(vocab_size, word_embedding_dim) 35 | for p in self.embedding.parameters(): 36 | p.requires_grad = False 37 | self.embedding.weight.data.copy_(torch.from_numpy(get_embedding(vocab_size, 38 | word_embedding_dim, 39 | word2id))) 40 | 41 | 42 | self.rnn = RNN(hidden_size, bi_flag, 43 | num_layer, input_size, 44 | cell_type, dropout, num_tag) 45 | 46 | self.crf = CRF(num_tag=num_tag) 47 | 48 | self.checkpoint_dir = checkpoint_dir 49 | 50 | def forward(self, inputs, length): 51 | embeddings = self.embedding(inputs) 52 | rnn_output = self.rnn(embeddings, length) # (batch_size, time_steps, num_tag+2) 53 | return rnn_output 54 | 55 | def loss_fn(self, rnn_output, labels, length): 56 | loss = self.crf.negative_log_loss(inputs=rnn_output, length=length, tags=labels) 57 | return loss 58 | 59 | def predict(self, rnn_output, length): 60 | best_path = self.crf.get_batch_best_path(rnn_output, length) 61 | return best_path 62 | 63 | def load(self): 64 | self.load_state_dict(torch.load(self.checkpoint_dir)) 65 | 66 | def save(self): 67 | torch.save(self.state_dict(), self.checkpoint_dir) 68 | 69 | def evaluate(self, y_pred, y_true): 70 | y_true = y_true.cpu().numpy() 71 | y_pred = y_pred.numpy() 72 | f1 = f1_score(y_true, y_pred, labels=config.labels, average="macro") 73 | correct = np.sum((y_true==y_pred).astype(int)) 74 | acc = correct/y_pred.shape[0] 75 | return (acc, f1) 76 | 77 | def class_report(self, y_pred, y_true): 78 | y_true = y_true.cpu().numpy() 79 | y_pred = y_pred.numpy() 80 | classify_report = classification_report(y_true, y_pred) 81 | print('\n\nclassify_report:\n', classify_report) 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /BILSTM+CRF/preprocessing/data_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import pickle 5 | import operator 6 | from glob import glob 7 | from tqdm import tqdm 8 | from collections import Counter 9 | 10 | import config.config as config 11 | from util.Logginger import init_logger 12 | 13 | logger = init_logger("torch", logging_path=config.LOG_PATH) 14 | 15 | 16 | def sent2char(line): 17 | """ 18 | 句子处理成单词 19 | :param line: 原始行 20 | :return: 单词, 标签 21 | """ 22 | res = line.strip('\n').split() 23 | return res 24 | 25 | def word_to_id(word, word2id): 26 | """ 27 | 单词-->ID 28 | :param word: 单词 29 | :param word2id: word2id @type: dict 30 | :return: 31 | """ 32 | return word2id[word] if word in word2id else word2id[config.flag_words[1]] 33 | 34 | 35 | def bulid_vocab(vocab_size, 36 | min_freq=3, 37 | is_debug=False): 38 | """ 39 | 建立词典 40 | :param vocab_size: 词典大小 41 | :param min_freq: 最小词频限制 42 | :param stop_list: 停用词 @type:file_path 43 | :param is_debug: 是否测试模式 @type: bool True:使用很小的数据集进行代码测试 44 | :return: word2id 45 | """ 46 | size = 0 47 | count = Counter() 48 | 49 | with open(os.path.join(config.ROOT_DIR, config.RAW_SOURCE_DATA), 'r') as fr: 50 | logger.info('Building vocab') 51 | for line in tqdm(fr, desc='Build vocab'): 52 | words = sent2char(line) 53 | count.update(words) 54 | size += 1 55 | if is_debug: 56 | limit_train_size = 5000 57 | if size > limit_train_size: 58 | break 59 | count = {k: v for k, v in count.items()} 60 | count = sorted(count.items(), key=operator.itemgetter(1)) 61 | # 词典 62 | vocab = [w[0] for w in count if w[1] >= min_freq] 63 | if vocab_size-2 < len(vocab): 64 | vocab = vocab[:vocab_size-2] 65 | vocab = config.flag_words + vocab 66 | logger.info('vocab_size is %d'%len(vocab)) 67 | # 词典到编号的映射 68 | word2id = {k: v for k, v in zip(vocab, range(0, len(vocab)))} 69 | assert word2id[''] == 0, "ValueError: '' id is not 0" 70 | # print(word2id) 71 | 72 | with open(config.WORD2ID_FILE, 'wb') as fw: 73 | pickle.dump(word2id, fw) 74 | return word2id 75 | 76 | 77 | def train_val_split(X, y, valid_size=0.3, random_state=2018, shuffle=True): 78 | """ 79 | 训练集验证集分割 80 | :param X: sentences 81 | :param y: labels 82 | :param random_state: 随机种子 83 | """ 84 | logger.info('train val split') 85 | 86 | train, valid = [], [] 87 | 88 | data = [] 89 | for data_x, data_y in tqdm(zip(X, y), desc='Merge'): 90 | data.append((data_x, data_y)) 91 | del X, y 92 | 93 | N = len(data) 94 | test_size = int(N * valid_size) 95 | 96 | if shuffle: 97 | random.seed(random_state) 98 | random.shuffle(data) 99 | 100 | valid = data[:test_size] 101 | train = data[test_size:] 102 | 103 | return train, valid 104 | 105 | 106 | def text2id(word2id, maxlen=None, valid_size=0.2, random_state=2018, shuffle=True, is_debug=False): 107 | """ 108 | 训练集文本转ID 109 | :param valid_size: 验证集大小 110 | """ 111 | file_name = os.path.join(config.ROOT_DIR, config.TRAIN_FILE) 112 | if len(glob(file_name)) > 0: 113 | logger.info('Text to id file existed') 114 | epoch_size = int(os.popen('cat %s | wc -l'%file_name).readlines()[0].strip('\n')) 115 | return epoch_size 116 | 117 | logger.info('Text to id') 118 | sentences, labels = [], [] 119 | size = 0 120 | with open(os.path.join(config.ROOT_DIR, config.RAW_SOURCE_DATA), 'r') as fr_1, \ 121 | open(os.path.join(config.ROOT_DIR, config.RAW_TARGET_DATA), 'r') as fr_2: 122 | for sent, target in tqdm(zip(fr_1, fr_2), desc='text_to_id'): 123 | chars = sent2char(sent) 124 | label = sent2char(target) 125 | if len(chars)==0 or len(label)==0: 126 | continue 127 | sent = [word_to_id(word=word, word2id=word2id) for word in chars] 128 | label = [config.tag_to_ix[l] for l in label] 129 | if maxlen: 130 | sent = sent[:maxlen] 131 | label = label[:maxlen] 132 | sentences.append(sent) 133 | labels.append(label) 134 | size += 1 135 | if is_debug: 136 | limit_train_size = 5000 137 | if size > limit_train_size: 138 | break 139 | 140 | train, valid = train_val_split(sentences, labels, 141 | valid_size=valid_size, 142 | random_state=random_state, 143 | shuffle=shuffle) 144 | epoch_size = len(train) 145 | 146 | del sentences, labels 147 | 148 | 149 | with open(config.TRAIN_FILE, 'w') as fw: 150 | for sent, label in train: 151 | sent = ' '.join([str(w) for w in sent]) 152 | label = ' '.join([str(l) for l in label]) 153 | df = {"source": sent, "target": label} 154 | encode_json = json.dumps(df) 155 | print(encode_json, file=fw) 156 | logger.info('Writing train to file done') 157 | 158 | with open(config.VALID_FILE, 'w') as fw: 159 | for sent, label in valid: 160 | sent = ' '.join([str(w) for w in sent]) 161 | label = ' '.join(str(l) for l in label) 162 | df = {"source": sent, "target": label} 163 | encode_json = json.dumps(df) 164 | print(encode_json, file=fw) 165 | logger.info('Writing valid to file done') 166 | return epoch_size 167 | 168 | 169 | def data_helper(vocab_size, max_len, min_freq=3, 170 | valid_size=0.2, random_state=2018, shuffle=True, is_debug=False): 171 | # 判断文件是否已存在 172 | if len(glob(os.path.join(config.ROOT_DIR, config.WORD2ID_FILE))) > 0: 173 | logger.info('Word to id file existed') 174 | with open(os.path.join(config.ROOT_DIR, config.WORD2ID_FILE), 'rb') as fr: 175 | word2id = pickle.load(fr) 176 | else: 177 | word2id= bulid_vocab(vocab_size=vocab_size, min_freq=min_freq, 178 | is_debug=is_debug) 179 | epoch_size = text2id(word2id, valid_size=valid_size, maxlen=max_len, 180 | random_state=random_state, shuffle=shuffle, is_debug=is_debug) 181 | return word2id, epoch_size 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /BILSTM+CRF/run_bilstm_crf.py: -------------------------------------------------------------------------------- 1 | from main.main import bilstm_crf 2 | 3 | if __name__ == '__main__': 4 | bilstm_crf() -------------------------------------------------------------------------------- /BILSTM+CRF/test.py: -------------------------------------------------------------------------------- 1 | from preprocessing.data_processor import data_helper 2 | from Io.data_loader import BatchIterator 3 | 4 | import config.config as config 5 | 6 | def run(): 7 | print('start...') 8 | data_helper(10000, 25, is_debug=False) 9 | 10 | bi = BatchIterator(config.TRAIN_FILE, config.VALID_FILE, config.batch_size, config.max_len) 11 | train, valid = bi.create_dataset() 12 | train_iter, valid_iter = bi.get_iterator(train, valid) 13 | batch = next(iter(train_iter)) 14 | print(batch) 15 | 16 | if __name__ == '__main__': 17 | run() -------------------------------------------------------------------------------- /BILSTM+CRF/train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import inspect 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | 8 | import config.config as config 9 | from util.gpu_mem_track import MemTracker 10 | from util.plot_util import loss_acc_plot 11 | from util.lr_util import lr_update 12 | from util.Logginger import init_logger 13 | 14 | logger = init_logger("torch", logging_path=config.LOG_PATH) 15 | 16 | torch.manual_seed(2018) 17 | torch.cuda.manual_seed(2018) 18 | torch.cuda.manual_seed_all(2018) 19 | 20 | 21 | import warnings 22 | 23 | warnings.filterwarnings('ignore') 24 | 25 | os.environ["CUDA_VISIBLE_DEVICES"] = "%d"%config.device 26 | 27 | 28 | frame = inspect.currentframe() 29 | gpu_tracker = MemTracker(frame) 30 | use_cuda = config.use_cuda if torch.cuda.is_available() else False 31 | 32 | 33 | def weights_init(m): 34 | if isinstance(m, nn.Conv1d): 35 | nn.init.kaiming_normal_(m.weight.data) 36 | 37 | 38 | def fit(model, training_iter, eval_iter, num_epoch, pbar, lr_decay_mode, initial_lr, verbose=1): 39 | model.apply(weights_init) 40 | 41 | if use_cuda: 42 | model.cuda() 43 | optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr) 44 | 45 | train_losses = [] 46 | eval_losses = [] 47 | train_accuracy = [] 48 | eval_accuracy = [] 49 | 50 | history = { 51 | "train_loss": train_losses, 52 | "train_acc": train_accuracy, 53 | "eval_loss": eval_losses, 54 | "eval_acc": eval_accuracy 55 | } 56 | 57 | best_f1 = 0 58 | 59 | start = time.time() 60 | for e in range(num_epoch): 61 | if e > 0: 62 | lr_update(optimizer=optimizer, epoch=e, lr_decay_mode=lr_decay_mode) 63 | 64 | model.train() 65 | for index, (inputs, label, length) in enumerate(training_iter): 66 | if config.use_mem_track: 67 | gpu_tracker.track() 68 | if use_cuda: 69 | inputs = Variable(inputs.cuda()) 70 | label = Variable(label.cuda()) 71 | length = Variable(length.cuda()) 72 | 73 | output = model(inputs, length) 74 | train_loss = model.loss_fn(output, label, length) 75 | optimizer.zero_grad() 76 | train_loss.backward() 77 | optimizer.step() 78 | 79 | with torch.no_grad(): 80 | predicts = model.predict(output, length) 81 | predicts = predicts.view(1, -1).squeeze() 82 | predicts = predicts[predicts != -1] 83 | label = label.view(1, -1).squeeze() 84 | label = label[label != -1] 85 | train_acc, _ = model.evaluate(predicts, label) 86 | pbar.show_process(train_acc, train_loss.detach(), time.time()-start, index) 87 | 88 | if config.use_mem_track: 89 | gpu_tracker.track() 90 | 91 | if use_cuda: 92 | torch.cuda.empty_cache() 93 | 94 | model.eval() 95 | eval_loss, eval_acc, eval_f1 = 0, 0, 0 96 | with torch.no_grad(): 97 | predict_set, label_set = [], [] 98 | count = 0 99 | for eval_inputs, eval_label, eval_length in eval_iter: 100 | if use_cuda: 101 | eval_inputs, eval_label, eval_length = eval_inputs.cuda(), eval_label.cuda(), eval_length.cuda() 102 | output = model(eval_inputs, eval_length) 103 | eval_loss += model.loss_fn(output, eval_label, eval_length).detach() 104 | eval_predicts = model.predict(output, eval_length) 105 | eval_predicts = eval_predicts.view(1, -1).squeeze() 106 | eval_predicts = eval_predicts[eval_predicts != -1] 107 | predict_set.append(eval_predicts) 108 | eval_label = eval_label.view(1, -1).squeeze() 109 | eval_label = eval_label[eval_label != -1] 110 | label_set.append(eval_label) 111 | count += 1 112 | predict_set = torch.cat(predict_set, dim=0) 113 | label_set = torch.cat(label_set, dim=0) 114 | 115 | eval_acc, eval_f1 = model.evaluate(predict_set, label_set) 116 | model.class_report(predict_set, label_set) 117 | 118 | logger.info( 119 | '\n\nEpoch %d - train_loss: %4f - eval_loss: %4f - train_acc:%4f - eval_acc:%4f - eval_f1:%4f\n' 120 | % (e + 1, 121 | train_loss.detach(), 122 | eval_loss/count, 123 | train_acc, 124 | eval_acc, 125 | eval_f1)) 126 | 127 | # 保存最好的模型 128 | if eval_f1 > best_f1: 129 | best_f1 = eval_f1 130 | model.save() 131 | 132 | if e % verbose == 0: 133 | train_losses.append(train_loss.data) 134 | train_accuracy.append(train_acc) 135 | eval_losses.append(eval_loss/count) 136 | eval_accuracy.append(eval_acc/count) 137 | model.save() 138 | loss_acc_plot(history) -------------------------------------------------------------------------------- /BILSTM+CRF/util/Logginger.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | import logging 3 | from logging import Logger 4 | from logging.handlers import TimedRotatingFileHandler 5 | 6 | ''' 7 | 使用方式 8 | from you_logging_filename.py import init_logger 9 | logger = init_logger("dataset",logging_path='') 10 | def you_function(): 11 | logger.info() 12 | logger.error() 13 | 14 | ''' 15 | 16 | 17 | ''' 18 | 日志模块 19 | 1. 同时将日志打印到屏幕跟文件中 20 | 2. 默认值保留近7天日志文件 21 | ''' 22 | def init_logger(logger_name, logging_path): 23 | if logger_name not in Logger.manager.loggerDict: 24 | logger = logging.getLogger(logger_name) 25 | logger.setLevel(logging.DEBUG) 26 | handler = TimedRotatingFileHandler(filename=logging_path+"/all.log",when='D',backupCount = 7) 27 | datefmt = '%Y-%m-%d %H:%M:%S' 28 | format_str = '[%(asctime)s]: %(name)s %(filename)s[line:%(lineno)s] %(levelname)s %(message)s' 29 | formatter = logging.Formatter(format_str,datefmt) 30 | handler.setFormatter(formatter) 31 | handler.setLevel(logging.INFO) 32 | logger.addHandler(handler) 33 | console= logging.StreamHandler() 34 | console.setLevel(logging.INFO) 35 | console.setFormatter(formatter) 36 | logger.addHandler(console) 37 | 38 | handler = TimedRotatingFileHandler(filename=logging_path+"/error.log",when='D',backupCount=7) 39 | datefmt = '%Y-%m-%d %H:%M:%S' 40 | format_str = '[%(asctime)s]: %(name)s %(filename)s[line:%(lineno)s] %(levelname)s %(message)s' 41 | formatter = logging.Formatter(format_str,datefmt) 42 | handler.setFormatter(formatter) 43 | handler.setLevel(logging.ERROR) 44 | logger.addHandler(handler) 45 | logger = logging.getLogger(logger_name) 46 | return logger 47 | 48 | #if __name__ == "__main__": 49 | # logger = init_logger("datatest",logging_path="E:/neo4j-community-3.4.1") 50 | # logger.error('test_error') 51 | # logger.info("test-info") 52 | # logger.warn("test-warn") 53 | 54 | 55 | -------------------------------------------------------------------------------- /BILSTM+CRF/util/embedding_util.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import config.config as config 4 | from util.Logginger import init_logger 5 | 6 | logger = init_logger("torch", logging_path=config.LOG_PATH) 7 | 8 | 9 | def parse_word_vector(word_index): 10 | pre_trained_wordvector = {} 11 | with open(config.EMBEDDING_FILE, 'r') as fr: 12 | for line in fr: 13 | lines = line.strip('\n').split(' ') 14 | word = lines[0] 15 | if word_index.get(word) is not None: 16 | vector = lines[1:] 17 | pre_trained_wordvector[word] = vector 18 | else: 19 | continue 20 | return pre_trained_wordvector 21 | 22 | 23 | def get_embedding(vocab_size, embedding_dim, word2id): 24 | logger.info('Get embedding') 25 | pre_trained_wordector = parse_word_vector(word2id) 26 | embedding_matrix = np.zeros((vocab_size, embedding_dim), dtype=np.float32) 27 | for word, id in tqdm(word2id.items()): 28 | try: 29 | word_vector = pre_trained_wordector[word] 30 | embedding_matrix[id] = word_vector 31 | except: 32 | continue 33 | logger.info('Get embedding done') 34 | return embedding_matrix -------------------------------------------------------------------------------- /BILSTM+CRF/util/gpu_mem_track.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import datetime 3 | import pynvml 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | class MemTracker(object): 10 | """ 11 | Class used to track pytorch memory usage 12 | Arguments: 13 | frame: a frame to detect current py-file runtime 14 | detail(bool, default True): whether the function shows the detail gpu memory usage 15 | path(str): where to save log file 16 | verbose(bool, default False): whether show the trivial exception 17 | device(int): GPU number, default is 0 18 | """ 19 | def __init__(self, frame, detail=True, path='', verbose=False, device=0): 20 | self.frame = frame 21 | self.print_detail = detail 22 | self.last_tensor_sizes = set() 23 | self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt' 24 | self.verbose = verbose 25 | self.begin = True 26 | self.device = device 27 | 28 | self.func_name = frame.f_code.co_name 29 | self.filename = frame.f_globals["__file__"] 30 | if (self.filename.endswith(".pyc") or 31 | self.filename.endswith(".pyo")): 32 | self.filename = self.filename[:-1] 33 | self.module_name = self.frame.f_globals["__name__"] 34 | self.curr_line = self.frame.f_lineno 35 | 36 | def get_tensors(self): 37 | for obj in gc.get_objects(): 38 | try: 39 | if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): 40 | tensor = obj 41 | else: 42 | continue 43 | if tensor.is_cuda: 44 | yield tensor 45 | except Exception as e: 46 | if self.verbose: 47 | print('A trivial exception occured: {}'.format(e)) 48 | 49 | def track(self): 50 | """ 51 | Track the GPU memory usage 52 | """ 53 | pynvml.nvmlInit() 54 | handle = pynvml.nvmlDeviceGetHandleByIndex(self.device) 55 | meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) 56 | self.curr_line = self.frame.f_lineno 57 | where_str = self.module_name + ' ' + self.func_name + ':' + ' line ' + str(self.curr_line) 58 | 59 | with open(self.gpu_profile_fn, 'a+') as f: 60 | 61 | if self.begin: 62 | f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |" 63 | f" Total Used Memory:{meminfo.used/1000**2:<7.1f}Mb\n\n") 64 | self.begin = False 65 | 66 | if self.print_detail is True: 67 | ts_list = [tensor.size() for tensor in self.get_tensors()] 68 | new_tensor_sizes = {(type(x), tuple(x.size()), ts_list.count(x.size()), np.prod(np.array(x.size()))*4/1000**2) 69 | for x in self.get_tensors()} 70 | for t, s, n, m in new_tensor_sizes - self.last_tensor_sizes: 71 | f.write(f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20}\n') 72 | for t, s, n, m in self.last_tensor_sizes - new_tensor_sizes: 73 | f.write(f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} \n') 74 | self.last_tensor_sizes = new_tensor_sizes 75 | 76 | f.write(f"\nAt {where_str:<50}" 77 | f"Total Used Memory:{meminfo.used/1000**2:<7.1f}Mb\n\n") 78 | 79 | pynvml.nvmlShutdown() 80 | 81 | -------------------------------------------------------------------------------- /BILSTM+CRF/util/lr_util.py: -------------------------------------------------------------------------------- 1 | """学习率衰减策略""" 2 | 3 | def exponential_decay(optimizer, epoch): 4 | pass 5 | 6 | 7 | def custom_decay(optimizer, epoch): 8 | if epoch % 2 != 0: 9 | for param_group in optimizer.param_groups: 10 | param_group["lr"] = param_group['lr'] * 0.1 11 | else: 12 | for param_group in optimizer.param_groups: 13 | param_group["lr"] = 0.001 14 | 15 | 16 | def cosine_anneal_decay(optimizer, epoch): 17 | pass 18 | 19 | 20 | def lr_update(optimizer, epoch, lr_decay_mode): 21 | if lr_decay_mode == "constant": 22 | pass 23 | elif lr_decay_mode == "exponential_decay": 24 | exponential_decay(optimizer, epoch) 25 | elif lr_decay_mode == "cosine_anneal_decay": 26 | cosine_anneal_decay(optimizer, epoch) 27 | elif lr_decay_mode == "custom_decay": 28 | custom_decay(optimizer, epoch) 29 | else: 30 | raise TypeError("Unknown lr update mode") 31 | -------------------------------------------------------------------------------- /BILSTM+CRF/util/plot_util.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import config.config as config 3 | 4 | # 无图形界面需要加,否则plt报错 5 | plt.switch_backend('agg') 6 | 7 | 8 | def loss_acc_plot(history): 9 | train_loss = history['train_loss'] 10 | eval_loss = history['eval_loss'] 11 | train_accuracy = history['train_acc'] 12 | eval_accuracy = history['eval_acc'] 13 | 14 | fig = plt.figure(figsize=(16, 10)) 15 | fig.add_subplot(2, 1, 1) 16 | plt.title('loss during train') 17 | plt.xlabel('epochs') 18 | plt.ylabel('loss') 19 | epochs = range(1, len(train_loss)+1) 20 | plt.plot(epochs, train_loss) 21 | plt.plot(epochs, eval_loss) 22 | plt.legend(['train_loss', 'eval_loss']) 23 | 24 | fig.add_subplot(2, 1, 2) 25 | plt.title('accuracy during train') 26 | plt.xlabel('epochs') 27 | plt.ylabel('accuracy') 28 | epochs = range(1, len(train_loss) + 1) 29 | plt.plot(epochs, train_accuracy) 30 | plt.plot(epochs, eval_accuracy) 31 | plt.legend(['train_acc', 'eval_acc']) 32 | 33 | plt.savefig(config.plot_path) 34 | 35 | 36 | if __name__ == '__main__': 37 | history = { 38 | 'train_loss': range(100), 39 | 'eval_loss': range(100), 40 | 'train_accuracy': range(100), 41 | 'eval_accuracy': range(100) 42 | } 43 | loss_acc_plot(history) 44 | 45 | 46 | -------------------------------------------------------------------------------- /BILSTM+CRF/util/porgress_util.py: -------------------------------------------------------------------------------- 1 | """进度条""" 2 | 3 | import sys 4 | 5 | 6 | class ProgressBar(object): 7 | """ 8 | 显示处理进度的类 9 | 调用该类相关函数即可实现处理进度的显示 10 | """ 11 | # 初始化函数,需要知道总共的处理次数 12 | def __init__(self, epoch_size, batch_size, max_arrow=80): 13 | self.epoch_size = epoch_size 14 | self.batch_size = batch_size 15 | self.max_steps = round(epoch_size/batch_size) # 总共处理次数 = round(epoch/batch_size) 16 | self.max_arrow = max_arrow # 进度条的长度 17 | 18 | # 显示函数,根据当前的处理进度i显示进度 19 | # 效果为[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]100.00% 20 | def show_process(self, train_acc, train_loss, used_time, i): 21 | num_arrow = int(i * self.max_arrow / self.max_steps) # 计算显示多少个'>' 22 | num_line = self.max_arrow - num_arrow # 计算显示多少个'-' 23 | percent = i * 100.0 / self.max_steps # 计算完成进度,格式为xx.xx% 24 | num_steps = self.batch_size * i # 当前处理数据条数 25 | process_bar = '%d'%num_steps + '/' + '%d'%self.epoch_size + '[' + '>' * num_arrow + '-' * num_line + ']'\ 26 | + '%.2f' % percent + '%' + ' - train_acc ' + '%.4f'%train_acc + ' - train_loss '+ \ 27 | '%.4f' %train_loss + ' - time '+ '%.1fs'%used_time + '\r' 28 | sys.stdout.write(process_bar) #这两句打印字符到终端 29 | sys.stdout.flush() 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NER 2 | ## DATA 3 | - 2014 people daily newspaper tagged dataset 4 | - For convience, the preprocessed data is free to download at https://pan.baidu.com/s/17sa7a-u-cDXjbW4Rok2Ntg 5 | ## pycrf 6 | - A implement of crf by feature template with pysuite 7 | ## bilstm_crf 8 | - pytorch 0.4.0 9 | - bilstm + crf 10 | 11 | Thanks to the blog of createMoMo for enlightening the implements of crf. 12 | 13 | https://createmomo.github.io/2017/09/12/CRF_Layer_on_the_Top_of_BiLSTM_1/ 14 | 15 | The chinese version of this blog is also available in https://state-of-art.top 16 | 17 | If you have any question, please feel free to contact me at circlepi@gmail.com. 18 | -------------------------------------------------------------------------------- /pycrf/code/2014_cropus_cleaning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "text = open('../data/raw_2014.txt', 'r').read()" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": { 16 | "collapsed": true 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "def q_to_b(q_str):\n", 21 | " \"\"\"\n", 22 | " 功能:非中文文字的全角转半角\n", 23 | " 输入:一个字符串\n", 24 | " 输出:半角字符串\n", 25 | " \n", 26 | " \"\"\"\n", 27 | " b_str = \"\"\n", 28 | " for uchar in q_str:\n", 29 | " inside_code = ord(uchar)\n", 30 | " if inside_code == 12288: \n", 31 | " inside_code = 32\n", 32 | " elif 65374 >= inside_code >= 65281: \n", 33 | " inside_code -= 65248\n", 34 | " b_str += chr(inside_code)\n", 35 | " return b_str\n" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": { 42 | "collapsed": true 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "text = q_to_b(text)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 5, 52 | "metadata": { 53 | "collapsed": true 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "sentences = text.split('\\n')" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 6, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "num_seqence: 286269\n" 70 | ] 71 | } 72 | ], 73 | "source": [ 74 | "print('num_seqence:', len(sentences))" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 7, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "'人民网/nz 1月1日/t 讯/ng 据/p 《/w [纽约/nsf 时报/n]/nz 》/w 报道/v ,/w 美国/nsf 华尔街/nsf 股市/n 在/p 2013年/t 的/ude1 最后/f 一天/mq 继续/v 上涨/vn ,/w 和/cc [全球/n 股市/n]/nz 一样/uyy ,/w 都/d 以/p [最高/a 纪录/n]/nz 或/c 接近/v [最高/a 纪录/n]/nz 结束/v 本/rz 年/qt 的/ude1 交易/vn 。/w '" 86 | ] 87 | }, 88 | "execution_count": 7, 89 | "metadata": {}, 90 | "output_type": "execute_result" 91 | } 92 | ], 93 | "source": [ 94 | "sentences[0]" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 8, 100 | "metadata": { 101 | "collapsed": true 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "import re\n", 106 | "\n", 107 | "def sent_sep(sent):\n", 108 | " \"\"\"将句子和标注分割\"\"\"\n", 109 | " sent = sent.strip()\n", 110 | " \n", 111 | " # 处理 '[]'\n", 112 | " if '[' in sent:\n", 113 | " sent = sent.replace(' [',' ')\n", 114 | " sent = re.sub(r']/[a-z]+', '', sent)\n", 115 | " \n", 116 | " sent = sent.split(' ')\n", 117 | " \n", 118 | " # 处理空\n", 119 | " sent = [item for item in sent if '/' in item]\n", 120 | " \n", 121 | " sents, tags = [], []\n", 122 | " for item in sent:\n", 123 | " tmp = item.split('/')\n", 124 | " try:\n", 125 | " sents.append(tmp[0])\n", 126 | " tags.append(tmp[1])\n", 127 | " except:\n", 128 | " print(sent)\n", 129 | " return sents, tags\n" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 9, 135 | "metadata": { 136 | "collapsed": true 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "new_sentences, new_targets = [], []\n", 141 | "for sent in sentences:\n", 142 | " sents,tags = sent_sep(sent)\n", 143 | " new_sentences.append(sents)\n", 144 | " new_targets.append(tags)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 10, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "['人民网', '1月1日', '讯', '据', '《', '纽约', '时报', '》', '报道', ',', '美国', '华尔街', '股市', '在', '2013年', '的', '最后', '一天', '继续', '上涨', ',', '和', '全球', '股市', '一样', ',', '都', '以', '最高', '纪录', '或', '接近', '最高', '纪录', '结束', '本', '年', '的', '交易', '。']\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "print(new_sentences[0])" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 11, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "['nz', 't', 'ng', 'p', 'w', 'nsf', 'n', 'w', 'v', 'w', 'nsf', 'nsf', 'n', 'p', 't', 'ude1', 'f', 'mq', 'v', 'vn', 'w', 'cc', 'n', 'n', 'uyy', 'w', 'd', 'p', 'a', 'n', 'c', 'v', 'a', 'n', 'v', 'rz', 'qt', 'ude1', 'vn', 'w']\n" 174 | ] 175 | } 176 | ], 177 | "source": [ 178 | "print(new_targets[0])" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 12, 184 | "metadata": { 185 | "collapsed": true 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "def tag_process(tags):\n", 190 | " \"\"\"\n", 191 | " 给 人物 地点 机构 时间 打上 PER,LOC,ORG,T 标签\n", 192 | " 其他为 O\n", 193 | " 采用BIO标记\n", 194 | " \"\"\"\n", 195 | " new_tags = []\n", 196 | " for tag in tags:\n", 197 | " if tag == 't':\n", 198 | " new_tag = 'T'\n", 199 | " elif tag.startswith('nr'):\n", 200 | " new_tag = 'PER'\n", 201 | " elif tag.startswith('ns'):\n", 202 | " new_tag = 'LOC'\n", 203 | " elif tag.startswith('nt'):\n", 204 | " new_tag = 'ORG'\n", 205 | " else:\n", 206 | " new_tag = 'O'\n", 207 | " new_tags.append(new_tag)\n", 208 | " \n", 209 | " tags = []\n", 210 | " for index, tag in enumerate(new_tags):\n", 211 | " if tag in ['PER', 'LOC','T','ORG']:\n", 212 | " if index==0:\n", 213 | " new_tag = 'B_'+tag\n", 214 | " else:\n", 215 | " if new_tags[index-1]=='O':\n", 216 | " new_tag = 'B_'+tag\n", 217 | " else:\n", 218 | " new_tag = 'I_'+tag\n", 219 | " else:\n", 220 | " new_tag = tag\n", 221 | " tags.append(new_tag)\n", 222 | " return tags\n", 223 | " " 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 13, 229 | "metadata": { 230 | "collapsed": true 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "targets = []\n", 235 | "for tags in new_targets:\n", 236 | " tags = tag_process(tags)\n", 237 | " targets.append(tags)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 14, 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "name": "stdout", 247 | "output_type": "stream", 248 | "text": [ 249 | "['O', 'B_T', 'O', 'O', 'O', 'B_LOC', 'O', 'O', 'O', 'O', 'B_LOC', 'I_LOC', 'O', 'O', 'B_T', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n" 250 | ] 251 | } 252 | ], 253 | "source": [ 254 | "print(targets[0])" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 15, 260 | "metadata": { 261 | "collapsed": true 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "import pickle\n", 266 | "with open('../outputs/source_word_2014_cropus.pkl','wb') as fw:\n", 267 | " fw.write(pickle.dumps(new_sentences))\n", 268 | "with open('../outputs/target_word_2014_cropus.pkl','wb') as fw:\n", 269 | " fw.write(pickle.dumps(targets))" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 16, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | "人民网/O 1月1日/B_T 讯/O 据/O 《/O 纽约/B_LOC 时报/O 》/O 报道/O ,/O 美国/B_LOC 华尔街/I_LOC 股市/O 在/O 2013年/B_T 的/O 最后/O 一天/O 继续/O 上涨/O ,/O 和/O 全球/O 股市/O 一样/O ,/O 都/O 以/O 最高/O 纪录/O 或/O 接近/O 最高/O 纪录/O 结束/O 本/O 年/O 的/O 交易/O 。/O\n" 282 | ] 283 | } 284 | ], 285 | "source": [ 286 | "# 合并词和标签, 至此,基于词的处理就结束了\n", 287 | "res = []\n", 288 | "for word, tag in zip(new_sentences[0], targets[0]):\n", 289 | " res.append(word+'/'+tag)\n", 290 | "print(' '.join(res))" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 17, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "['人民网', '1月1日', '讯', '据', '《', '纽约', '时报', '》', '报道', ',', '美国', '华尔街', '股市', '在', '2013年', '的', '最后', '一天', '继续', '上涨', ',', '和', '全球', '股市', '一样', ',', '都', '以', '最高', '纪录', '或', '接近', '最高', '纪录', '结束', '本', '年', '的', '交易', '。']\n" 303 | ] 304 | } 305 | ], 306 | "source": [ 307 | "print(new_sentences[0])" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "### 因为后边的模型是基于字符的,所以这里还要做进一步处理" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 18, 320 | "metadata": { 321 | "collapsed": true 322 | }, 323 | "outputs": [], 324 | "source": [ 325 | "def sep(sent,labels):\n", 326 | " new_sent = []\n", 327 | " new_labels = []\n", 328 | " for s, label in zip(sent,labels):\n", 329 | " s = list(s)\n", 330 | "# print(s)\n", 331 | " n = len(s)\n", 332 | " if n==1:\n", 333 | " new_labels.append(label)\n", 334 | " new_sent.append(s[0])\n", 335 | " else:\n", 336 | " for i in range(n):\n", 337 | " if i > 0:\n", 338 | " label = label.replace('B','I')\n", 339 | " new_labels.append(label)\n", 340 | " new_sent.append(s[i])\n", 341 | "# print('new_sent', new_sent)\n", 342 | "# print('new_labels',new_labels)\n", 343 | " return new_sent,new_labels\n", 344 | " \n", 345 | "def init_seq(sentences, targets):\n", 346 | " \"\"\"将词分割成字符,字符的标签同词的标签\"\"\"\n", 347 | " new_sentences = []\n", 348 | " new_targets = []\n", 349 | " for sent, labels in zip(sentences,targets):\n", 350 | " new_sent, new_labels = sep(sent,labels)\n", 351 | " new_sentences.append(new_sent)\n", 352 | " new_targets.append(new_labels)\n", 353 | " return new_sentences, new_targets \n", 354 | " " 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": 19, 360 | "metadata": {}, 361 | "outputs": [ 362 | { 363 | "name": "stdout", 364 | "output_type": "stream", 365 | "text": [ 366 | "人/O 民/O 网/O 1/B_T 月/I_T 1/I_T 日/I_T 讯/O 据/O 《/O 纽/B_LOC 约/I_LOC 时/O 报/O 》/O 报/O 道/O ,/O 美/B_LOC 国/I_LOC 华/I_LOC 尔/I_LOC 街/I_LOC 股/O 市/O 在/O 2/B_T 0/I_T 1/I_T 3/I_T 年/I_T 的/O 最/O 后/O 一/O 天/O 继/O 续/O 上/O 涨/O ,/O 和/O 全/O 球/O 股/O 市/O 一/O 样/O ,/O 都/O 以/O 最/O 高/O 纪/O 录/O 或/O 接/O 近/O 最/O 高/O 纪/O 录/O 结/O 束/O 本/O 年/O 的/O 交/O 易/O 。/O\n" 367 | ] 368 | } 369 | ], 370 | "source": [ 371 | "new_sent, new_labels = sep(new_sentences[0],targets[0])\n", 372 | "res=[]\n", 373 | "for sent, lable in zip(new_sent, new_labels):\n", 374 | " res.append(sent+'/'+lable)\n", 375 | "print(' '.join(res))" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 20, 381 | "metadata": { 382 | "collapsed": true 383 | }, 384 | "outputs": [], 385 | "source": [ 386 | "new_sentences, new_targets = init_seq(new_sentences,targets)" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 21, 392 | "metadata": {}, 393 | "outputs": [ 394 | { 395 | "name": "stdout", 396 | "output_type": "stream", 397 | "text": [ 398 | "['人', '民', '网', '1', '月', '1', '日', '讯', '据', '《', '纽', '约', '时', '报', '》', '报', '道', ',', '美', '国', '华', '尔', '街', '股', '市', '在', '2', '0', '1', '3', '年', '的', '最', '后', '一', '天', '继', '续', '上', '涨', ',', '和', '全', '球', '股', '市', '一', '样', ',', '都', '以', '最', '高', '纪', '录', '或', '接', '近', '最', '高', '纪', '录', '结', '束', '本', '年', '的', '交', '易', '。']\n" 399 | ] 400 | } 401 | ], 402 | "source": [ 403 | "print(new_sentences[0])" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 22, 409 | "metadata": {}, 410 | "outputs": [ 411 | { 412 | "name": "stdout", 413 | "output_type": "stream", 414 | "text": [ 415 | "['O', 'O', 'O', 'B_T', 'I_T', 'I_T', 'I_T', 'O', 'O', 'O', 'B_LOC', 'I_LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'B_LOC', 'I_LOC', 'I_LOC', 'I_LOC', 'I_LOC', 'O', 'O', 'O', 'B_T', 'I_T', 'I_T', 'I_T', 'I_T', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n" 416 | ] 417 | } 418 | ], 419 | "source": [ 420 | "print(new_targets[0])" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 23, 426 | "metadata": {}, 427 | "outputs": [ 428 | { 429 | "name": "stdout", 430 | "output_type": "stream", 431 | "text": [ 432 | "人/O 民/O 网/O 1/B_T 月/I_T 1/I_T 日/I_T 讯/O 据/O 《/O 纽/B_LOC 约/I_LOC 时/O 报/O 》/O 报/O 道/O ,/O 美/B_LOC 国/I_LOC 华/I_LOC 尔/I_LOC 街/I_LOC 股/O 市/O 在/O 2/B_T 0/I_T 1/I_T 3/I_T 年/I_T 的/O 最/O 后/O 一/O 天/O 继/O 续/O 上/O 涨/O ,/O 和/O 全/O 球/O 股/O 市/O 一/O 样/O ,/O 都/O 以/O 最/O 高/O 纪/O 录/O 或/O 接/O 近/O 最/O 高/O 纪/O 录/O 结/O 束/O 本/O 年/O 的/O 交/O 易/O 。/O\n" 433 | ] 434 | } 435 | ], 436 | "source": [ 437 | "res = []\n", 438 | "for word, tag in zip(new_sentences[0], new_targets[0]):\n", 439 | " res.append(word+'/'+tag)\n", 440 | "print(' '.join(res))" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 24, 446 | "metadata": { 447 | "collapsed": true 448 | }, 449 | "outputs": [], 450 | "source": [ 451 | "# 保存\n", 452 | "with open('../outputs/source_BIO_2014_cropus.txt','w') as fw:\n", 453 | " for sent in new_sentences:\n", 454 | " fw.write(\" \".join(sent)+'\\n')\n", 455 | "with open('../outputs/target_BIO_2014_cropus.txt','w') as fw:\n", 456 | " for sent in new_targets:\n", 457 | " fw.write(\" \".join(sent)+'\\n')\n", 458 | " " 459 | ] 460 | } 461 | ], 462 | "metadata": { 463 | "kernelspec": { 464 | "display_name": "Python 3", 465 | "language": "python", 466 | "name": "python3" 467 | }, 468 | "language_info": { 469 | "codemirror_mode": { 470 | "name": "ipython", 471 | "version": 3 472 | }, 473 | "file_extension": ".py", 474 | "mimetype": "text/x-python", 475 | "name": "python", 476 | "nbconvert_exporter": "python", 477 | "pygments_lexer": "ipython3", 478 | "version": "3.6.2" 479 | } 480 | }, 481 | "nbformat": 4, 482 | "nbformat_minor": 2 483 | } 484 | -------------------------------------------------------------------------------- /pycrf/code/CRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pycrfsuite\n", 10 | "import numpy as np\n", 11 | "from itertools import chain\n", 12 | "from sklearn.metrics import classification_report,confusion_matrix\n", 13 | "import sklearn\n", 14 | "from sklearn.preprocessing import LabelBinarizer" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 加载训练数据" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": { 28 | "collapsed": true 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "text = open('../outputs/source_BIO_2014_cropus.txt').read()\n", 33 | "target = open('../outputs/target_BIO_2014_cropus.txt').read()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "data": { 43 | "text/plain": [ 44 | "'\\n[[sent0],[sent1],[sent2]]\\n'" 45 | ] 46 | }, 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "output_type": "execute_result" 50 | } 51 | ], 52 | "source": [ 53 | "sentences = []\n", 54 | "sent = text.split('\\n')\n", 55 | "for s in sent:\n", 56 | " sentences.append(s.split(\" \"))\n", 57 | "\"\"\"\n", 58 | "[[sent0],[sent1],[sent2]]\n", 59 | "\"\"\"" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "metadata": { 66 | "collapsed": true 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "targets = []\n", 71 | "target = target.split('\\n')\n", 72 | "for t in target:\n", 73 | " targets.append(t.split(\" \"))" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 5, 79 | "metadata": { 80 | "collapsed": true 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "assert len(targets[0])==len(sentences[0]), 'not equal'" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 6, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "sentences: ['人', '民', '网', '1', '月', '1', '日', '讯', '据', '《', '纽', '约', '时', '报', '》', '报', '道', ',', '美', '国', '华', '尔', '街', '股', '市', '在', '2', '0', '1', '3', '年', '的', '最', '后', '一', '天', '继', '续', '上', '涨', ',', '和', '全', '球', '股', '市', '一', '样', ',', '都', '以', '最', '高', '纪', '录', '或', '接', '近', '最', '高', '纪', '录', '结', '束', '本', '年', '的', '交', '易', '。']\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "print('sentences:',sentences[0])" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 7, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "targets: ['O', 'O', 'O', 'B_T', 'I_T', 'I_T', 'I_T', 'O', 'O', 'O', 'B_LOC', 'I_LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'B_LOC', 'I_LOC', 'I_LOC', 'I_LOC', 'I_LOC', 'O', 'O', 'O', 'B_T', 'I_T', 'I_T', 'I_T', 'I_T', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "print('targets:',targets[0])" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## 基于字符的CRF" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 8, 131 | "metadata": { 132 | "collapsed": true 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "# 句子加入分割符号\n", 137 | "# :一句话的开头\n", 138 | "# : 一句话的结尾\n", 139 | "def sent_sep(array):\n", 140 | " new_array = []\n", 141 | " for sent in array:\n", 142 | " sent.insert(0,'')\n", 143 | " sent.append('')\n", 144 | " new_array.append(sent)\n", 145 | " return new_array\n", 146 | "new_sentences = sent_sep(sentences)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 9, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "data": { 156 | "text/plain": [ 157 | "[['', '人', '民'],\n", 158 | " ['人', '民', '网'],\n", 159 | " ['民', '网', '1'],\n", 160 | " ['网', '1', '月'],\n", 161 | " ['1', '月', '1'],\n", 162 | " ['月', '1', '日'],\n", 163 | " ['1', '日', '讯'],\n", 164 | " ['日', '讯', '据'],\n", 165 | " ['讯', '据', '《'],\n", 166 | " ['据', '《', '纽'],\n", 167 | " ['《', '纽', '约'],\n", 168 | " ['纽', '约', '时'],\n", 169 | " ['约', '时', '报'],\n", 170 | " ['时', '报', '》'],\n", 171 | " ['报', '》', '报'],\n", 172 | " ['》', '报', '道'],\n", 173 | " ['报', '道', ','],\n", 174 | " ['道', ',', '美'],\n", 175 | " [',', '美', '国'],\n", 176 | " ['美', '国', '华'],\n", 177 | " ['国', '华', '尔'],\n", 178 | " ['华', '尔', '街'],\n", 179 | " ['尔', '街', '股'],\n", 180 | " ['街', '股', '市'],\n", 181 | " ['股', '市', '在'],\n", 182 | " ['市', '在', '2'],\n", 183 | " ['在', '2', '0'],\n", 184 | " ['2', '0', '1'],\n", 185 | " ['0', '1', '3'],\n", 186 | " ['1', '3', '年'],\n", 187 | " ['3', '年', '的'],\n", 188 | " ['年', '的', '最'],\n", 189 | " ['的', '最', '后'],\n", 190 | " ['最', '后', '一'],\n", 191 | " ['后', '一', '天'],\n", 192 | " ['一', '天', '继'],\n", 193 | " ['天', '继', '续'],\n", 194 | " ['继', '续', '上'],\n", 195 | " ['续', '上', '涨'],\n", 196 | " ['上', '涨', ','],\n", 197 | " ['涨', ',', '和'],\n", 198 | " [',', '和', '全'],\n", 199 | " ['和', '全', '球'],\n", 200 | " ['全', '球', '股'],\n", 201 | " ['球', '股', '市'],\n", 202 | " ['股', '市', '一'],\n", 203 | " ['市', '一', '样'],\n", 204 | " ['一', '样', ','],\n", 205 | " ['样', ',', '都'],\n", 206 | " [',', '都', '以'],\n", 207 | " ['都', '以', '最'],\n", 208 | " ['以', '最', '高'],\n", 209 | " ['最', '高', '纪'],\n", 210 | " ['高', '纪', '录'],\n", 211 | " ['纪', '录', '或'],\n", 212 | " ['录', '或', '接'],\n", 213 | " ['或', '接', '近'],\n", 214 | " ['接', '近', '最'],\n", 215 | " ['近', '最', '高'],\n", 216 | " ['最', '高', '纪'],\n", 217 | " ['高', '纪', '录'],\n", 218 | " ['纪', '录', '结'],\n", 219 | " ['录', '结', '束'],\n", 220 | " ['结', '束', '本'],\n", 221 | " ['束', '本', '年'],\n", 222 | " ['本', '年', '的'],\n", 223 | " ['年', '的', '交'],\n", 224 | " ['的', '交', '易'],\n", 225 | " ['交', '易', '。'],\n", 226 | " ['易', '。', '']]" 227 | ] 228 | }, 229 | "execution_count": 9, 230 | "metadata": {}, 231 | "output_type": "execute_result" 232 | } 233 | ], 234 | "source": [ 235 | "def seg_by_window(sent,window=3):\n", 236 | " \"\"\"采用滑动窗口截取句子,默认窗口大小为3,方便后面特征提取\"\"\"\n", 237 | " n = len(sent)\n", 238 | " flag = 0\n", 239 | " new_sent = []\n", 240 | " while flag < n-window+1:\n", 241 | " new_sent.append(sent[flag:flag+window])\n", 242 | " flag += 1\n", 243 | " return new_sent\n", 244 | "\n", 245 | "sentences = []\n", 246 | "for sent in new_sentences:\n", 247 | " item = seg_by_window(sent)\n", 248 | " sentences.append(item)\n", 249 | "sentences[0]" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 11, 255 | "metadata": { 256 | "collapsed": true 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "def feature_dict(item):\n", 261 | " \"\"\"构造特征模板\"\"\"\n", 262 | " feat = {\n", 263 | " 'w-1':item[0],\n", 264 | " 'w':item[1],\n", 265 | " 'w+1':item[2],\n", 266 | " 'w-1:w':item[0]+item[1],\n", 267 | " 'w:w+1':item[1]+item[2],\n", 268 | " 'bias':1\n", 269 | " }\n", 270 | " return feat\n", 271 | " \n", 272 | "def extract_feature(sentences):\n", 273 | " \"\"\"提取特征\"\"\"\n", 274 | " features = []\n", 275 | " for sent in sentences:\n", 276 | " feature_of_sent = []\n", 277 | " for item in sent:\n", 278 | " feat = feature_dict(item)\n", 279 | " feature_of_sent.append(feat)\n", 280 | " features.append(feature_of_sent)\n", 281 | " return features" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "features = extract_feature(sentences)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 289, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "data": { 300 | "text/plain": [ 301 | "[{'bias': 1,\n", 302 | " 'w': '人',\n", 303 | " 'w+1': '民',\n", 304 | " 'w-1': '',\n", 305 | " 'w-1:w': '人',\n", 306 | " 'w:w+1': '人民'},\n", 307 | " {'bias': 1, 'w': '民', 'w+1': '网', 'w-1': '人', 'w-1:w': '人民', 'w:w+1': '民网'},\n", 308 | " {'bias': 1, 'w': '网', 'w+1': '1', 'w-1': '民', 'w-1:w': '民网', 'w:w+1': '网1'},\n", 309 | " {'bias': 1, 'w': '1', 'w+1': '月', 'w-1': '网', 'w-1:w': '网1', 'w:w+1': '1月'},\n", 310 | " {'bias': 1, 'w': '月', 'w+1': '1', 'w-1': '1', 'w-1:w': '1月', 'w:w+1': '月1'},\n", 311 | " {'bias': 1, 'w': '1', 'w+1': '日', 'w-1': '月', 'w-1:w': '月1', 'w:w+1': '1日'},\n", 312 | " {'bias': 1, 'w': '日', 'w+1': '讯', 'w-1': '1', 'w-1:w': '1日', 'w:w+1': '日讯'},\n", 313 | " {'bias': 1, 'w': '讯', 'w+1': '据', 'w-1': '日', 'w-1:w': '日讯', 'w:w+1': '讯据'},\n", 314 | " {'bias': 1, 'w': '据', 'w+1': '《', 'w-1': '讯', 'w-1:w': '讯据', 'w:w+1': '据《'},\n", 315 | " {'bias': 1, 'w': '《', 'w+1': '纽', 'w-1': '据', 'w-1:w': '据《', 'w:w+1': '《纽'},\n", 316 | " {'bias': 1, 'w': '纽', 'w+1': '约', 'w-1': '《', 'w-1:w': '《纽', 'w:w+1': '纽约'},\n", 317 | " {'bias': 1, 'w': '约', 'w+1': '时', 'w-1': '纽', 'w-1:w': '纽约', 'w:w+1': '约时'},\n", 318 | " {'bias': 1, 'w': '时', 'w+1': '报', 'w-1': '约', 'w-1:w': '约时', 'w:w+1': '时报'},\n", 319 | " {'bias': 1, 'w': '报', 'w+1': '》', 'w-1': '时', 'w-1:w': '时报', 'w:w+1': '报》'},\n", 320 | " {'bias': 1, 'w': '》', 'w+1': '报', 'w-1': '报', 'w-1:w': '报》', 'w:w+1': '》报'},\n", 321 | " {'bias': 1, 'w': '报', 'w+1': '道', 'w-1': '》', 'w-1:w': '》报', 'w:w+1': '报道'},\n", 322 | " {'bias': 1, 'w': '道', 'w+1': ',', 'w-1': '报', 'w-1:w': '报道', 'w:w+1': '道,'},\n", 323 | " {'bias': 1, 'w': ',', 'w+1': '美', 'w-1': '道', 'w-1:w': '道,', 'w:w+1': ',美'},\n", 324 | " {'bias': 1, 'w': '美', 'w+1': '国', 'w-1': ',', 'w-1:w': ',美', 'w:w+1': '美国'},\n", 325 | " {'bias': 1, 'w': '国', 'w+1': '华', 'w-1': '美', 'w-1:w': '美国', 'w:w+1': '国华'},\n", 326 | " {'bias': 1, 'w': '华', 'w+1': '尔', 'w-1': '国', 'w-1:w': '国华', 'w:w+1': '华尔'},\n", 327 | " {'bias': 1, 'w': '尔', 'w+1': '街', 'w-1': '华', 'w-1:w': '华尔', 'w:w+1': '尔街'},\n", 328 | " {'bias': 1, 'w': '街', 'w+1': '股', 'w-1': '尔', 'w-1:w': '尔街', 'w:w+1': '街股'},\n", 329 | " {'bias': 1, 'w': '股', 'w+1': '市', 'w-1': '街', 'w-1:w': '街股', 'w:w+1': '股市'},\n", 330 | " {'bias': 1, 'w': '市', 'w+1': '在', 'w-1': '股', 'w-1:w': '股市', 'w:w+1': '市在'},\n", 331 | " {'bias': 1, 'w': '在', 'w+1': '2', 'w-1': '市', 'w-1:w': '市在', 'w:w+1': '在2'},\n", 332 | " {'bias': 1, 'w': '2', 'w+1': '0', 'w-1': '在', 'w-1:w': '在2', 'w:w+1': '20'},\n", 333 | " {'bias': 1, 'w': '0', 'w+1': '1', 'w-1': '2', 'w-1:w': '20', 'w:w+1': '01'},\n", 334 | " {'bias': 1, 'w': '1', 'w+1': '3', 'w-1': '0', 'w-1:w': '01', 'w:w+1': '13'},\n", 335 | " {'bias': 1, 'w': '3', 'w+1': '年', 'w-1': '1', 'w-1:w': '13', 'w:w+1': '3年'},\n", 336 | " {'bias': 1, 'w': '年', 'w+1': '的', 'w-1': '3', 'w-1:w': '3年', 'w:w+1': '年的'},\n", 337 | " {'bias': 1, 'w': '的', 'w+1': '最', 'w-1': '年', 'w-1:w': '年的', 'w:w+1': '的最'},\n", 338 | " {'bias': 1, 'w': '最', 'w+1': '后', 'w-1': '的', 'w-1:w': '的最', 'w:w+1': '最后'},\n", 339 | " {'bias': 1, 'w': '后', 'w+1': '一', 'w-1': '最', 'w-1:w': '最后', 'w:w+1': '后一'},\n", 340 | " {'bias': 1, 'w': '一', 'w+1': '天', 'w-1': '后', 'w-1:w': '后一', 'w:w+1': '一天'},\n", 341 | " {'bias': 1, 'w': '天', 'w+1': '继', 'w-1': '一', 'w-1:w': '一天', 'w:w+1': '天继'},\n", 342 | " {'bias': 1, 'w': '继', 'w+1': '续', 'w-1': '天', 'w-1:w': '天继', 'w:w+1': '继续'},\n", 343 | " {'bias': 1, 'w': '续', 'w+1': '上', 'w-1': '继', 'w-1:w': '继续', 'w:w+1': '续上'},\n", 344 | " {'bias': 1, 'w': '上', 'w+1': '涨', 'w-1': '续', 'w-1:w': '续上', 'w:w+1': '上涨'},\n", 345 | " {'bias': 1, 'w': '涨', 'w+1': ',', 'w-1': '上', 'w-1:w': '上涨', 'w:w+1': '涨,'},\n", 346 | " {'bias': 1, 'w': ',', 'w+1': '和', 'w-1': '涨', 'w-1:w': '涨,', 'w:w+1': ',和'},\n", 347 | " {'bias': 1, 'w': '和', 'w+1': '全', 'w-1': ',', 'w-1:w': ',和', 'w:w+1': '和全'},\n", 348 | " {'bias': 1, 'w': '全', 'w+1': '球', 'w-1': '和', 'w-1:w': '和全', 'w:w+1': '全球'},\n", 349 | " {'bias': 1, 'w': '球', 'w+1': '股', 'w-1': '全', 'w-1:w': '全球', 'w:w+1': '球股'},\n", 350 | " {'bias': 1, 'w': '股', 'w+1': '市', 'w-1': '球', 'w-1:w': '球股', 'w:w+1': '股市'},\n", 351 | " {'bias': 1, 'w': '市', 'w+1': '一', 'w-1': '股', 'w-1:w': '股市', 'w:w+1': '市一'},\n", 352 | " {'bias': 1, 'w': '一', 'w+1': '样', 'w-1': '市', 'w-1:w': '市一', 'w:w+1': '一样'},\n", 353 | " {'bias': 1, 'w': '样', 'w+1': ',', 'w-1': '一', 'w-1:w': '一样', 'w:w+1': '样,'},\n", 354 | " {'bias': 1, 'w': ',', 'w+1': '都', 'w-1': '样', 'w-1:w': '样,', 'w:w+1': ',都'},\n", 355 | " {'bias': 1, 'w': '都', 'w+1': '以', 'w-1': ',', 'w-1:w': ',都', 'w:w+1': '都以'},\n", 356 | " {'bias': 1, 'w': '以', 'w+1': '最', 'w-1': '都', 'w-1:w': '都以', 'w:w+1': '以最'},\n", 357 | " {'bias': 1, 'w': '最', 'w+1': '高', 'w-1': '以', 'w-1:w': '以最', 'w:w+1': '最高'},\n", 358 | " {'bias': 1, 'w': '高', 'w+1': '纪', 'w-1': '最', 'w-1:w': '最高', 'w:w+1': '高纪'},\n", 359 | " {'bias': 1, 'w': '纪', 'w+1': '录', 'w-1': '高', 'w-1:w': '高纪', 'w:w+1': '纪录'},\n", 360 | " {'bias': 1, 'w': '录', 'w+1': '或', 'w-1': '纪', 'w-1:w': '纪录', 'w:w+1': '录或'},\n", 361 | " {'bias': 1, 'w': '或', 'w+1': '接', 'w-1': '录', 'w-1:w': '录或', 'w:w+1': '或接'},\n", 362 | " {'bias': 1, 'w': '接', 'w+1': '近', 'w-1': '或', 'w-1:w': '或接', 'w:w+1': '接近'},\n", 363 | " {'bias': 1, 'w': '近', 'w+1': '最', 'w-1': '接', 'w-1:w': '接近', 'w:w+1': '近最'},\n", 364 | " {'bias': 1, 'w': '最', 'w+1': '高', 'w-1': '近', 'w-1:w': '近最', 'w:w+1': '最高'},\n", 365 | " {'bias': 1, 'w': '高', 'w+1': '纪', 'w-1': '最', 'w-1:w': '最高', 'w:w+1': '高纪'},\n", 366 | " {'bias': 1, 'w': '纪', 'w+1': '录', 'w-1': '高', 'w-1:w': '高纪', 'w:w+1': '纪录'},\n", 367 | " {'bias': 1, 'w': '录', 'w+1': '结', 'w-1': '纪', 'w-1:w': '纪录', 'w:w+1': '录结'},\n", 368 | " {'bias': 1, 'w': '结', 'w+1': '束', 'w-1': '录', 'w-1:w': '录结', 'w:w+1': '结束'},\n", 369 | " {'bias': 1, 'w': '束', 'w+1': '本', 'w-1': '结', 'w-1:w': '结束', 'w:w+1': '束本'},\n", 370 | " {'bias': 1, 'w': '本', 'w+1': '年', 'w-1': '束', 'w-1:w': '束本', 'w:w+1': '本年'},\n", 371 | " {'bias': 1, 'w': '年', 'w+1': '的', 'w-1': '本', 'w-1:w': '本年', 'w:w+1': '年的'},\n", 372 | " {'bias': 1, 'w': '的', 'w+1': '交', 'w-1': '年', 'w-1:w': '年的', 'w:w+1': '的交'},\n", 373 | " {'bias': 1, 'w': '交', 'w+1': '易', 'w-1': '的', 'w-1:w': '的交', 'w:w+1': '交易'},\n", 374 | " {'bias': 1, 'w': '易', 'w+1': '。', 'w-1': '交', 'w-1:w': '交易', 'w:w+1': '易。'},\n", 375 | " {'bias': 1,\n", 376 | " 'w': '。',\n", 377 | " 'w+1': '',\n", 378 | " 'w-1': '易',\n", 379 | " 'w-1:w': '易。',\n", 380 | " 'w:w+1': '。'}]" 381 | ] 382 | }, 383 | "execution_count": 289, 384 | "metadata": {}, 385 | "output_type": "execute_result" 386 | } 387 | ], 388 | "source": [ 389 | "features[0]" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 290, 395 | "metadata": {}, 396 | "outputs": [ 397 | { 398 | "data": { 399 | "text/plain": [ 400 | "229016" 401 | ] 402 | }, 403 | "execution_count": 290, 404 | "metadata": {}, 405 | "output_type": "execute_result" 406 | } 407 | ], 408 | "source": [ 409 | "train_len = int(len(features)*0.8)\n", 410 | "train_len" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 291, 416 | "metadata": { 417 | "collapsed": true 418 | }, 419 | "outputs": [], 420 | "source": [ 421 | "X_train = features[:train_len]\n", 422 | "y_train = new_targets[:train_len]\n", 423 | "X_test = features[train_len:]\n", 424 | "y_test = new_targets[train_len:]" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 292, 430 | "metadata": {}, 431 | "outputs": [ 432 | { 433 | "name": "stdout", 434 | "output_type": "stream", 435 | "text": [ 436 | "CPU times: user 4min 43s, sys: 1.8 s, total: 4min 44s\n", 437 | "Wall time: 4min 44s\n" 438 | ] 439 | } 440 | ], 441 | "source": [ 442 | "%%time\n", 443 | "trainer = pycrfsuite.Trainer(verbose=False)\n", 444 | "for xseq,yseq in zip(X_train,y_train):\n", 445 | " trainer.append(xseq,yseq)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 293, 451 | "metadata": { 452 | "collapsed": true 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "# 参数设置\n", 457 | "trainer.set_params({'c1':1.0,'c2':1e-3,'max_iterations':100,'feature.possible_transitions':True})" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 295, 463 | "metadata": {}, 464 | "outputs": [ 465 | { 466 | "name": "stdout", 467 | "output_type": "stream", 468 | "text": [ 469 | "CPU times: user 56min 50s, sys: 5.12 s, total: 56min 55s\n", 470 | "Wall time: 56min 50s\n" 471 | ] 472 | } 473 | ], 474 | "source": [ 475 | "%%time\n", 476 | "trainer.train('../outputs/ner_2014_char_based.pycrfsuite')" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": 296, 482 | "metadata": {}, 483 | "outputs": [ 484 | { 485 | "name": "stdout", 486 | "output_type": "stream", 487 | "text": [ 488 | "-rw-rw-r-- 1 daizelin daizelin 470K 9月 20 21:22 ./ner_2018_char_based.pycrfsuite\r\n" 489 | ] 490 | } 491 | ], 492 | "source": [ 493 | "!ls -lh ./ner_2018_char_based.pycrfsuite" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 297, 499 | "metadata": { 500 | "collapsed": true 501 | }, 502 | "outputs": [], 503 | "source": [ 504 | "def bio_classification_report(y_true, y_pred):\n", 505 | " lb = LabelBinarizer()\n", 506 | " y_true_combined = lb.fit_transform(list(chain.from_iterable(y_true)))\n", 507 | " y_pred_combined = lb.transform(list(chain.from_iterable(y_pred)))\n", 508 | " \n", 509 | " tagset = set(lb.classes_) - {'O'}\n", 510 | " tagset = sorted(tagset,key = lambda tag:tag.split('-',1)[::-1])\n", 511 | " class_indices = {cls:idx for idx,cls in enumerate(lb.classes_)}\n", 512 | " \n", 513 | " return classification_report(\n", 514 | " y_true_combined,\n", 515 | " y_pred_combined,\n", 516 | " labels=[class_indices[cls] for cls in tagset],\n", 517 | " target_names = tagset,\n", 518 | " )" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 298, 524 | "metadata": {}, 525 | "outputs": [ 526 | { 527 | "name": "stdout", 528 | "output_type": "stream", 529 | "text": [ 530 | "CPU times: user 1min 4s, sys: 540 ms, total: 1min 4s\n", 531 | "Wall time: 1min 5s\n" 532 | ] 533 | } 534 | ], 535 | "source": [ 536 | "%%time\n", 537 | "tagger = pycrfsuite.Tagger()\n", 538 | "tagger.open('../outputs/ner_2014_char_based.pycrfsuite')\n", 539 | "y_pred = [tagger.tag(xseq) for xseq in X_test]" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": 299, 545 | "metadata": {}, 546 | "outputs": [ 547 | { 548 | "name": "stdout", 549 | "output_type": "stream", 550 | "text": [ 551 | " precision recall f1-score support\n", 552 | "\n", 553 | " 0.00 0.00 0.00 2\n", 554 | " B_LOC 0.97 0.97 0.97 51825\n", 555 | " B_ORG 0.98 0.97 0.98 3687\n", 556 | " B_PER 0.96 0.92 0.94 46640\n", 557 | " B_T 0.98 0.98 0.98 43415\n", 558 | " I_LOC 0.96 0.95 0.96 80188\n", 559 | " I_ORG 0.99 0.96 0.97 8266\n", 560 | " I_PER 0.96 0.91 0.93 90070\n", 561 | " I_T 0.98 0.99 0.98 115917\n", 562 | "\n", 563 | "avg / total 0.99 0.99 0.99 4758408\n", 564 | "\n" 565 | ] 566 | }, 567 | { 568 | "name": "stderr", 569 | "output_type": "stream", 570 | "text": [ 571 | "/home/daizelin/anaconda3/lib/python3.6/site-packages/sklearn/metrics/classification.py:1135: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.\n", 572 | " 'precision', 'predicted', average, warn_for)\n" 573 | ] 574 | } 575 | ], 576 | "source": [ 577 | "print(bio_classification_report(y_test,y_pred))" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": 4, 583 | "metadata": { 584 | "collapsed": true 585 | }, 586 | "outputs": [], 587 | "source": [ 588 | "def predict(s):\n", 589 | " s = list(s)\n", 590 | " s.insert(0,'')\n", 591 | " s.append('')\n", 592 | " sent = seg_by_window(s)\n", 593 | "# print(sent)\n", 594 | " features = extract_feature([sent])\n", 595 | "# print(features)\n", 596 | " tagger = pycrfsuite.Tagger()\n", 597 | " tagger.open('ner_2014_char_based.pycrfsuite')\n", 598 | " y_pred = [tagger.tag(features[0])]\n", 599 | "# print(y_pred)\n", 600 | " return y_pred[0]" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": null, 606 | "metadata": { 607 | "collapsed": true 608 | }, 609 | "outputs": [], 610 | "source": [ 611 | "\n", 612 | "for c, t in zip(input,y_pred[0]):\n", 613 | " res.append(c+'/'+t)\n", 614 | "print(' '.join(res))" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": 342, 620 | "metadata": { 621 | "collapsed": true 622 | }, 623 | "outputs": [], 624 | "source": [ 625 | "sent = '新华社北京9月11日电第二十二届国际检察官联合会年会暨会员代表大会11日上午在北京开幕。国家主席江泽民发来贺信, 对会议召开表示祝贺。'" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": 6, 631 | "metadata": { 632 | "collapsed": true 633 | }, 634 | "outputs": [], 635 | "source": [ 636 | "def run_predcit(sent):\n", 637 | " y = predict(sent)\n", 638 | " res = []\n", 639 | " for c, t in zip(list(sent.strip()),y):\n", 640 | " res.append(c+'/'+t)\n", 641 | " print(' '.join(res))" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 7, 647 | "metadata": {}, 648 | "outputs": [ 649 | { 650 | "name": "stdout", 651 | "output_type": "stream", 652 | "text": [ 653 | "下/O 沙/O 世/O 贸/O 江/B_PER 滨/I_PER 花/O 园/O 骏/O 景/O 湾/O 5/O 幢/O 与/O 6/O 幢/O 之/O 间/O\n" 654 | ] 655 | } 656 | ], 657 | "source": [ 658 | "run_predcit(sent)" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": 350, 664 | "metadata": { 665 | "collapsed": true 666 | }, 667 | "outputs": [], 668 | "source": [ 669 | "sent2 = '1949年,她还曾到“华大”向戏剧系同志学习,也能和解放区的文艺工作者打成一片。'" 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "execution_count": 351, 675 | "metadata": {}, 676 | "outputs": [ 677 | { 678 | "name": "stdout", 679 | "output_type": "stream", 680 | "text": [ 681 | "1/B_T 9/I_T 4/I_T 9/I_T 年/I_T ,/O 她/O 还/O 曾/O 到/O “/O 华/O 大/O ”/O 向/O 戏/O 剧/O 系/O 同/O 志/O 学/O 习/O ,/O 也/O 能/O 和/O 解/B_LOC 放/I_LOC 区/I_LOC 的/O 文/O 艺/O 工/O 作/O 者/O 打/O 成/O 一/O 片/O 。/O\n" 682 | ] 683 | } 684 | ], 685 | "source": [ 686 | "run_predcit(sent2)" 687 | ] 688 | } 689 | ], 690 | "metadata": { 691 | "kernelspec": { 692 | "display_name": "Python 3", 693 | "language": "python", 694 | "name": "python3" 695 | }, 696 | "language_info": { 697 | "codemirror_mode": { 698 | "name": "ipython", 699 | "version": 3 700 | }, 701 | "file_extension": ".py", 702 | "mimetype": "text/x-python", 703 | "name": "python", 704 | "nbconvert_exporter": "python", 705 | "pygments_lexer": "ipython3", 706 | "version": "3.6.2" 707 | } 708 | }, 709 | "nbformat": 4, 710 | "nbformat_minor": 2 711 | } 712 | --------------------------------------------------------------------------------