├── README.md ├── data ├── log │ ├── readme.txt │ ├── test │ │ └── readme.txt │ ├── train │ │ └── readme.txt │ └── validation │ │ └── readme.txt ├── model │ ├── msra │ │ └── readme.txt │ ├── note4 │ │ └── readme.txt │ ├── readme.txt │ ├── resume │ │ └── readme.txt │ └── weibo │ │ └── readme.txt ├── msra │ └── readme.txt ├── note4 │ └── readme.txt ├── result │ ├── msra │ │ └── readme.txt │ ├── note4 │ │ └── readme.txt │ ├── resume │ │ └── readme.txt │ └── weibo │ │ └── readme.txt ├── resume │ └── readme.txt └── weibo │ ├── dev.char.bmes │ ├── test.char.bmes │ └── train.char.bmes ├── msra.py ├── ner ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc ├── functions │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── gaz_opt.cpython-36.pyc │ │ ├── iniatlize.cpython-36.pyc │ │ ├── save_res.cpython-36.pyc │ │ └── utils.cpython-36.pyc │ ├── gaz_opt.py │ ├── iniatlize.py │ ├── save_res.py │ └── utils.py ├── lw │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── cw_ner.cpython-36.pyc │ │ └── tbx_writer.cpython-36.pyc │ ├── cw_ner.py │ └── tbx_writer.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── bilstm.cpython-36.pyc │ │ ├── bilstmcrf.cpython-36.pyc │ │ ├── charbilstm.cpython-36.pyc │ │ ├── charcnn.cpython-36.pyc │ │ ├── crf.cpython-36.pyc │ │ └── latticelstm.cpython-36.pyc │ └── crf.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── gaz_bilstm.cpython-36.pyc │ │ ├── gaz_embed.cpython-36.pyc │ │ └── highway.cpython-36.pyc │ ├── gaz_bilstm.py │ ├── gaz_embed.py │ └── highway.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── alphabet.cpython-36.pyc │ ├── data.cpython-36.pyc │ ├── functions.cpython-36.pyc │ ├── gazetteer.cpython-36.pyc │ ├── metric.cpython-36.pyc │ ├── msra_data.cpython-36.pyc │ ├── trie.cpython-36.pyc │ └── weibo_data.cpython-36.pyc │ ├── alphabet.py │ ├── functions.py │ ├── gazetteer.py │ ├── metric.py │ ├── msra_data.py │ ├── note4_data.py │ ├── resume_data.py │ ├── trie.py │ └── weibo_data.py ├── note4.py ├── resume.py └── weibo.py /README.md: -------------------------------------------------------------------------------- 1 | An Encoding Strategy based Word-Character LSTM for Chinese NER 2 | ============================================================= 3 | An model take both character and words as input for Chinese NER task. 4 | 5 | 6 | Models and results can be found at our NAACL 2019 paper "[An Encoding Strategy based Word-Character LSTM for Chinese NER](https://www.aclweb.org/anthology/papers/N/N19/N19-1247/)". It achieves state-of-the-art performance on most of the dataset. 7 | 8 | 9 | Most of the code is written with reference to Yang Jie's "NCRF++". To know more about "NCRF++", please refer to the paper "NCRF++: An Open-source Neural Sequence Labeling Toolkit". 10 | 11 | 12 | Requirement: 13 | ============================ 14 | Python 3.6 15 | Pytorch: 0.4.0 16 | 17 | if you want to use the tensorboard in our code, you should also install the followings: 18 | tensorboardX 1.2 19 | tensorflow 1.6.0 20 | 21 | 22 | Input format: 23 | ============================= 24 | CoNLL format (prefer BIOES tag scheme), with each character its label for one line. Sentences are splited with a null line. 25 | ```cpp 26 | 美 B-LOC 27 | 国 E-LOC 28 | 的 O 29 | 华 B-PER 30 | 莱 I-PER 31 | 士 E-PER 32 | 33 | 我 O 34 | 跟 O 35 | 他 O 36 | 谈 O 37 | 笑 O 38 | 风 O 39 | 生 O 40 | ``` 41 | 42 | Pretrained Embeddings: 43 | =============== 44 | Character embeddings: [gigword_chn.all.a2b.uni.ite50.vec](https://pan.baidu.com/s/1pLO6T9D) 45 | Word embeddings: [ctb.50d.vec](https://pan.baidu.com/s/1pLO6T9D) 46 | 47 | 48 | Run: 49 | ============ 50 | put each dataset to the data dir, and then simply run the .py file. 51 | For example, to run Weibo experiment, just run: python3 weibo.py 52 | 53 | Cite: 54 | ======== 55 | @inproceedings{liu-etal-2019-encoding, \ 56 | title = "An Encoding Strategy Based Word-Character {LSTM} for {C}hinese {NER}", \ 57 | author = "Liu, Wei and 58 | Xu, Tongge and 59 | Xu, Qinghua and 60 | Song, Jiayu and 61 | Zu, Yueran", \ 62 | booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies", \ 63 | year = "2019", \ 64 | publisher = "Association for Computational Linguistics", \ 65 | url = "https://www.aclweb.org/anthology/N19-1247", \ 66 | pages = "2379--2389" \ 67 | } 68 | 69 | -------------------------------------------------------------------------------- /data/log/readme.txt: -------------------------------------------------------------------------------- 1 | The dir for tensorboard log file. -------------------------------------------------------------------------------- /data/log/test/readme.txt: -------------------------------------------------------------------------------- 1 | for test. -------------------------------------------------------------------------------- /data/log/train/readme.txt: -------------------------------------------------------------------------------- 1 | for train. -------------------------------------------------------------------------------- /data/log/validation/readme.txt: -------------------------------------------------------------------------------- 1 | for dev. -------------------------------------------------------------------------------- /data/model/msra/readme.txt: -------------------------------------------------------------------------------- 1 | for MSRA model checkpoints. -------------------------------------------------------------------------------- /data/model/note4/readme.txt: -------------------------------------------------------------------------------- 1 | for OntoNote4.0 model checkpoints. -------------------------------------------------------------------------------- /data/model/readme.txt: -------------------------------------------------------------------------------- 1 | the dir to save checkpoint files. -------------------------------------------------------------------------------- /data/model/resume/readme.txt: -------------------------------------------------------------------------------- 1 | for resume model checkpoints. -------------------------------------------------------------------------------- /data/model/weibo/readme.txt: -------------------------------------------------------------------------------- 1 | for weibo model checkpoints. -------------------------------------------------------------------------------- /data/msra/readme.txt: -------------------------------------------------------------------------------- 1 | The dir to place MSRA dataset -------------------------------------------------------------------------------- /data/note4/readme.txt: -------------------------------------------------------------------------------- 1 | The dir to place OntoNote4 dataset -------------------------------------------------------------------------------- /data/result/msra/readme.txt: -------------------------------------------------------------------------------- 1 | The output results of MSRA -------------------------------------------------------------------------------- /data/result/note4/readme.txt: -------------------------------------------------------------------------------- 1 | The output results of OntoNote4 -------------------------------------------------------------------------------- /data/result/resume/readme.txt: -------------------------------------------------------------------------------- 1 | The output results of Resume -------------------------------------------------------------------------------- /data/result/weibo/readme.txt: -------------------------------------------------------------------------------- 1 | The output results of Weibo -------------------------------------------------------------------------------- /data/resume/readme.txt: -------------------------------------------------------------------------------- 1 | The dir to place Resume dataset -------------------------------------------------------------------------------- /ner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/__init__.py -------------------------------------------------------------------------------- /ner/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ner/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__init__.py -------------------------------------------------------------------------------- /ner/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ner/functions/__pycache__/gaz_opt.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__pycache__/gaz_opt.cpython-36.pyc -------------------------------------------------------------------------------- /ner/functions/__pycache__/iniatlize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__pycache__/iniatlize.cpython-36.pyc -------------------------------------------------------------------------------- /ner/functions/__pycache__/save_res.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__pycache__/save_res.cpython-36.pyc -------------------------------------------------------------------------------- /ner/functions/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /ner/functions/gaz_opt.py: -------------------------------------------------------------------------------- 1 | __author__ = "liuwei" 2 | 3 | 4 | import numpy as np 5 | import torch 6 | import torch.autograd as autograd 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | def get_batch_gaz(gazs, batch_size, max_seq_len, gpu=False): 11 | """ 12 | rely on the gazs for batch_data, generation a batched gaz tensor for train 13 | Args: 14 | gazs: a list list list, gazs for batch_data 15 | batch_size: the size of batch 16 | max_seq_len: the max seq length 17 | """ 18 | # we need guarantee that every word has the same number gaz, that is use paddding 19 | # record the really length 20 | gaz_seq_length = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long() 21 | max_gaz_len = 1 22 | for i in range(batch_size): 23 | this_gaz_len = len(gazs[i]) 24 | gaz_lens = [len(gazs[i][j]) for j in range(this_gaz_len)] 25 | gaz_seq_length[i, :this_gaz_len] = torch.LongTensor(gaz_lens) 26 | l = max(gaz_lens) 27 | if max_gaz_len < l: 28 | max_gaz_len = l 29 | 30 | # do padding 31 | gaz_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len, max_gaz_len))).long() 32 | for i in range(batch_size): 33 | for j in range(len(gazs[i])): 34 | l = int(gaz_seq_length[i][j]) 35 | gaz_seq_tensor[i, j, :l] = torch.LongTensor(gazs[i][j][:l]) 36 | 37 | # get mask 38 | empty_tensor = (gaz_seq_length == 0).long() 39 | empty_tensor = empty_tensor * max_gaz_len 40 | gaz_seq_length = gaz_seq_length + empty_tensor 41 | 42 | gaz_mask_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len, max_gaz_len))).long() 43 | for i in range(batch_size): 44 | for j in range(max_seq_len): 45 | l = int(gaz_seq_length[i][j]) 46 | gaz_mask_tensor[i, j, :l] = 1 47 | del empty_tensor 48 | 49 | if gpu: 50 | gaz_seq_tensor = gaz_seq_tensor.cuda() 51 | gaz_seq_length = gaz_seq_length.cuda() 52 | gaz_mask_tensor = gaz_mask_tensor.cuda() 53 | 54 | return gaz_seq_tensor, gaz_seq_length, gaz_mask_tensor 55 | -------------------------------------------------------------------------------- /ner/functions/iniatlize.py: -------------------------------------------------------------------------------- 1 | __author__ = "liuwei" 2 | 3 | """ 4 | some init function 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | def init_cnn_weight(cnn_layer, seed=100): 12 | """ 13 | init the weight of cnn net 14 | Args: 15 | cnn_layer: weight.size() = [out_channel, in_channels, kernei size] 16 | seed: int 17 | """ 18 | torch.manual_seed(seed) 19 | nn.init.xavier_uniform_(cnn_layer.weight) 20 | cnn_layer.bias.data.zero_() 21 | 22 | def init_embedding(input_embedding, seed=100): 23 | """ 24 | Args: 25 | input_embedding: the weight of embedding need to be init 26 | seed: int 27 | """ 28 | # torch.manual_seed(seed) 29 | # scope = np.sqrt(3.0 / input_embedding.size(1)) 30 | # nn.init.uniform_(input_embedding, -scope, scope) 31 | 32 | torch.manual_seed(seed) 33 | nn.init.normal_(input_embedding, 0, 0.1) 34 | 35 | def init_linear(input_linear, seed=100): 36 | """ 37 | init the weight of linear net 38 | Args: 39 | input_linear: a linear layer 40 | """ 41 | torch.manual_seed(seed) 42 | scope = np.sqrt(6.0 / (input_linear.weight.size(0) + input_linear.weight.size(1))) 43 | nn.init.uniform_(input_linear.weight, -scope, scope) 44 | if input_linear.bias is not None: 45 | input_linear.bias.data.zero_() 46 | 47 | def init_maxtrix_weight(weights, seed=100): 48 | """ 49 | init the weight of a matrix 50 | """ 51 | torch.manual_seed(seed) 52 | scope = np.sqrt(6.0 / (weights.size(0) + weights.size(1))) 53 | nn.init.uniform_(weights, -scope, scope) 54 | 55 | def init_vector(vector, seed=100): 56 | """ 57 | init a vector, note that vector is 1-D 58 | """ 59 | torch.manual_seed(seed) 60 | v_size = vector.size(0) 61 | scale = np.sqrt(3.0 / v_size) 62 | nn.init.uniform_(vector, -scale, scale) 63 | 64 | def init_highway(highway_layer, seed=100): 65 | """ 66 | init the weight of highway net. do not init the bias 67 | Args: 68 | highway_layer: a highway layer 69 | seed: 70 | """ 71 | torch.manual_seed(seed) 72 | scope = np.sqrt(6.0 / (highway_layer.weight.size(0) + highway_layer.weight.size(1))) 73 | nn.init.uniform_(highway_layer.weight, -scope, scope) 74 | 75 | 76 | def get_embedding_weight(weight_file, vocab_size, word_dim): 77 | """ 78 | get the embedding from weight file, and then use it to init embedding 79 | Args: 80 | weight_file: embedding weight file 81 | vocab_size: the size of vocab 82 | word_dim: the dim of word embedding 83 | """ 84 | -------------------------------------------------------------------------------- /ner/functions/save_res.py: -------------------------------------------------------------------------------- 1 | __author__ = "liuwei" 2 | 3 | 4 | """ 5 | some util functions 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import os 11 | 12 | def save_gold_pred(instances_text, preds, golds, name): 13 | """ 14 | save the gold and pred result to do compare 15 | Args: 16 | instances_text: is a list list, each list is a sentence 17 | preds: is also a list list, each list is a sentence predict tag 18 | golds: is also a list list, each list is a sentence gold tag 19 | name: train? dev? or test 20 | """ 21 | sent_len = len(instances_text) 22 | 23 | assert len(instances_text) == len(preds) 24 | assert len(preds) == len(golds) 25 | 26 | dir = "data/result/resume/" 27 | file_path = os.path.join(dir, name) 28 | num = 1 29 | with open(file_path, 'w') as f: 30 | f.write("wrod gold pred\n") 31 | for sent, gold, pred in zip(instances_text, golds, preds): 32 | # for each sentence 33 | for word, w_g, w_p in zip(sent[0], gold, pred): 34 | if w_g != w_p: 35 | f.write(word) 36 | f.write(" ") 37 | f.write(w_g) 38 | f.write(" ") 39 | f.write(w_p) 40 | f.write(" ") 41 | f.write(str(num)) 42 | f.write("\n") 43 | num += 1 44 | else: 45 | f.write(word) 46 | f.write(" ") 47 | f.write(w_g) 48 | f.write(" ") 49 | f.write(w_p) 50 | f.write("\n") 51 | 52 | f.write("\n") 53 | 54 | 55 | -------------------------------------------------------------------------------- /ner/functions/utils.py: -------------------------------------------------------------------------------- 1 | __author__ = "liuwei" 2 | 3 | 4 | """ 5 | some util functions 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | 11 | def reverse_padded_sequence(inputs, lengths, batch_first=True): 12 | """Reverses sequences according to their lengths. 13 | Inputs should have size ``T x B x *`` if ``batch_first`` is False, or 14 | ``B x T x *`` if True. T is the length of the longest sequence (or larger), 15 | B is the batch size, and * is any number of dimensions (including 0). 16 | Arguments: 17 | inputs (Variable): padded batch of variable length sequences. 18 | lengths (list[int]): list of sequence lengths 19 | batch_first (bool, optional): if True, inputs should be B x T x *. 20 | Returns: 21 | A Variable with the same size as inputs, but with each sequence 22 | reversed according to its length. 23 | """ 24 | if batch_first: 25 | inputs = inputs.transpose(0, 1) 26 | max_length, batch_size = inputs.size(0), inputs.size(1) 27 | if len(lengths) != batch_size: 28 | raise ValueError("inputs is incompatible with lengths.") 29 | ind = [list(reversed(range(0, length))) + list(range(length, max_length)) 30 | for length in lengths] 31 | ind = torch.LongTensor(ind).transpose(0, 1) 32 | for dim in range(2, inputs.dim()): 33 | ind = ind.unsqueeze(dim) 34 | ind = ind.expand_as(inputs) 35 | if inputs.is_cuda: 36 | ind = ind.cuda() 37 | reversed_inputs = torch.gather(inputs, 0, ind) 38 | if batch_first: 39 | reversed_inputs = reversed_inputs.transpose(0, 1) 40 | 41 | return reversed_inputs 42 | 43 | 44 | def random_embedding(vocab_size, embedding_dim): 45 | pretrain_emb = np.empty([vocab_size, embedding_dim]) 46 | scale = np.sqrt(3.0 / embedding_dim) 47 | for index in range(vocab_size): 48 | pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedding_dim]) 49 | return pretrain_emb 50 | -------------------------------------------------------------------------------- /ner/lw/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/lw/__init__.py -------------------------------------------------------------------------------- /ner/lw/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/lw/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ner/lw/__pycache__/cw_ner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/lw/__pycache__/cw_ner.cpython-36.pyc -------------------------------------------------------------------------------- /ner/lw/__pycache__/tbx_writer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/lw/__pycache__/tbx_writer.cpython-36.pyc -------------------------------------------------------------------------------- /ner/lw/cw_ner.py: -------------------------------------------------------------------------------- 1 | __author__ = "liuwei" 2 | 3 | """ 4 | A char word model 5 | """ 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | from ner.modules.gaz_embed import Gaz_Embed 11 | from ner.modules.gaz_bilstm import Gaz_BiLSTM 12 | from ner.model.crf import CRF 13 | from ner.functions.utils import random_embedding 14 | from ner.functions.gaz_opt import get_batch_gaz 15 | from ner.functions.utils import reverse_padded_sequence 16 | 17 | class CW_NER(torch.nn.Module): 18 | def __init__(self, data, type=1): 19 | print("Build char-word based NER Task...") 20 | super(CW_NER, self).__init__() 21 | 22 | self.gpu = data.HP_gpu 23 | label_size = data.label_alphabet_size 24 | self.type = type 25 | self.gaz_embed = Gaz_Embed(data, type) 26 | 27 | self.word_embedding = nn.Embedding(data.word_alphabet.size(), data.word_emb_dim) 28 | 29 | self.lstm = Gaz_BiLSTM(data, data.word_emb_dim + data.gaz_emb_dim, data.HP_hidden_dim) 30 | 31 | self.crf = CRF(data.label_alphabet_size, self.gpu) 32 | 33 | self.hidden2tag = nn.Linear(data.HP_hidden_dim * 2, data.label_alphabet_size + 2) 34 | 35 | if data.pretrain_word_embedding is not None: 36 | self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding)) 37 | else: 38 | self.word_embedding.weight.data.copy_( 39 | random_embedding(data.word_alphabet_size, data.word_emb_dim) 40 | ) 41 | 42 | if self.gpu: 43 | self.word_embedding = self.word_embedding.cuda() 44 | self.hidden2tag = self.hidden2tag.cuda() 45 | 46 | 47 | def neg_log_likelihood_loss(self, gaz_list, reverse_gaz_list, word_inputs, word_seq_lengths, batch_label, mask): 48 | """ 49 | get the neg_log_likelihood_loss 50 | Args: 51 | gaz_list: the batch data's gaz, for every chinese char 52 | reverse_gaz_list: the reverse list 53 | word_inputs: word input ids, [batch_size, seq_len] 54 | word_seq_lengths: [batch_size] 55 | batch_label: [batch_size, seq_len] 56 | mask: [batch_size, seq_len] 57 | """ 58 | batch_size = word_inputs.size(0) 59 | seq_len = word_inputs.size(1) 60 | lengths = list(map(int, word_seq_lengths)) 61 | 62 | # print('one ', reverse_gaz_list[0][:10]) 63 | 64 | ## get batch gaz ids 65 | batch_gaz_ids, batch_gaz_length, batch_gaz_mask = get_batch_gaz(reverse_gaz_list, batch_size, seq_len, self.gpu) 66 | 67 | # print('two ', batch_gaz_ids[0][:10]) 68 | 69 | reverse_batch_gaz_ids, reverse_batch_gaz_length, reverse_batch_gaz_mask = get_batch_gaz(gaz_list, batch_size, seq_len, self.gpu) 70 | reverse_batch_gaz_ids = reverse_padded_sequence(reverse_batch_gaz_ids, lengths) 71 | reverse_batch_gaz_length = reverse_padded_sequence(reverse_batch_gaz_length, lengths) 72 | reverse_batch_gaz_mask = reverse_padded_sequence(reverse_batch_gaz_mask, lengths) 73 | 74 | ## word embedding 75 | word_embs = self.word_embedding(word_inputs) 76 | reverse_word_embs = reverse_padded_sequence(word_embs, lengths) 77 | 78 | ## gaz embedding 79 | gaz_embs = self.gaz_embed((batch_gaz_ids, batch_gaz_length, batch_gaz_mask)) 80 | reverse_gaz_embs = self.gaz_embed((reverse_batch_gaz_ids, reverse_batch_gaz_length, reverse_batch_gaz_mask)) 81 | # print(gaz_embs[0][0][:20]) 82 | 83 | ## lstm 84 | forward_inputs = torch.cat((word_embs, gaz_embs), dim=-1) 85 | backward_inputs = torch.cat((reverse_word_embs, reverse_gaz_embs), dim=-1) 86 | 87 | lstm_outs, _ = self.lstm((forward_inputs, backward_inputs), word_seq_lengths) 88 | 89 | ## hidden2tag 90 | outs = self.hidden2tag(lstm_outs) 91 | 92 | ## crf and loss 93 | loss = self.crf.neg_log_likelihood_loss(outs, mask, batch_label) 94 | _, tag_seq = self.crf._viterbi_decode(outs, mask) 95 | 96 | return loss, tag_seq 97 | 98 | def forward(self, gaz_list, reverse_gaz_list, word_inputs, word_seq_lengths, mask): 99 | """ 100 | Args: 101 | gaz_list: the forward gaz_list 102 | reverse_gaz_list: the backward gaz list 103 | word_inputs: word ids 104 | word_seq_lengths: each sentence length 105 | mask: sentence mask 106 | """ 107 | batch_size = word_inputs.size(0) 108 | seq_len = word_inputs.size(1) 109 | lengths = list(map(int, word_seq_lengths)) 110 | 111 | ## get batch gaz ids 112 | batch_gaz_ids, batch_gaz_length, batch_gaz_mask = get_batch_gaz(reverse_gaz_list, batch_size, seq_len, self.gpu) 113 | reverse_batch_gaz_ids, reverse_batch_gaz_length, reverse_batch_gaz_mask = get_batch_gaz(gaz_list, batch_size, seq_len, self.gpu) 114 | reverse_batch_gaz_ids = reverse_padded_sequence(reverse_batch_gaz_ids, lengths) 115 | reverse_batch_gaz_length = reverse_padded_sequence(reverse_batch_gaz_length, lengths) 116 | reverse_batch_gaz_mask = reverse_padded_sequence(reverse_batch_gaz_mask, lengths) 117 | 118 | ## word embedding 119 | word_embs = self.word_embedding(word_inputs) 120 | reverse_word_embs = reverse_padded_sequence(word_embs, lengths) 121 | 122 | ## gaz embedding 123 | gaz_embs = self.gaz_embed((batch_gaz_ids, batch_gaz_length, batch_gaz_mask)) 124 | reverse_gaz_embs = self.gaz_embed((reverse_batch_gaz_ids, reverse_batch_gaz_length, reverse_batch_gaz_mask)) 125 | 126 | ## lstm 127 | forward_inputs = torch.cat((word_embs, gaz_embs), dim=-1) 128 | backward_inputs = torch.cat((reverse_word_embs, reverse_gaz_embs), dim=-1) 129 | 130 | lstm_outs, _ = self.lstm((forward_inputs, backward_inputs), word_seq_lengths) 131 | 132 | ## hidden2tag 133 | outs = self.hidden2tag(lstm_outs) 134 | 135 | ## crf and loss 136 | _, tag_seq = self.crf._viterbi_decode(outs, mask) 137 | 138 | return tag_seq 139 | -------------------------------------------------------------------------------- /ner/lw/tbx_writer.py: -------------------------------------------------------------------------------- 1 | __author__ = "liuwei" 2 | 3 | """ 4 | write the loss and accu to the tensorboardx 5 | """ 6 | from typing import Any 7 | from tensorboardX import SummaryWriter 8 | 9 | class TensorboardWriter: 10 | """ 11 | wrap the SummaryWriter, print the value to the tensorboard 12 | """ 13 | def __init__(self, train_log: SummaryWriter = None, validation_log: SummaryWriter = None, 14 | test_log: SummaryWriter = None): 15 | self._train_log = train_log 16 | self._validation_log = validation_log 17 | self._test_log = test_log 18 | 19 | @staticmethod 20 | def _item(value: Any): 21 | if hasattr(value, 'item'): 22 | val = value.item() 23 | else: 24 | val = value 25 | 26 | return val 27 | 28 | def add_train_scalar(self, name: str, value: float, global_step: int): 29 | """ 30 | add train scalar value to tensorboardX 31 | Args: 32 | name: the name of the value 33 | value: 34 | global_step: the steps 35 | """ 36 | if self._train_log is not None: 37 | self._train_log.add_scalar(name, self._item(value), global_step) 38 | 39 | 40 | def add_validation_scalar(self, name: str, value: float, global_step: int): 41 | """ 42 | add validation scalar value to tensorboardX 43 | Args: 44 | name: the name of the value 45 | value: 46 | global_step: 47 | """ 48 | if self._validation_log is not None: 49 | self._validation_log.add_scalar(name, self._item(value), global_step) 50 | 51 | def add_test_scalar(self, name: str, value: float, global_step: int): 52 | """ 53 | add test scalar value to tensorboardX 54 | Args: 55 | name: the name of the value 56 | value: 57 | global_step: 58 | """ 59 | if self._test_log is not None: 60 | self._test_log.add_scalar(name, self._item(value), global_step) 61 | 62 | -------------------------------------------------------------------------------- /ner/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__init__.py -------------------------------------------------------------------------------- /ner/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ner/model/__pycache__/bilstm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/bilstm.cpython-36.pyc -------------------------------------------------------------------------------- /ner/model/__pycache__/bilstmcrf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/bilstmcrf.cpython-36.pyc -------------------------------------------------------------------------------- /ner/model/__pycache__/charbilstm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/charbilstm.cpython-36.pyc -------------------------------------------------------------------------------- /ner/model/__pycache__/charcnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/charcnn.cpython-36.pyc -------------------------------------------------------------------------------- /ner/model/__pycache__/crf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/crf.cpython-36.pyc -------------------------------------------------------------------------------- /ner/model/__pycache__/latticelstm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/latticelstm.cpython-36.pyc -------------------------------------------------------------------------------- /ner/model/crf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.autograd as autograd 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | START_TAG = -2 9 | STOP_TAG = -1 10 | 11 | 12 | # Compute log sum exp in a numerically stable way for the forward algorithm 13 | def log_sum_exp(vec, m_size): 14 | """ 15 | calculate log of exp sum 16 | args: 17 | vec (batch_size, vanishing_dim, hidden_dim) : input tensor 18 | m_size : hidden_dim 19 | return: 20 | batch_size, hidden_dim 21 | """ 22 | _, idx = torch.max(vec, 1) # B * 1 * M 23 | max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M 24 | return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) # B * M 25 | 26 | class CRF(nn.Module): 27 | 28 | def __init__(self, tagset_size, gpu): 29 | super(CRF, self).__init__() 30 | print("build batched crf...") 31 | self.gpu = gpu 32 | # Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j. 33 | self.average_batch = False 34 | self.tagset_size = tagset_size 35 | # # We add 2 here, because of START_TAG and STOP_TAG 36 | # # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag 37 | init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2) 38 | # init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2) 39 | # init_transitions[:,START_TAG] = -1000.0 40 | # init_transitions[STOP_TAG,:] = -1000.0 41 | # init_transitions[:,0] = -1000.0 42 | # init_transitions[0,:] = -1000.0 43 | if self.gpu: 44 | init_transitions = init_transitions.cuda() 45 | self.transitions = nn.Parameter(init_transitions) 46 | 47 | # self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2)) 48 | # self.transitions.data.zero_() 49 | 50 | def _calculate_PZ(self, feats, mask): 51 | """ 52 | input: 53 | feats: (batch, seq_len, self.tag_size+2) 54 | masks: (batch, seq_len) 55 | """ 56 | batch_size = feats.size(0) 57 | seq_len = feats.size(1) 58 | tag_size = feats.size(2) 59 | # print feats.view(seq_len, tag_size) 60 | assert(tag_size == self.tagset_size+2) 61 | mask = mask.transpose(1,0).contiguous() 62 | ins_num = seq_len * batch_size 63 | ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) 64 | feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size) 65 | ## need to consider start 66 | scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) 67 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 68 | # build iter 69 | seq_iter = enumerate(scores) 70 | #_, inivalues = seq_iter.next() # bat_size * from_target_size * to_target_size 71 | _, inivalues = next(seq_iter) 72 | # only need start from start_tag 73 | partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size 74 | 75 | ## add start score (from start to all tag, duplicate to batch_size) 76 | # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1) 77 | # iter over last scores 78 | for idx, cur_values in seq_iter: 79 | # previous to_target is current from_target 80 | # partition: previous results log(exp(from_target)), #(batch_size * from_target) 81 | # cur_values: bat_size * from_target * to_target 82 | 83 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 84 | cur_partition = log_sum_exp(cur_values, tag_size) 85 | # print cur_partition.data 86 | 87 | # (bat_size * from_target * to_target) -> (bat_size * to_target) 88 | # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1) 89 | mask_idx = mask[idx, :] 90 | mask_idx = mask_idx.view(batch_size, 1).expand(batch_size, tag_size) 91 | 92 | ## effective updated partition part, only keep the partition value of mask value = 1 93 | masked_cur_partition = cur_partition.masked_select(mask_idx) 94 | ## let mask_idx broadcastable, to disable warning 95 | mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) 96 | 97 | ## replace the partition where the maskvalue=1, other partition value keeps the same 98 | partition.masked_scatter_(mask_idx, masked_cur_partition) 99 | # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG 100 | cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 101 | cur_partition = log_sum_exp(cur_values, tag_size) 102 | final_partition = cur_partition[:, STOP_TAG] 103 | return final_partition.sum(), scores 104 | 105 | 106 | def _viterbi_decode(self, feats, mask): 107 | """ 108 | input: 109 | feats: (batch, seq_len, self.tag_size+2) 110 | mask: (batch, seq_len) 111 | output: 112 | decode_idx: (batch, seq_len) decoded sequence 113 | path_score: (batch, 1) corresponding score for each sequence (to be implementated) 114 | """ 115 | batch_size = feats.size(0) 116 | seq_len = feats.size(1) 117 | tag_size = feats.size(2) 118 | assert(tag_size == self.tagset_size+2) 119 | ## calculate sentence length for each sentence 120 | length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() 121 | ## mask to (seq_len, batch_size) 122 | mask = mask.transpose(1,0).contiguous() 123 | ins_num = seq_len * batch_size 124 | ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) 125 | feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) 126 | ## need to consider start 127 | scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) 128 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 129 | 130 | # build iter 131 | seq_iter = enumerate(scores) 132 | ## record the position of best score 133 | back_points = list() 134 | partition_history = list() 135 | 136 | 137 | ## reverse mask (bug for mask = 1- mask, use this as alternative choice) 138 | # mask = 1 + (-1)*mask 139 | mask = (1 - mask.long()).byte() 140 | #_, inivalues = seq_iter.next() # bat_size * from_target_size * to_target_size 141 | _, inivalues = next(seq_iter) 142 | # only need start from start_tag 143 | partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size 144 | partition_history.append(partition) 145 | # iter over last scores 146 | for idx, cur_values in seq_iter: 147 | # previous to_target is current from_target 148 | # partition: previous results log(exp(from_target)), #(batch_size * from_target) 149 | # cur_values: batch_size * from_target * to_target 150 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 151 | ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG 152 | partition, cur_bp = torch.max(cur_values, 1) 153 | 154 | # add by liuwei 155 | partition = partition.view(partition.size()[0], partition.size()[1], 1) 156 | 157 | partition_history.append(partition) 158 | ## cur_bp: (batch_size, tag_size) max source score position in current tag 159 | ## set padded label as 0, which will be filtered in post processing 160 | cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) 161 | back_points.append(cur_bp) 162 | ### add score to final STOP_TAG 163 | partition_history = torch.cat(partition_history,0).view(seq_len, batch_size,-1).transpose(1,0).contiguous() ## (batch_size, seq_len. tag_size) 164 | ### get the last position for each setences, and select the last partitions using gather() 165 | last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1 166 | last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1) 167 | ### calculate the score from last partition to end state (and then select the STOP_TAG from it) 168 | last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) 169 | _, last_bp = torch.max(last_values, 1) 170 | pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long() 171 | if self.gpu: 172 | pad_zero = pad_zero.cuda() 173 | back_points.append(pad_zero) 174 | back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) 175 | 176 | ## select end ids in STOP_TAG 177 | pointer = last_bp[:, STOP_TAG] 178 | insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size) 179 | back_points = back_points.transpose(1,0).contiguous() 180 | ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values 181 | # print "lp:",last_position 182 | # print "il:",insert_last 183 | back_points.scatter_(1, last_position, insert_last) 184 | # print "bp:",back_points 185 | # exit(0) 186 | back_points = back_points.transpose(1,0).contiguous() 187 | ## decode from the end, padded position ids are 0, which will be filtered if following evaluation 188 | decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size)) 189 | if self.gpu: 190 | decode_idx = decode_idx.cuda() 191 | decode_idx[-1] = pointer.data 192 | for idx in range(len(back_points)-2, -1, -1): 193 | pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) 194 | decode_idx[idx] = pointer.data.view(-1) 195 | path_score = None 196 | decode_idx = decode_idx.transpose(1,0) 197 | return path_score, decode_idx 198 | 199 | 200 | 201 | def forward(self, feats): 202 | path_score, best_path = self._viterbi_decode(feats) 203 | return path_score, best_path 204 | 205 | 206 | def _score_sentence(self, scores, mask, tags): 207 | """ 208 | input: 209 | scores: variable (seq_len, batch, tag_size, tag_size) 210 | mask: (batch, seq_len) 211 | tags: tensor (batch, seq_len) 212 | output: 213 | score: sum of score for gold sequences within whole batch 214 | """ 215 | # Gives the score of a provided tag sequence 216 | batch_size = scores.size(1) 217 | seq_len = scores.size(0) 218 | tag_size = scores.size(2) 219 | ## convert tag value into a new format, recorded label bigram information to index 220 | new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len)) 221 | if self.gpu: 222 | new_tags = new_tags.cuda() 223 | for idx in range(seq_len): 224 | if idx == 0: 225 | ## start -> first score 226 | new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0] 227 | 228 | else: 229 | new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx] 230 | 231 | ## transition for label to STOP_TAG 232 | end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) 233 | ## length for batch, last word position = length - 1 234 | length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() 235 | ## index the label id of last word 236 | end_ids = torch.gather(tags, 1, length_mask - 1) 237 | 238 | ## index the transition score for end_id to STOP_TAG 239 | end_energy = torch.gather(end_transition, 1, end_ids) 240 | 241 | ## convert tag as (seq_len, batch_size, 1) 242 | new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1) 243 | ### need convert tags id to search from 400 positions of scores 244 | tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size 245 | ## mask transpose to (seq_len, batch_size) 246 | tg_energy = tg_energy.masked_select(mask.transpose(1,0)) 247 | 248 | # ## calculate the score from START_TAG to first label 249 | # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) 250 | # start_energy = torch.gather(start_transition, 1, tags[0,:]) 251 | 252 | ## add all score together 253 | # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() 254 | gold_score = tg_energy.sum() + end_energy.sum() 255 | return gold_score 256 | 257 | def neg_log_likelihood_loss(self, feats, mask, tags): 258 | # nonegative log likelihood 259 | batch_size = feats.size(0) 260 | forward_score, scores = self._calculate_PZ(feats, mask) 261 | gold_score = self._score_sentence(scores, mask, tags) 262 | # print "batch, f:", forward_score.data[0], " g:", gold_score.data[0], " dis:", forward_score.data[0] - gold_score.data[0] 263 | # exit(0) 264 | if self.average_batch: 265 | return (forward_score - gold_score)/batch_size 266 | else: 267 | return forward_score - gold_score 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | -------------------------------------------------------------------------------- /ner/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/modules/__init__.py -------------------------------------------------------------------------------- /ner/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ner/modules/__pycache__/gaz_bilstm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/modules/__pycache__/gaz_bilstm.cpython-36.pyc -------------------------------------------------------------------------------- /ner/modules/__pycache__/gaz_embed.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/modules/__pycache__/gaz_embed.cpython-36.pyc -------------------------------------------------------------------------------- /ner/modules/__pycache__/highway.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/modules/__pycache__/highway.cpython-36.pyc -------------------------------------------------------------------------------- /ner/modules/gaz_bilstm.py: -------------------------------------------------------------------------------- 1 | __author__ = "liuwei" 2 | 3 | """ 4 | A bilstm use for the sentences contain gaz msg 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | from ner.functions.utils import reverse_padded_sequence 12 | 13 | class Gaz_BiLSTM(torch.nn.Module): 14 | def __init__(self, data, input_size, hidden_size): 15 | print("Build the Gaz bilstm...") 16 | super(Gaz_BiLSTM, self).__init__() 17 | 18 | self.gpu = data.HP_gpu 19 | self.batch_size = data.HP_batch_size 20 | self.drop = nn.Dropout(data.HP_dropout) 21 | self.droplstm = nn.Dropout(data.HP_dropout) 22 | 23 | self.f_lstm = nn.LSTM(input_size, hidden_size, num_layers=data.HP_lstm_layer, batch_first=True) 24 | self.b_lstm = nn.LSTM(input_size, hidden_size, num_layers=data.HP_lstm_layer, batch_first=True) 25 | 26 | if self.gpu: 27 | self.drop = self.drop.cuda() 28 | self.droplstm = self.droplstm.cuda() 29 | self.f_lstm = self.f_lstm.cuda() 30 | self.b_lstm = self.b_lstm.cuda() 31 | 32 | 33 | def get_lstm_features(self, inputs, word_seq_length): 34 | """ 35 | get the output of bilstm. Note that inputs is forward and backward inputs 36 | Args: 37 | inputs: a tuple, each item size is [batch_size, sent_len, dim] 38 | word_seq_length: a [batch_size] tensor 39 | """ 40 | # lengths = list(map(int, word_seq_length)) 41 | f_inputs, b_inputs = inputs 42 | f_inputs = self.drop(f_inputs) 43 | b_inputs = self.drop(b_inputs) 44 | 45 | f_lstm_out, f_hidden = self.f_lstm(f_inputs) 46 | b_lstm_out, b_hidden = self.b_lstm(b_inputs) 47 | # b_lstm_out = reverse_padded_sequence(b_lstm_out, lengths) 48 | 49 | f_lstm_out = self.droplstm(f_lstm_out) 50 | b_lstm_out = self.droplstm(b_lstm_out) 51 | 52 | return f_lstm_out, b_lstm_out 53 | 54 | 55 | def forward(self, inputs, word_seq_length): 56 | """ 57 | """ 58 | f_lstm_out, b_lstm_out = self.get_lstm_features(inputs, word_seq_length) 59 | 60 | lengths = list(map(int, word_seq_length)) 61 | rb_lstm_out = reverse_padded_sequence(b_lstm_out, lengths) 62 | 63 | lstm_out = torch.cat((f_lstm_out, rb_lstm_out), dim=-1) 64 | 65 | return lstm_out, (f_lstm_out, b_lstm_out) 66 | -------------------------------------------------------------------------------- /ner/modules/gaz_embed.py: -------------------------------------------------------------------------------- 1 | __author__ = "liuwei" 2 | 3 | """ 4 | the strategy to obtain gaz embedding. this embedding is concat to the char embedding 5 | """ 6 | 7 | import numpy 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from ner.modules.highway import Highway 13 | from ner.functions.utils import random_embedding 14 | from ner.functions.utils import reverse_padded_sequence 15 | from ner.functions.iniatlize import init_cnn_weight 16 | from ner.functions.iniatlize import init_maxtrix_weight 17 | from ner.functions.iniatlize import init_vector 18 | 19 | class Gaz_Embed(torch.nn.Module): 20 | def __init__(self, data, type=1): 21 | """ 22 | Args: 23 | data: all the data information 24 | type: the type of strategy, 1 for avg, 2 for short first, 3 for long first 25 | """ 26 | print('build gaz embedding...') 27 | 28 | super(Gaz_Embed, self).__init__() 29 | 30 | self.gpu = data.HP_gpu 31 | self.data = data 32 | self.type = type 33 | self.gaz_dim = data.gaz_emb_dim 34 | self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), data.gaz_emb_dim) 35 | self.dropout = nn.Dropout(p=0.5) 36 | 37 | if data.pretrain_gaz_embedding is not None: 38 | self.gaz_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_gaz_embedding)) 39 | else: 40 | self.gaz_embedding.weight.data.copy_( 41 | torch.from_numpy(random_embedding(data.gaz_alphabet.size(), data.gaz_emb_dim)) 42 | ) 43 | 44 | self.filters = [[1, 20], [2, 30]] 45 | if self.type == 4: 46 | # use conv, so we need to define some conv 47 | # here we use 20 1-d conv, and 30 2-d conv 48 | self.build_cnn(self.filters) 49 | 50 | ## also use highway, 2 layers highway 51 | # self.highway = Highway(self.gaz_dim, num_layers=2) 52 | # if self.gpu: 53 | # self.highway = self.highway.cuda() 54 | 55 | if self.type == 5: 56 | # use self-attention 57 | self.build_attention() 58 | 59 | if self.gpu: 60 | self.gaz_embedding = self.gaz_embedding.cuda() 61 | 62 | def build_cnn(self, filters): 63 | """ 64 | build cnn for convolution the gaz embeddings 65 | Args: 66 | filters: the filters definetion 67 | """ 68 | for filter in filters: 69 | k_size = filter[0] 70 | channel_size = filter[1] 71 | 72 | conv = torch.nn.Conv1d( 73 | in_channels=self.gaz_dim, 74 | out_channels=channel_size, 75 | kernel_size=k_size, 76 | bias=True 77 | ) 78 | 79 | if self.gpu: 80 | conv = conv.cuda() 81 | 82 | init_cnn_weight(conv) 83 | self.add_module('conv_{}d'.format(k_size), conv) 84 | del conv 85 | 86 | def build_attention(self): 87 | """ 88 | build a self-attention to weight add the values 89 | Args: 90 | max_gaz_num: the max gaz number 91 | """ 92 | w1 = torch.zeros(self.gaz_dim, self.gaz_dim) 93 | w2 = torch.zeros(self.gaz_dim) 94 | 95 | if self.gpu: 96 | w1 = w1.cuda() 97 | w2 = w2.cuda() 98 | init_maxtrix_weight(w1) 99 | init_vector(w2) 100 | 101 | self.W1 = nn.Parameter(w1) 102 | self.W2 = nn.Parameter(w2) 103 | 104 | def forward(self, inputs): 105 | """ 106 | the inputs is a tuple, include gaz_seq_tensor, gaz_seq_length, gaz_mask_tensor 107 | Args: 108 | gaz_seq_tensor: [batch_size, seq_len, gaz_num] 109 | gaz_seq_length: [batch_size, seq_len] 110 | gaz_mask_tensor: [batch_size, seq_len, gaz_num] 111 | """ 112 | gaz_seq_tensor, gaz_seq_lengths, gaz_mask_tensor = inputs 113 | batch_size = gaz_seq_tensor.size(0) 114 | seq_len = gaz_seq_tensor.size(1) 115 | gaz_num = gaz_seq_tensor.size(2) 116 | 117 | # type = 1, short first; type = 2, long first; type = 3, avg; type = 4, cnn 118 | if self.type == 1: 119 | # short first 120 | gaz_ids = gaz_seq_tensor[:, :, 0] 121 | gaz_ids = gaz_ids.view(batch_size, seq_len, -1) 122 | gaz_ids = torch.squeeze(gaz_ids, dim=-1) 123 | if self.gpu: 124 | gaz_ids = gaz_ids.cuda() 125 | 126 | gaz_embs = self.gaz_embedding(gaz_ids) 127 | 128 | return gaz_embs 129 | 130 | elif self.type == 2: 131 | # long first 132 | select_ids = gaz_seq_lengths 133 | select_ids = select_ids.view(batch_size, seq_len, -1) 134 | 135 | # the max index = len - 1 136 | select_ids = select_ids - 1 137 | # print(select_ids[0]) 138 | gaz_ids = torch.gather(gaz_seq_tensor, dim=2, index=select_ids) 139 | gaz_ids = gaz_ids.view(batch_size, seq_len, -1) 140 | gaz_ids = torch.squeeze(gaz_ids, dim=-1) 141 | if self.gpu: 142 | gaz_ids = gaz_ids.cuda() 143 | 144 | gaz_embs = self.gaz_embedding(gaz_ids) 145 | 146 | return gaz_embs 147 | 148 | elif self.type == 3: 149 | ## avg first 150 | # [batch_size, seq_len, gaz_num, gaz_dim] 151 | if self.gpu: 152 | gaz_seq_tensor = gaz_seq_tensor.cuda() 153 | 154 | gaz_embs = self.gaz_embedding(gaz_seq_tensor) 155 | 156 | # use mask to do select sum, mask: [batch_size, seq_len, gaz_num] 157 | pad_mask = gaz_mask_tensor.view(batch_size, seq_len, gaz_num, -1).float() 158 | 159 | # padding embedding transform to 0 160 | gaz_embs = gaz_embs * pad_mask 161 | 162 | # do sum at the gaz_num axis, result is [batch_size, seq_len, gaz_dim] 163 | gaz_embs = torch.sum(gaz_embs, dim=2) 164 | gaz_seq_lengths = gaz_seq_lengths.view(batch_size, seq_len, -1).float() 165 | gaz_embs = gaz_embs / gaz_seq_lengths 166 | 167 | return gaz_embs 168 | 169 | elif self.type == 4: 170 | ## use convolution 171 | # first get all the gaz embedding representation 172 | # [batch_size, seq_len, gaz_num, gaz_dim] 173 | input_embs = self.gaz_embedding(gaz_seq_tensor) 174 | 175 | # transform to [batch_size * seq_len, gaz_num, gaz_dim] to use Conv1d 176 | input_embs = input_embs.view(-1, gaz_num, self.gaz_dim) 177 | input_embs = torch.transpose(input_embs, 2, 1) 178 | input_embs = self.dropout(input_embs) 179 | 180 | gaz_embs = [] 181 | 182 | for filter in self.filters: 183 | k_size = filter[0] 184 | 185 | conv = getattr(self, 'conv_{}d'.format(k_size)) 186 | convolved = conv(input_embs) 187 | 188 | # convolved is [batch_size * seq_len, channel, width-k_size+1] 189 | # do active and max-pool, [batch_size * seq_len, channel] 190 | convolved, _ = torch.max(convolved, dim=-1) 191 | if True: 192 | convolved = F.tanh(convolved) 193 | 194 | gaz_embs.append(convolved) 195 | 196 | # transpose to [batch_size, seq_len, gaz_dim] 197 | gaz_embs = torch.cat(gaz_embs, dim=-1) 198 | 199 | # gaz_embs = self.dropout(gaz_embs) 200 | # gaz_embs = self.highway(gaz_embs) 201 | gaz_embs = gaz_embs.view(batch_size, seq_len, -1) 202 | 203 | return gaz_embs 204 | 205 | elif self.type == 5: 206 | ## self attention 207 | gaz_embs = self.gaz_embedding(gaz_seq_tensor) 208 | # print('origin: ', gaz_embs[0][0][:20]) 209 | input_embs = gaz_embs.view(-1, gaz_num, self.gaz_dim) 210 | input_embs = torch.transpose(input_embs, 2, 1) 211 | 212 | ### step1, cal alpha 213 | # cal the alpha, result is [batch * seq_len, d1, gaz_num] 214 | alpha = torch.matmul(self.W1, input_embs) 215 | alpha = F.tanh(alpha) 216 | 217 | # result is [batch * seq_len, gaz_num] 218 | alpha = torch.transpose(alpha, 2, 1) 219 | weight2 = self.W2 220 | alpha = torch.matmul(alpha, weight2) 221 | 222 | # before softmax, we need to mask 223 | # alpha = alpha * gaz_mask_tensor.contiguous().view(-1, gaz_num).float() 224 | zero_mask = (1 - gaz_mask_tensor.float()) * (-2**31 + 1) 225 | zero_mask = zero_mask.contiguous().view(-1, gaz_num) 226 | alpha = alpha + zero_mask 227 | 228 | ### step2 do softmax, 229 | # [batch * seq_len, gaz_num] 230 | # alpha = torch.exp(alpha) 231 | # total_alpha = torch.sum(alpha, dim=-1, keepdim=True) 232 | # alpha = torch.div(alpha, total_alpha) 233 | alpha = F.softmax(alpha, dim=-1) 234 | alpha = alpha.view(batch_size, seq_len, gaz_num, -1) 235 | 236 | ### step3, weighted add, [batch_size, seq_len, gaz_num, gaz_dim] 237 | gaz_embs = gaz_embs * alpha 238 | gaz_embs = torch.sum(gaz_embs, dim=2) 239 | 240 | return gaz_embs 241 | 242 | 243 | 244 | 245 | -------------------------------------------------------------------------------- /ner/modules/highway.py: -------------------------------------------------------------------------------- 1 | __author__ = 'liuwei' 2 | 3 | """ 4 | implements of the highway net 5 | include two gate: transform gate(G_T) and carry gate(G_C) 6 | H = w_h * x + b_h 7 | G_T = sigmoid(w_t * x + b_t) 8 | G_C = sigmoid(w_c * x + b_c) 9 | outputs: 10 | outputs = G_T * H + G_C * x 11 | 12 | for sample: 13 | G_C = (1 - G_T), then: 14 | outputs = G_T * H + (1 - G_T) * x 15 | and generally set b_c = -1 or -3, that mean set b_t = 1 or 3 16 | """ 17 | 18 | import torch 19 | import torch.nn as nn 20 | import numpy as np 21 | 22 | from ner.functions.iniatlize import init_highway 23 | 24 | class Highway(nn.Module): 25 | def __init__(self, input_dim, num_layers=1, activation=nn.functional.relu, 26 | require_grad=True): 27 | """ 28 | Args: 29 | input_dim: the dim 30 | num_layers: the numer of highway layers 31 | activation: activation function, tanh or relu 32 | """ 33 | super(Highway, self).__init__() 34 | 35 | self._input_dim = input_dim 36 | self._num_layers = num_layers 37 | 38 | # output is input_dim * 2, because one is candidate status, and another 39 | # is transform gate 40 | self._layers = torch.nn.ModuleList( 41 | [nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)] 42 | ) 43 | self._activation = activation 44 | i = 0 45 | for layer in self._layers: 46 | layer.weight.requires_grad = require_grad 47 | layer.bias.requires_grad = require_grad 48 | init_highway(layer, 100) 49 | layer.bias[input_dim:].data.fill_(1) 50 | 51 | i += 1 52 | 53 | 54 | def forward(self, inputs): 55 | """ 56 | Args: 57 | inputs: a tensor, size is [batch_size, n_tokens, input_dim] 58 | """ 59 | current_input = inputs 60 | for layer in self._layers: 61 | proj_inputs = layer(current_input) 62 | linear_part = current_input 63 | 64 | del current_input 65 | 66 | # here the gate is carry gate, if you change it to transform gate 67 | # the bias init should change too, maybe -1 or -3 even 68 | nonlinear_part, carry_gate = proj_inputs.chunk(2, dim=-1) 69 | nonlinear_part = self._activation(nonlinear_part) 70 | carry_gate = torch.nn.functional.sigmoid(carry_gate) 71 | current_input = (1 - carry_gate) * nonlinear_part + carry_gate * linear_part 72 | 73 | return current_input 74 | 75 | -------------------------------------------------------------------------------- /ner/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__init__.py -------------------------------------------------------------------------------- /ner/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ner/utils/__pycache__/alphabet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/alphabet.cpython-36.pyc -------------------------------------------------------------------------------- /ner/utils/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /ner/utils/__pycache__/functions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/functions.cpython-36.pyc -------------------------------------------------------------------------------- /ner/utils/__pycache__/gazetteer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/gazetteer.cpython-36.pyc -------------------------------------------------------------------------------- /ner/utils/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /ner/utils/__pycache__/msra_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/msra_data.cpython-36.pyc -------------------------------------------------------------------------------- /ner/utils/__pycache__/trie.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/trie.cpython-36.pyc -------------------------------------------------------------------------------- /ner/utils/__pycache__/weibo_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/weibo_data.cpython-36.pyc -------------------------------------------------------------------------------- /ner/utils/alphabet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects. 5 | """ 6 | import json 7 | import os 8 | 9 | 10 | class Alphabet: 11 | def __init__(self, name, label=False, keep_growing=True): 12 | self.__name = name 13 | self.UNKNOWN = "" 14 | self.label = label 15 | self.instance2index = {} 16 | self.instances = [] 17 | self.keep_growing = keep_growing 18 | 19 | # Index 0 is occupied by default, all else following. 20 | self.default_index = 0 21 | self.next_index = 1 22 | if not self.label: 23 | self.add(self.UNKNOWN) 24 | 25 | def clear(self, keep_growing=True): 26 | self.instance2index = {} 27 | self.instances = [] 28 | self.keep_growing = keep_growing 29 | 30 | # Index 0 is occupied by default, all else following. 31 | self.default_index = 0 32 | self.next_index = 1 33 | 34 | def add(self, instance): 35 | if instance not in self.instance2index: 36 | self.instances.append(instance) 37 | self.instance2index[instance] = self.next_index 38 | self.next_index += 1 39 | 40 | def get_index(self, instance): 41 | try: 42 | return self.instance2index[instance] 43 | except KeyError: 44 | if self.keep_growing: 45 | index = self.next_index 46 | self.add(instance) 47 | return index 48 | else: 49 | return self.instance2index[self.UNKNOWN] 50 | 51 | def get_instance(self, index): 52 | if index == 0: 53 | # First index is occupied by the wildcard element. 54 | return None 55 | try: 56 | return self.instances[index - 1] 57 | except IndexError: 58 | print('WARNING:Alphabet get_instance ,unknown instance, return the first label.') 59 | return self.instances[0] 60 | 61 | def size(self): 62 | # if self.label: 63 | # return len(self.instances) 64 | # else: 65 | return len(self.instances) + 1 66 | 67 | def iteritems(self): 68 | return self.instance2index.items() 69 | 70 | def enumerate_items(self, start=1): 71 | if start < 1 or start >= self.size(): 72 | raise IndexError("Enumerate is allowed between [1 : size of the alphabet)") 73 | return zip(range(start, len(self.instances) + 1), self.instances[start - 1:]) 74 | 75 | def close(self): 76 | self.keep_growing = False 77 | 78 | def open(self): 79 | self.keep_growing = True 80 | 81 | def get_content(self): 82 | return {'instance2index': self.instance2index, 'instances': self.instances} 83 | 84 | def from_json(self, data): 85 | self.instances = data["instances"] 86 | self.instance2index = data["instance2index"] 87 | 88 | def save(self, output_directory, name=None): 89 | """ 90 | Save both alhpabet records to the given directory. 91 | :param output_directory: Directory to save model and weights. 92 | :param name: The alphabet saving name, optional. 93 | :return: 94 | """ 95 | saving_name = name if name else self.__name 96 | try: 97 | json.dump(self.get_content(), open(os.path.join(output_directory, saving_name + ".json"), 'w')) 98 | except Exception as e: 99 | print("Exception: Alphabet is not saved: " % repr(e)) 100 | 101 | def load(self, input_directory, name=None): 102 | """ 103 | Load model architecture and weights from the give directory. This allow we use old models even the structure 104 | changes. 105 | :param input_directory: Directory to save model and weights 106 | :return: 107 | """ 108 | loading_name = name if name else self.__name 109 | self.from_json(json.load(open(os.path.join(input_directory, loading_name + ".json")))) 110 | -------------------------------------------------------------------------------- /ner/utils/functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import numpy as np 5 | from ner.utils.alphabet import Alphabet 6 | NULLKEY = "-null-" 7 | def normalize_word(word): 8 | new_word = "" 9 | for char in word: 10 | if char.isdigit(): 11 | new_word += '0' 12 | else: 13 | new_word += char 14 | return new_word 15 | 16 | def read_word_instance(input_file, word_alphabet, label_alphabet, number_normalized, max_sent_lengths): 17 | """ 18 | read only word msg, no char. 19 | note that here word is chinese character! 20 | """ 21 | instence_texts = [] 22 | instence_ids = [] 23 | entity_num = 0 24 | # max_sent_lengths = 1000 25 | with open(input_file, 'r', errors='ignore') as f: 26 | in_lines = f.readlines() 27 | words = [] 28 | labels = [] 29 | word_ids = [] 30 | label_ids = [] 31 | 32 | for line in in_lines: 33 | if len(line) > 2: 34 | # len less than 2 mean a blank line 35 | pairs = line.strip().split() 36 | word = pairs[0] 37 | if number_normalized: 38 | word = normalize_word(word) 39 | label = pairs[-1] 40 | if "B-" in label or "S-" in label: 41 | entity_num += 1 42 | words.append(word) 43 | labels.append(label) 44 | word_ids.append(word_alphabet.get_index(word)) 45 | label_ids.append(label_alphabet.get_index(label)) 46 | else: 47 | if (max_sent_lengths < 0) or (len(words) < max_sent_lengths): 48 | instence_texts.append([words, labels]) 49 | instence_ids.append([word_ids, label_ids]) 50 | # else: 51 | # print(len(words)) 52 | # print('so long!!!') 53 | words = [] 54 | labels = [] 55 | word_ids = [] 56 | label_ids = [] 57 | # if len(words) > 0: 58 | # instence_texts.append([words, labels]) 59 | # instence_ids.append([word_ids, label_ids]) 60 | 61 | # print("entity num: ", entity_num) 62 | return instence_texts, instence_ids 63 | 64 | 65 | def read_instance(input_file, word_alphabet, char_alphabet, label_alphabet, number_normalized,max_sent_length, char_padding_size=-1, char_padding_symbol = ''): 66 | """ 67 | read the word and char msg, all read into instance_texts and instance_ids 68 | 69 | note that, in the file, every line is a single chinese character and its tag, 70 | but when read into instace_texts and instance_ids, every line is a sentence character, 71 | and sentence character ids 72 | """ 73 | in_lines = open(input_file,'r').readlines() 74 | instence_texts = [] 75 | instence_Ids = [] 76 | words = [] 77 | chars = [] 78 | labels = [] 79 | word_Ids = [] 80 | char_Ids = [] 81 | label_Ids = [] 82 | for line in in_lines: 83 | if len(line) > 2: 84 | pairs = line.strip().split() 85 | word = pairs[0] 86 | if number_normalized: 87 | word = normalize_word(word) 88 | label = pairs[-1] 89 | words.append(word) 90 | labels.append(label) 91 | word_Ids.append(word_alphabet.get_index(word)) 92 | label_Ids.append(label_alphabet.get_index(label)) 93 | char_list = [] 94 | char_Id = [] 95 | for char in word: 96 | char_list.append(char) 97 | if char_padding_size > 0: 98 | char_number = len(char_list) 99 | if char_number < char_padding_size: 100 | char_list = char_list + [char_padding_symbol]*(char_padding_size-char_number) 101 | assert(len(char_list) == char_padding_size) 102 | else: 103 | ### not padding 104 | pass 105 | for char in char_list: 106 | char_Id.append(char_alphabet.get_index(char)) 107 | chars.append(char_list) 108 | char_Ids.append(char_Id) 109 | else: 110 | if (max_sent_length < 0) or (len(words) < max_sent_length): 111 | instence_texts.append([words, chars, labels]) 112 | instence_Ids.append([word_Ids, char_Ids,label_Ids]) 113 | words = [] 114 | chars = [] 115 | labels = [] 116 | word_Ids = [] 117 | char_Ids = [] 118 | label_Ids = [] 119 | return instence_texts, instence_Ids 120 | 121 | 122 | def read_seg_instance(input_file, word_alphabet, biword_alphabet, char_alphabet, label_alphabet, number_normalized, max_sent_length, char_padding_size=-1, char_padding_symbol = ''): 123 | """ 124 | read the word, biword and char msg, all read into instance_texts and instance_ids 125 | 126 | note that, in the file, every line is a single chinese character and its tag, 127 | but when read into instace_texts and instance_ids, every line is a sentence character, 128 | and sentence character ids 129 | """ 130 | in_lines = open(input_file,'r').readlines() 131 | instence_texts = [] 132 | instence_Ids = [] 133 | words = [] 134 | biwords = [] 135 | chars = [] 136 | labels = [] 137 | word_Ids = [] 138 | biword_Ids = [] 139 | char_Ids = [] 140 | label_Ids = [] 141 | for idx in range(len(in_lines)): 142 | line = in_lines[idx] 143 | if len(line) > 2: 144 | pairs = line.strip().split() 145 | word = pairs[0] 146 | if number_normalized: 147 | word = normalize_word(word) 148 | label = pairs[-1] 149 | words.append(word) 150 | if idx < len(in_lines) -1 and len(in_lines[idx+1]) > 2: 151 | biword = word + in_lines[idx+1].strip().split()[0] 152 | else: 153 | biword = word + NULLKEY 154 | biwords.append(biword) 155 | labels.append(label) 156 | word_Ids.append(word_alphabet.get_index(word)) 157 | biword_Ids.append(biword_alphabet.get_index(biword)) 158 | label_Ids.append(label_alphabet.get_index(label)) 159 | char_list = [] 160 | char_Id = [] 161 | for char in word: 162 | char_list.append(char) 163 | if char_padding_size > 0: 164 | char_number = len(char_list) 165 | if char_number < char_padding_size: 166 | char_list = char_list + [char_padding_symbol]*(char_padding_size-char_number) 167 | assert(len(char_list) == char_padding_size) 168 | else: 169 | ### not padding 170 | pass 171 | for char in char_list: 172 | char_Id.append(char_alphabet.get_index(char)) 173 | chars.append(char_list) 174 | char_Ids.append(char_Id) 175 | else: 176 | if (max_sent_length < 0) or (len(words) < max_sent_length): 177 | instence_texts.append([words, biwords, chars, labels]) 178 | instence_Ids.append([word_Ids, biword_Ids, char_Ids,label_Ids]) 179 | words = [] 180 | biwords = [] 181 | chars = [] 182 | labels = [] 183 | word_Ids = [] 184 | biword_Ids = [] 185 | char_Ids = [] 186 | label_Ids = [] 187 | return instence_texts, instence_Ids 188 | 189 | def read_instance_with_gaz_no_char(input_file, gaz, word_alphabet, biword_alphabet, gaz_alphabet, label_alphabet, number_normalized, max_sent_length, use_single): 190 | """ 191 | read instance with, word, biword, gaz, lable, no char 192 | Args: 193 | input_file: the input file path 194 | gaz: the gaz obj 195 | word_alphabet: word 196 | biword_alphabet: biword 197 | gaz_alphabet: gaz 198 | label_alphabet: label 199 | number_normalized: true or false 200 | max_sent_length: the max length 201 | """ 202 | in_lines = open(input_file, 'r', encoding="utf-8").readlines() 203 | instence_texts = [] 204 | instence_Ids = [] 205 | words = [] 206 | biwords = [] 207 | labels = [] 208 | word_Ids = [] 209 | biword_Ids = [] 210 | label_Ids = [] 211 | for idx in range(len(in_lines)): 212 | line = in_lines[idx] 213 | if len(line) > 2: 214 | pairs = line.strip().split() 215 | word = pairs[0] 216 | if number_normalized: 217 | word = normalize_word(word) 218 | label = pairs[-1] 219 | if idx < len(in_lines) - 1 and len(in_lines[idx+1]) > 2: 220 | biword = word + in_lines[idx+1].strip().split()[0] 221 | else: 222 | biword = word + NULLKEY 223 | biwords.append(biword) 224 | words.append(word) 225 | labels.append(label) 226 | word_Ids.append(word_alphabet.get_index(word)) 227 | biword_Ids.append(biword_alphabet.get_index(biword)) 228 | label_Ids.append(label_alphabet.get_index(label)) 229 | else: 230 | if ((max_sent_length < 0) or (len(words) < max_sent_length)) and (len(words) > 0): 231 | gazs = [] 232 | gaz_Ids = [] 233 | gazs_length = [] 234 | w_length = len(words) 235 | 236 | reverse_gazs = [[] for i in range(w_length)] 237 | reverse_gaz_Ids = [[] for i in range(w_length)] 238 | flag = [0 for f in range(w_length)] 239 | # assign sub-sequence to every chinese letter 240 | for i in range(w_length): 241 | matched_list = gaz.enumerateMatchList(words[i:]) 242 | 243 | if use_single and len(matched_list) > 0: 244 | f_len = len(matched_list[0]) 245 | 246 | if (flag[i] == 1 or len(matched_list) > 1) and len(matched_list[-1]) == 1: 247 | matched_list = matched_list[:-1] 248 | 249 | for f_pos in range(i, i+f_len): 250 | flag[f_pos] = 1 251 | 252 | matched_length = [len(a) for a in matched_list] 253 | 254 | gazs.append(matched_list) 255 | matched_Id = [gaz_alphabet.get_index(entity) for entity in matched_list] 256 | if matched_Id: 257 | # gaz_Ids.append([matched_Id, matched_length]) 258 | gaz_Ids.append(matched_Id) 259 | gazs_length.append(matched_length) 260 | else: 261 | gaz_Ids.append([]) 262 | gazs_length.append([]) 263 | 264 | for i in range(w_length-1, -1, -1): 265 | now_pos_gaz = gazs[i] 266 | now_pos_gaz_Id = gaz_Ids[i] 267 | now_pos_gaz_len = gazs_length[i] 268 | 269 | ## Traversing it 270 | l = len(now_pos_gaz) 271 | assert len(now_pos_gaz) == len(now_pos_gaz_Id) 272 | for j in range(l): 273 | width = now_pos_gaz_len[j] 274 | end_char_pos = i + width - 1 275 | 276 | reverse_gazs[end_char_pos].append(now_pos_gaz[j]) 277 | reverse_gaz_Ids[end_char_pos].append(now_pos_gaz_Id[j]) 278 | 279 | 280 | instence_texts.append([words, biwords, gazs, reverse_gazs, labels]) 281 | instence_Ids.append([word_Ids, biword_Ids, gaz_Ids, reverse_gaz_Ids, label_Ids]) 282 | words = [] 283 | biwords = [] 284 | labels = [] 285 | word_Ids = [] 286 | biword_Ids = [] 287 | label_Ids = [] 288 | gazs = [] 289 | reverse_gazs = [] 290 | gaz_Ids = [] 291 | reverse_gaz_Ids = [] 292 | return instence_texts, instence_Ids 293 | 294 | def read_instance_with_gaz(input_file, gaz, word_alphabet, biword_alphabet, char_alphabet, gaz_alphabet, label_alphabet, number_normalized, max_sent_length, char_padding_size=-1, char_padding_symbol = ''): 295 | in_lines = open(input_file,'r').readlines() 296 | instence_texts = [] 297 | instence_Ids = [] 298 | words = [] 299 | biwords = [] 300 | chars = [] 301 | labels = [] 302 | word_Ids = [] 303 | biword_Ids = [] 304 | char_Ids = [] 305 | label_Ids = [] 306 | for idx in range(len(in_lines)): 307 | line = in_lines[idx] 308 | if len(line) > 2: 309 | pairs = line.strip().split() 310 | word = pairs[0] 311 | if number_normalized: 312 | word = normalize_word(word) 313 | label = pairs[-1] 314 | if idx < len(in_lines) -1 and len(in_lines[idx+1]) > 2: 315 | biword = word + in_lines[idx+1].strip().split()[0] 316 | else: 317 | biword = word + NULLKEY 318 | biwords.append(biword) 319 | words.append(word) 320 | labels.append(label) 321 | word_Ids.append(word_alphabet.get_index(word)) 322 | biword_Ids.append(biword_alphabet.get_index(biword)) 323 | label_Ids.append(label_alphabet.get_index(label)) 324 | char_list = [] 325 | char_Id = [] 326 | for char in word: 327 | char_list.append(char) 328 | if char_padding_size > 0: 329 | char_number = len(char_list) 330 | if char_number < char_padding_size: 331 | char_list = char_list + [char_padding_symbol]*(char_padding_size-char_number) 332 | assert(len(char_list) == char_padding_size) 333 | else: 334 | ### not padding 335 | pass 336 | for char in char_list: 337 | char_Id.append(char_alphabet.get_index(char)) 338 | chars.append(char_list) 339 | char_Ids.append(char_Id) 340 | 341 | else: 342 | if ((max_sent_length < 0) or (len(words) < max_sent_length)) and (len(words)>0): 343 | gazs = [] 344 | gaz_Ids = [] 345 | w_length = len(words) 346 | # print sentence 347 | # for w in words: 348 | # print w," ", 349 | # print 350 | for idx in range(w_length): 351 | matched_list = gaz.enumerateMatchList(words[idx:]) 352 | matched_length = [len(a) for a in matched_list] 353 | # print idx,"----------" 354 | # print "forward...feed:","".join(words[idx:]) 355 | # for a in matched_list: 356 | # print a,len(a)," ", 357 | # print 358 | 359 | # print matched_length 360 | 361 | gazs.append(matched_list) 362 | matched_Id = [gaz_alphabet.get_index(entity) for entity in matched_list] 363 | if matched_Id: 364 | gaz_Ids.append([matched_Id, matched_length]) 365 | else: 366 | gaz_Ids.append([]) 367 | 368 | instence_texts.append([words, biwords, chars, gazs, labels]) 369 | instence_Ids.append([word_Ids, biword_Ids, char_Ids, gaz_Ids, label_Ids]) 370 | words = [] 371 | biwords = [] 372 | chars = [] 373 | labels = [] 374 | word_Ids = [] 375 | biword_Ids = [] 376 | char_Ids = [] 377 | label_Ids = [] 378 | gazs = [] 379 | gaz_Ids = [] 380 | return instence_texts, instence_Ids 381 | 382 | 383 | def read_instance_with_gaz_in_sentence(input_file, gaz, word_alphabet, biword_alphabet, char_alphabet, gaz_alphabet, label_alphabet, number_normalized, max_sent_length, char_padding_size=-1, char_padding_symbol = ''): 384 | in_lines = open(input_file,'r').readlines() 385 | instence_texts = [] 386 | instence_Ids = [] 387 | for idx in range(len(in_lines)): 388 | pair = in_lines[idx].strip() 389 | orig_words = list(pair[0]) 390 | 391 | if (max_sent_length > 0) and (len(orig_words) > max_sent_length): 392 | continue 393 | biwords = [] 394 | biword_Ids = [] 395 | if number_normalized: 396 | words = [] 397 | for word in orig_words: 398 | word = normalize_word(word) 399 | words.append(word) 400 | else: 401 | words = orig_words 402 | word_num = len(words) 403 | for idy in range(word_num): 404 | if idy < word_num - 1: 405 | biword = words[idy]+words[idy+1] 406 | else: 407 | biword = words[idy]+NULLKEY 408 | biwords.append(biword) 409 | biword_Ids.append(biword_alphabet.get_index(biword)) 410 | word_Ids = [word_alphabet.get_index(word) for word in words] 411 | label = pair[-1] 412 | label_Id = label_alphabet.get_index(label) 413 | gazs = [] 414 | gaz_Ids = [] 415 | word_num = len(words) 416 | chars = [[word] for word in words] 417 | char_Ids = [[char_alphabet.get_index(word)] for word in words] 418 | ## print sentence 419 | # for w in words: 420 | # print w," ", 421 | # print 422 | for idx in range(word_num): 423 | matched_list = gaz.enumerateMatchList(words[idx:]) 424 | matched_length = [len(a) for a in matched_list] 425 | # print idx,"----------" 426 | # print "forward...feed:","".join(words[idx:]) 427 | # for a in matched_list: 428 | # print a,len(a)," ", 429 | # print 430 | # print matched_length 431 | gazs.append(matched_list) 432 | matched_Id = [gaz_alphabet.get_index(entity) for entity in matched_list] 433 | if matched_Id: 434 | gaz_Ids.append([matched_Id, matched_length]) 435 | else: 436 | gaz_Ids.append([]) 437 | instence_texts.append([words, biwords, chars, gazs, label]) 438 | instence_Ids.append([word_Ids, biword_Ids, char_Ids, gaz_Ids, label_Id]) 439 | return instence_texts, instence_Ids 440 | 441 | 442 | def build_pretrain_embedding(embedding_path, word_alphabet, embedd_dim=100, norm=True): 443 | embedd_dict = dict() 444 | if embedding_path != None: 445 | embedd_dict, embedd_dim = load_pretrain_emb(embedding_path) 446 | scale = np.sqrt(3.0 / embedd_dim) 447 | pretrain_emb = np.empty([word_alphabet.size(), embedd_dim]) 448 | perfect_match = 0 449 | case_match = 0 450 | not_match = 0 451 | 452 | ## we should also init the index 0 453 | pretrain_emb[0, :] = np.random.uniform(-scale, scale, [1, embedd_dim]) 454 | 455 | for word, index in word_alphabet.iteritems(): 456 | if word in embedd_dict: 457 | if norm: 458 | pretrain_emb[index,:] = norm2one(embedd_dict[word]) 459 | else: 460 | pretrain_emb[index,:] = embedd_dict[word] 461 | perfect_match += 1 462 | elif word.lower() in embedd_dict: 463 | if norm: 464 | pretrain_emb[index,:] = norm2one(embedd_dict[word.lower()]) 465 | else: 466 | pretrain_emb[index,:] = embedd_dict[word.lower()] 467 | case_match += 1 468 | else: 469 | pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedd_dim]) 470 | not_match += 1 471 | pretrained_size = len(embedd_dict) 472 | print("Embedding:\n pretrain word:%s, prefect match:%s, case_match:%s, oov:%s, oov%%:%s"%(pretrained_size, perfect_match, case_match, not_match, (not_match+0.)/word_alphabet.size())) 473 | return pretrain_emb, embedd_dim 474 | 475 | 476 | 477 | def norm2one(vec): 478 | root_sum_square = np.sqrt(np.sum(np.square(vec))) 479 | return vec/root_sum_square 480 | 481 | def load_pretrain_emb(embedding_path): 482 | embedd_dim = -1 483 | embedd_dict = dict() 484 | with open(embedding_path, 'r', encoding="utf-8") as file: 485 | for line in file: 486 | line = line.strip() 487 | if len(line) == 0: 488 | continue 489 | tokens = line.split() 490 | if embedd_dim < 0: 491 | embedd_dim = len(tokens) - 1 492 | else: 493 | assert (embedd_dim + 1 == len(tokens)) 494 | embedd = np.empty([1, embedd_dim]) 495 | embedd[:] = tokens[1:] 496 | embedd_dict[tokens[0]] = embedd 497 | return embedd_dict, embedd_dim 498 | 499 | if __name__ == '__main__': 500 | a = np.arange(9.0) 501 | print(a) 502 | print(norm2one(a)) 503 | -------------------------------------------------------------------------------- /ner/utils/gazetteer.py: -------------------------------------------------------------------------------- 1 | from ner.utils.trie import Trie 2 | 3 | class Gazetteer: 4 | def __init__(self, lower, use_single): 5 | self.trie = Trie(use_single) 6 | self.ent2type = {} ## word list to type 7 | self.ent2id = {"":0} ## word list to id 8 | self.lower = lower 9 | self.space = "" 10 | 11 | def enumerateMatchList(self, word_list): 12 | if self.lower: 13 | word_list = [word.lower() for word in word_list] 14 | match_list = self.trie.enumerateMatch(word_list, self.space) 15 | return match_list 16 | 17 | def insert(self, word_list, source): 18 | """ 19 | Args: 20 | word_list: this is the letter list of a possible word, 21 | source: source is the word ? 22 | """ 23 | if self.lower: 24 | word_list = [word.lower() for word in word_list] 25 | self.trie.insert(word_list) 26 | string = self.space.join(word_list) 27 | if string not in self.ent2type: 28 | self.ent2type[string] = source 29 | if string not in self.ent2id: 30 | self.ent2id[string] = len(self.ent2id) 31 | 32 | def searchId(self, word_list): 33 | if self.lower: 34 | word_list = [word.lower() for word in word_list] 35 | string = self.space.join(word_list) 36 | if string in self.ent2id: 37 | return self.ent2id[string] 38 | return self.ent2id[""] 39 | 40 | def searchType(self, word_list): 41 | if self.lower: 42 | word_list = [word.lower() for word in word_list] 43 | string = self.space.join(word_list) 44 | if string in self.ent2type: 45 | return self.ent2type[string] 46 | print("Error in finding entity type at gazetteer.py, exit program! String:", string) 47 | exit(0) 48 | 49 | def size(self): 50 | return len(self.ent2type) 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /ner/utils/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | # from operator import add 5 | # 6 | import numpy as np 7 | import math 8 | import sys 9 | import os 10 | 11 | 12 | 13 | ## input as sentence level labels 14 | def get_ner_fmeasure(golden_lists, predict_lists, label_type="BMES"): 15 | sent_num = len(golden_lists) 16 | golden_full = [] 17 | predict_full = [] 18 | right_full = [] 19 | right_tag = 0 20 | all_tag = 0 21 | for idx in range(0, sent_num): 22 | # word_list = sentence_lists[idx] 23 | golden_list = golden_lists[idx] 24 | predict_list = predict_lists[idx] 25 | for idy in range(len(golden_list)): 26 | if golden_list[idy] == predict_list[idy]: 27 | right_tag += 1 28 | all_tag += len(golden_list) 29 | if label_type == "BMES": 30 | gold_matrix = get_ner_BMES(golden_list) 31 | pred_matrix = get_ner_BMES(predict_list) 32 | else: 33 | gold_matrix = get_ner_BIO(golden_list) 34 | pred_matrix = get_ner_BIO(predict_list) 35 | # print "gold", gold_matrix 36 | # print "pred", pred_matrix 37 | right_ner = list(set(gold_matrix).intersection(set(pred_matrix))) 38 | golden_full += gold_matrix 39 | predict_full += pred_matrix 40 | right_full += right_ner 41 | right_num = len(right_full) 42 | golden_num = len(golden_full) 43 | predict_num = len(predict_full) 44 | 45 | if predict_num == 0: 46 | precision = -1 47 | else: 48 | precision = (right_num+0.0)/predict_num 49 | 50 | if golden_num == 0: 51 | recall = -1 52 | else: 53 | recall = (right_num+0.0)/golden_num 54 | if (precision == -1) or (recall == -1) or (precision + recall) <= 0.: 55 | f_measure = -1 56 | else: 57 | f_measure = 2 * precision * recall / (precision + recall) 58 | accuracy = (right_tag+0.0)/all_tag 59 | # print "Accuracy: ", right_tag,"/",all_tag,"=",accuracy 60 | print("gold_num = ", golden_num, " pred_num = ", predict_num, " right_num = ", right_num) 61 | return accuracy, precision, recall, f_measure 62 | 63 | 64 | def reverse_style(input_string): 65 | target_position = input_string.index('[') 66 | input_len = len(input_string) 67 | output_string = input_string[target_position:input_len] + input_string[0:target_position] 68 | return output_string 69 | 70 | 71 | def get_ner_BMES(label_list): 72 | # list_len = len(word_list) 73 | # assert(list_len == len(label_list)), "word list size unmatch with label list" 74 | list_len = len(label_list) 75 | begin_label = 'B-' 76 | end_label = 'E-' 77 | single_label = 'S-' 78 | whole_tag = '' 79 | index_tag = '' 80 | tag_list = [] 81 | stand_matrix = [] 82 | for i in range(0, list_len): 83 | # wordlabel = word_list[i] 84 | current_label = label_list[i].upper() 85 | if begin_label in current_label: 86 | if index_tag != '': 87 | tag_list.append(whole_tag + ',' + str(i-1)) 88 | whole_tag = current_label.replace(begin_label,"",1) +'[' +str(i) 89 | index_tag = current_label.replace(begin_label,"",1) 90 | 91 | elif single_label in current_label: 92 | if index_tag != '': 93 | tag_list.append(whole_tag + ',' + str(i-1)) 94 | whole_tag = current_label.replace(single_label,"",1) +'[' +str(i) 95 | tag_list.append(whole_tag) 96 | whole_tag = "" 97 | index_tag = "" 98 | elif end_label in current_label: 99 | if index_tag != '': 100 | tag_list.append(whole_tag +',' + str(i)) 101 | whole_tag = '' 102 | index_tag = '' 103 | else: 104 | continue 105 | if (whole_tag != '') and (index_tag != ''): 106 | tag_list.append(whole_tag) 107 | tag_list_len = len(tag_list) 108 | 109 | for i in range(0, tag_list_len): 110 | if len(tag_list[i]) > 0: 111 | tag_list[i] = tag_list[i]+ ']' 112 | insert_list = reverse_style(tag_list[i]) 113 | stand_matrix.append(insert_list) 114 | # print stand_matrix 115 | return stand_matrix 116 | 117 | 118 | def get_ner_BIO(label_list): 119 | # list_len = len(word_list) 120 | # assert(list_len == len(label_list)), "word list size unmatch with label list" 121 | list_len = len(label_list) 122 | begin_label = 'B-' 123 | inside_label = 'I-' 124 | whole_tag = '' 125 | index_tag = '' 126 | tag_list = [] 127 | stand_matrix = [] 128 | for i in range(0, list_len): 129 | # wordlabel = word_list[i] 130 | current_label = label_list[i].upper() 131 | if begin_label in current_label: 132 | if index_tag == '': 133 | whole_tag = current_label.replace(begin_label,"",1) +'[' +str(i) 134 | index_tag = current_label.replace(begin_label,"",1) 135 | else: 136 | tag_list.append(whole_tag + ',' + str(i-1)) 137 | whole_tag = current_label.replace(begin_label,"",1) + '[' + str(i) 138 | index_tag = current_label.replace(begin_label,"",1) 139 | 140 | elif inside_label in current_label: 141 | if current_label.replace(inside_label,"",1) == index_tag: 142 | whole_tag = whole_tag 143 | else: 144 | if (whole_tag != '')&(index_tag != ''): 145 | tag_list.append(whole_tag +',' + str(i-1)) 146 | whole_tag = '' 147 | index_tag = '' 148 | else: 149 | if (whole_tag != '')&(index_tag != ''): 150 | tag_list.append(whole_tag +',' + str(i-1)) 151 | whole_tag = '' 152 | index_tag = '' 153 | 154 | if (whole_tag != '')&(index_tag != ''): 155 | tag_list.append(whole_tag) 156 | tag_list_len = len(tag_list) 157 | 158 | for i in range(0, tag_list_len): 159 | if len(tag_list[i]) > 0: 160 | tag_list[i] = tag_list[i]+ ']' 161 | insert_list = reverse_style(tag_list[i]) 162 | stand_matrix.append(insert_list) 163 | 164 | return stand_matrix 165 | 166 | 167 | 168 | def readSentence(input_file): 169 | in_lines = open(input_file,'r').readlines() 170 | sentences = [] 171 | labels = [] 172 | sentence = [] 173 | label = [] 174 | for line in in_lines: 175 | if len(line) < 2: 176 | sentences.append(sentence) 177 | labels.append(label) 178 | sentence = [] 179 | label = [] 180 | else: 181 | pair = line.strip('\n').split(' ') 182 | sentence.append(pair[0]) 183 | label.append(pair[-1]) 184 | return sentences,labels 185 | 186 | 187 | def readTwoLabelSentence(input_file, pred_col=-1): 188 | in_lines = open(input_file,'r').readlines() 189 | sentences = [] 190 | predict_labels = [] 191 | golden_labels = [] 192 | sentence = [] 193 | predict_label = [] 194 | golden_label = [] 195 | for line in in_lines: 196 | if "##score##" in line: 197 | continue 198 | if len(line) < 2: 199 | sentences.append(sentence) 200 | golden_labels.append(golden_label) 201 | predict_labels.append(predict_label) 202 | sentence = [] 203 | golden_label = [] 204 | predict_label = [] 205 | else: 206 | pair = line.strip('\n').split(' ') 207 | sentence.append(pair[0]) 208 | golden_label.append(pair[1]) 209 | predict_label.append(pair[pred_col]) 210 | 211 | return sentences,golden_labels,predict_labels 212 | 213 | 214 | def fmeasure_from_file(golden_file, predict_file, label_type="BMES"): 215 | print("Get f measure from file:", golden_file, predict_file) 216 | print("Label format:",label_type) 217 | golden_sent,golden_labels = readSentence(golden_file) 218 | predict_sent,predict_labels = readSentence(predict_file) 219 | acc, P,R,F = get_ner_fmeasure(golden_labels, predict_labels, label_type) 220 | print ("Acc:%s, P:%s R:%s, F:%s"%(acc, P,R,F)) 221 | 222 | 223 | 224 | def fmeasure_from_singlefile(twolabel_file, label_type="BMES", pred_col=-1): 225 | sent,golden_labels,predict_labels = readTwoLabelSentence(twolabel_file, pred_col) 226 | P,R,F = get_ner_fmeasure(golden_labels, predict_labels, label_type) 227 | print ("P:%s, R:%s, F:%s"%(P,R,F)) 228 | 229 | 230 | 231 | if __name__ == '__main__': 232 | # print "sys:",len(sys.argv) 233 | if len(sys.argv) == 3: 234 | fmeasure_from_singlefile(sys.argv[1],"BMES",int(sys.argv[2])) 235 | else: 236 | fmeasure_from_singlefile(sys.argv[1],"BMES") 237 | 238 | -------------------------------------------------------------------------------- /ner/utils/msra_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import numpy as np 5 | import pickle 6 | from ner.utils.alphabet import Alphabet 7 | from ner.utils.functions import * 8 | from ner.utils.gazetteer import Gazetteer 9 | 10 | 11 | 12 | START = "" 13 | UNKNOWN = "" 14 | PADDING = "" 15 | NULLKEY = "-null-" 16 | 17 | class Data: 18 | def __init__(self): 19 | self.MAX_SENTENCE_LENGTH = 5000 20 | self.MAX_WORD_LENGTH = -1 21 | self.number_normalized = True 22 | self.norm_word_emb = True 23 | self.norm_biword_emb = True 24 | self.norm_gaz_emb = False 25 | self.use_single = False 26 | self.word_alphabet = Alphabet('word') 27 | self.biword_alphabet = Alphabet('biword') 28 | self.char_alphabet = Alphabet('character') 29 | # self.word_alphabet.add(START) 30 | # self.word_alphabet.add(UNKNOWN) 31 | # self.char_alphabet.add(START) 32 | # self.char_alphabet.add(UNKNOWN) 33 | # self.char_alphabet.add(PADDING) 34 | self.label_alphabet = Alphabet('label', True) 35 | self.gaz_lower = False 36 | self.gaz = Gazetteer(self.gaz_lower, self.use_single) 37 | self.gaz_alphabet = Alphabet('gaz') 38 | self.HP_fix_gaz_emb = False 39 | self.HP_use_gaz = True 40 | 41 | self.tagScheme = "NoSeg" 42 | self.char_features = "LSTM" 43 | 44 | self.train_texts = [] 45 | self.dev_texts = [] 46 | self.test_texts = [] 47 | self.raw_texts = [] 48 | 49 | self.train_Ids = [] 50 | self.dev_Ids = [] 51 | self.test_Ids = [] 52 | self.raw_Ids = [] 53 | self.use_bigram = True 54 | self.word_emb_dim = 100 55 | self.biword_emb_dim = 50 56 | self.char_emb_dim = 30 57 | self.gaz_emb_dim = 50 58 | self.gaz_dropout = 0.5 59 | self.pretrain_word_embedding = None 60 | self.pretrain_biword_embedding = None 61 | self.pretrain_gaz_embedding = None 62 | self.label_size = 0 63 | self.word_alphabet_size = 0 64 | self.biword_alphabet_size = 0 65 | self.char_alphabet_size = 0 66 | self.label_alphabet_size = 0 67 | ### hyperparameters 68 | self.HP_iteration = 100 69 | self.HP_batch_size = 10 70 | self.HP_char_hidden_dim = 100 71 | self.HP_hidden_dim = 200 72 | self.HP_dropout = 0.5 73 | self.HP_lstm_layer = 1 74 | self.HP_bilstm = True 75 | self.HP_use_char = False 76 | self.HP_gpu = False 77 | self.HP_lr = 0.015 78 | self.HP_lr_decay = 0.05 79 | self.HP_clip = 5.0 80 | self.HP_momentum = 0 81 | 82 | 83 | def show_data_summary(self): 84 | print("DATA SUMMARY START:") 85 | print(" Tag scheme: %s"%(self.tagScheme)) 86 | print(" MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH)) 87 | print(" MAX WORD LENGTH: %s"%(self.MAX_WORD_LENGTH)) 88 | print(" Number normalized: %s"%(self.number_normalized)) 89 | print(" Use bigram: %s"%(self.use_bigram)) 90 | print(" Word alphabet size: %s"%(self.word_alphabet_size)) 91 | print(" Biword alphabet size: %s"%(self.biword_alphabet_size)) 92 | print(" Char alphabet size: %s"%(self.char_alphabet_size)) 93 | print(" Gaz alphabet size: %s"%(self.gaz_alphabet.size())) 94 | print(" Label alphabet size: %s"%(self.label_alphabet_size)) 95 | print(" Word embedding size: %s"%(self.word_emb_dim)) 96 | print(" Biword embedding size: %s"%(self.biword_emb_dim)) 97 | print(" Char embedding size: %s"%(self.char_emb_dim)) 98 | print(" Gaz embedding size: %s"%(self.gaz_emb_dim)) 99 | print(" Norm word emb: %s"%(self.norm_word_emb)) 100 | print(" Norm biword emb: %s"%(self.norm_biword_emb)) 101 | print(" Norm gaz emb: %s"%(self.norm_gaz_emb)) 102 | print(" Norm gaz dropout: %s"%(self.gaz_dropout)) 103 | print(" Train instance number: %s"%(len(self.train_texts))) 104 | print(" Dev instance number: %s"%(len(self.dev_texts))) 105 | print(" Test instance number: %s"%(len(self.test_texts))) 106 | print(" Raw instance number: %s"%(len(self.raw_texts))) 107 | print(" Hyperpara iteration: %s"%(self.HP_iteration)) 108 | print(" Hyperpara batch size: %s"%(self.HP_batch_size)) 109 | print(" Hyperpara lr: %s"%(self.HP_lr)) 110 | print(" Hyperpara lr_decay: %s"%(self.HP_lr_decay)) 111 | print(" Hyperpara HP_clip: %s"%(self.HP_clip)) 112 | print(" Hyperpara momentum: %s"%(self.HP_momentum)) 113 | print(" Hyperpara hidden_dim: %s"%(self.HP_hidden_dim)) 114 | print(" Hyperpara dropout: %s"%(self.HP_dropout)) 115 | print(" Hyperpara lstm_layer: %s"%(self.HP_lstm_layer)) 116 | print(" Hyperpara bilstm: %s"%(self.HP_bilstm)) 117 | print(" Hyperpara GPU: %s"%(self.HP_gpu)) 118 | print(" Hyperpara use_gaz: %s"%(self.HP_use_gaz)) 119 | print(" Hyperpara fix gaz emb: %s"%(self.HP_fix_gaz_emb)) 120 | print(" Hyperpara use_char: %s"%(self.HP_use_char)) 121 | if self.HP_use_char: 122 | print(" Char_features: %s"%(self.char_features)) 123 | print("DATA SUMMARY END.") 124 | sys.stdout.flush() 125 | 126 | def refresh_label_alphabet(self, input_file): 127 | old_size = self.label_alphabet_size 128 | self.label_alphabet.clear(True) 129 | in_lines = open(input_file,'r', encoding="utf-8").readlines() 130 | for line in in_lines: 131 | if len(line) > 2: 132 | pairs = line.strip().split() 133 | label = pairs[-1] 134 | self.label_alphabet.add(label) 135 | self.label_alphabet_size = self.label_alphabet.size() 136 | startS = False 137 | startB = False 138 | for label,_ in self.label_alphabet.iteritems(): 139 | if "S-" in label.upper(): 140 | startS = True 141 | elif "B-" in label.upper(): 142 | startB = True 143 | if startB: 144 | if startS: 145 | self.tagScheme = "BMES" 146 | else: 147 | self.tagScheme = "BIO" 148 | self.fix_alphabet() 149 | print("Refresh label alphabet finished: old:%s -> new:%s"%(old_size, self.label_alphabet_size)) 150 | 151 | 152 | 153 | def build_alphabet(self, input_file): 154 | in_lines = open(input_file, 'r', encoding="utf-8").readlines() 155 | for idx in range(len(in_lines)): 156 | line = in_lines[idx] 157 | if len(line) > 2: 158 | pairs = line.strip().split() 159 | word = pairs[0] 160 | if self.number_normalized: 161 | word = normalize_word(word) 162 | label = pairs[-1] 163 | self.label_alphabet.add(label) 164 | self.word_alphabet.add(word) 165 | if idx < len(in_lines) - 1 and len(in_lines[idx+1]) > 2: 166 | biword = word + in_lines[idx+1].strip().split()[0] 167 | else: 168 | biword = word + NULLKEY 169 | self.biword_alphabet.add(biword) 170 | for char in word: 171 | self.char_alphabet.add(char) 172 | self.word_alphabet_size = self.word_alphabet.size() 173 | self.biword_alphabet_size = self.biword_alphabet.size() 174 | self.char_alphabet_size = self.char_alphabet.size() 175 | self.label_alphabet_size = self.label_alphabet.size() 176 | startS = False 177 | startB = False 178 | for label,_ in self.label_alphabet.iteritems(): 179 | if "S-" in label.upper(): 180 | startS = True 181 | elif "B-" in label.upper(): 182 | startB = True 183 | if startB: 184 | if startS: 185 | self.tagScheme = "BMES" 186 | else: 187 | self.tagScheme = "BIO" 188 | 189 | 190 | def build_gaz_file(self, gaz_file): 191 | ## build gaz file,initial read gaz embedding file 192 | ## we only get the word, do not read embedding this step 193 | if gaz_file: 194 | fins = open(gaz_file, 'r', encoding="utf-8").readlines() 195 | for fin in fins: 196 | fin = fin.strip().split()[0] 197 | if fin: 198 | self.gaz.insert(fin, "one_source") 199 | print("Load gaz file: ", gaz_file, " total size:", self.gaz.size()) 200 | else: 201 | print("Gaz file is None, load nothing") 202 | 203 | 204 | def build_gaz_alphabet(self, input_file): 205 | """ 206 | based on the train, dev, test file, we only save the seb-sequence word that my be appear 207 | """ 208 | in_lines = open(input_file,'r', encoding="utf-8").readlines() 209 | word_list = [] 210 | for line in in_lines: 211 | if len(line) > 3: 212 | word = line.split()[0] 213 | if self.number_normalized: 214 | word = normalize_word(word) 215 | word_list.append(word) 216 | else: 217 | w_length = len(word_list) 218 | ## Travser from [0: n], [1: n] to [n-1: n] 219 | for idx in range(w_length): 220 | matched_entity = self.gaz.enumerateMatchList(word_list[idx:]) 221 | for entity in matched_entity: 222 | # print entity, self.gaz.searchId(entity),self.gaz.searchType(entity) 223 | self.gaz_alphabet.add(entity) 224 | word_list = [] 225 | print("gaz alphabet size:", self.gaz_alphabet.size()) 226 | 227 | 228 | def fix_alphabet(self): 229 | self.word_alphabet.close() 230 | self.biword_alphabet.close() 231 | self.char_alphabet.close() 232 | self.label_alphabet.close() 233 | self.gaz_alphabet.close() 234 | 235 | 236 | def build_word_pretrain_emb(self, emb_path): 237 | print("build word pretrain emb...") 238 | self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.word_emb_dim, self.norm_word_emb) 239 | 240 | def build_biword_pretrain_emb(self, emb_path): 241 | print("build biword pretrain emb...") 242 | self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(emb_path, self.biword_alphabet, self.biword_emb_dim, self.norm_biword_emb) 243 | 244 | def build_gaz_pretrain_emb(self, emb_path): 245 | print("build gaz pretrain emb...") 246 | self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(emb_path, self.gaz_alphabet, self.gaz_emb_dim, self.norm_gaz_emb) 247 | 248 | 249 | def generate_word_instance(self, input_file, name): 250 | """ 251 | every instance include: words, labels, word_ids, label_ids 252 | """ 253 | self.fix_alphabet() 254 | if name == "train": 255 | self.train_texts, self.train_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 256 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 257 | elif name == "dev": 258 | self.dev_texts, self.dev_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 259 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 260 | elif name == "test": 261 | self.test_texts, self.test_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 262 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 263 | elif name == "raw": 264 | self.raw_texts, self.raw_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 265 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 266 | else: 267 | print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name)) 268 | 269 | 270 | def generate_instance(self, input_file, name): 271 | """ 272 | every instance include: words, biwords, chars, labels, 273 | word_ids, biword_ids, char_ids, label_ids 274 | """ 275 | self.fix_alphabet() 276 | if name == "train": 277 | self.train_texts, self.train_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 278 | elif name == "dev": 279 | self.dev_texts, self.dev_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 280 | elif name == "test": 281 | self.test_texts, self.test_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 282 | elif name == "raw": 283 | self.raw_texts, self.raw_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 284 | else: 285 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 286 | 287 | 288 | def generate_instance_with_gaz(self, input_file, name): 289 | """ 290 | every instance include: words, biwords, chars, labels, gazs, 291 | word_ids, biword_ids, char_ids, label_ids, gaz_ids 292 | """ 293 | self.fix_alphabet() 294 | if name == "train": 295 | self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 296 | elif name == "dev": 297 | self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 298 | elif name == "test": 299 | self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 300 | elif name == "raw": 301 | self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 302 | else: 303 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 304 | 305 | def generate_instance_with_gaz_no_char(self, input_file, name): 306 | """ 307 | every instance include: 308 | words, biwords, gazs, labels 309 | word_Ids, biword_Ids, gazs_Ids, label_Ids 310 | """ 311 | self.fix_alphabet() 312 | if name == "train": 313 | self.train_texts, self.train_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 314 | elif name == "dev": 315 | self.dev_texts, self.dev_Ids = read_instance_with_gaz_no_char(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 316 | elif name == "test": 317 | self.test_texts, self.test_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 318 | elif name == "raw": 319 | self.raw_texts, self.raw_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 320 | else: 321 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 322 | 323 | 324 | def write_decoded_results(self, output_file, predict_results, name): 325 | fout = open(output_file,'w', encoding="utf-8") 326 | sent_num = len(predict_results) 327 | content_list = [] 328 | if name == 'raw': 329 | content_list = self.raw_texts 330 | elif name == 'test': 331 | content_list = self.test_texts 332 | elif name == 'dev': 333 | content_list = self.dev_texts 334 | elif name == 'train': 335 | content_list = self.train_texts 336 | else: 337 | print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !") 338 | assert(sent_num == len(content_list)) 339 | for idx in range(sent_num): 340 | sent_length = len(predict_results[idx]) 341 | for idy in range(sent_length): 342 | ## content_list[idx] is a list with [word, char, label] 343 | fout.write(content_list[idx][0][idy].encode('utf-8') + " " + predict_results[idx][idy] + '\n') 344 | 345 | fout.write('\n') 346 | fout.close() 347 | print("Predict %s result has been written into file. %s"%(name, output_file)) 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | -------------------------------------------------------------------------------- /ner/utils/note4_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import numpy as np 5 | import pickle 6 | from ner.utils.alphabet import Alphabet 7 | from ner.utils.functions import * 8 | from ner.utils.gazetteer import Gazetteer 9 | 10 | 11 | 12 | START = "" 13 | UNKNOWN = "" 14 | PADDING = "" 15 | NULLKEY = "-null-" 16 | 17 | class Data: 18 | def __init__(self): 19 | self.MAX_SENTENCE_LENGTH = 250 20 | self.MAX_WORD_LENGTH = -1 21 | self.number_normalized = True 22 | self.norm_word_emb = True 23 | self.norm_biword_emb = True 24 | self.norm_gaz_emb = False 25 | self.use_single = True 26 | self.word_alphabet = Alphabet('word') 27 | self.biword_alphabet = Alphabet('biword') 28 | self.char_alphabet = Alphabet('character') 29 | # self.word_alphabet.add(START) 30 | # self.word_alphabet.add(UNKNOWN) 31 | # self.char_alphabet.add(START) 32 | # self.char_alphabet.add(UNKNOWN) 33 | # self.char_alphabet.add(PADDING) 34 | self.label_alphabet = Alphabet('label', True) 35 | self.gaz_lower = False 36 | self.gaz = Gazetteer(self.gaz_lower, self.use_single) 37 | self.gaz_alphabet = Alphabet('gaz') 38 | self.HP_fix_gaz_emb = False 39 | self.HP_use_gaz = True 40 | 41 | self.tagScheme = "NoSeg" 42 | self.char_features = "LSTM" 43 | 44 | self.train_texts = [] 45 | self.dev_texts = [] 46 | self.test_texts = [] 47 | self.raw_texts = [] 48 | 49 | self.train_Ids = [] 50 | self.dev_Ids = [] 51 | self.test_Ids = [] 52 | self.raw_Ids = [] 53 | self.use_bigram = True 54 | self.word_emb_dim = 100 55 | self.biword_emb_dim = 50 56 | self.char_emb_dim = 30 57 | self.gaz_emb_dim = 50 58 | self.gaz_dropout = 0.5 59 | self.pretrain_word_embedding = None 60 | self.pretrain_biword_embedding = None 61 | self.pretrain_gaz_embedding = None 62 | self.label_size = 0 63 | self.word_alphabet_size = 0 64 | self.biword_alphabet_size = 0 65 | self.char_alphabet_size = 0 66 | self.label_alphabet_size = 0 67 | ### hyperparameters 68 | self.HP_iteration = 100 69 | self.HP_batch_size = 10 70 | self.HP_char_hidden_dim = 100 71 | self.HP_hidden_dim = 200 72 | self.HP_dropout = 0.5 73 | self.HP_lstm_layer = 1 74 | self.HP_bilstm = True 75 | self.HP_use_char = False 76 | self.HP_gpu = False 77 | self.HP_lr = 0.015 78 | self.HP_lr_decay = 0.05 79 | self.HP_clip = 5.0 80 | self.HP_momentum = 0 81 | 82 | 83 | def show_data_summary(self): 84 | print("DATA SUMMARY START:") 85 | print(" Tag scheme: %s"%(self.tagScheme)) 86 | print(" MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH)) 87 | print(" MAX WORD LENGTH: %s"%(self.MAX_WORD_LENGTH)) 88 | print(" Number normalized: %s"%(self.number_normalized)) 89 | print(" Use bigram: %s"%(self.use_bigram)) 90 | print(" Word alphabet size: %s"%(self.word_alphabet_size)) 91 | print(" Biword alphabet size: %s"%(self.biword_alphabet_size)) 92 | print(" Char alphabet size: %s"%(self.char_alphabet_size)) 93 | print(" Gaz alphabet size: %s"%(self.gaz_alphabet.size())) 94 | print(" Label alphabet size: %s"%(self.label_alphabet_size)) 95 | print(" Word embedding size: %s"%(self.word_emb_dim)) 96 | print(" Biword embedding size: %s"%(self.biword_emb_dim)) 97 | print(" Char embedding size: %s"%(self.char_emb_dim)) 98 | print(" Gaz embedding size: %s"%(self.gaz_emb_dim)) 99 | print(" Norm word emb: %s"%(self.norm_word_emb)) 100 | print(" Norm biword emb: %s"%(self.norm_biword_emb)) 101 | print(" Norm gaz emb: %s"%(self.norm_gaz_emb)) 102 | print(" Norm gaz dropout: %s"%(self.gaz_dropout)) 103 | print(" Train instance number: %s"%(len(self.train_texts))) 104 | print(" Dev instance number: %s"%(len(self.dev_texts))) 105 | print(" Test instance number: %s"%(len(self.test_texts))) 106 | print(" Raw instance number: %s"%(len(self.raw_texts))) 107 | print(" Hyperpara iteration: %s"%(self.HP_iteration)) 108 | print(" Hyperpara batch size: %s"%(self.HP_batch_size)) 109 | print(" Hyperpara lr: %s"%(self.HP_lr)) 110 | print(" Hyperpara lr_decay: %s"%(self.HP_lr_decay)) 111 | print(" Hyperpara HP_clip: %s"%(self.HP_clip)) 112 | print(" Hyperpara momentum: %s"%(self.HP_momentum)) 113 | print(" Hyperpara hidden_dim: %s"%(self.HP_hidden_dim)) 114 | print(" Hyperpara dropout: %s"%(self.HP_dropout)) 115 | print(" Hyperpara lstm_layer: %s"%(self.HP_lstm_layer)) 116 | print(" Hyperpara bilstm: %s"%(self.HP_bilstm)) 117 | print(" Hyperpara GPU: %s"%(self.HP_gpu)) 118 | print(" Hyperpara use_gaz: %s"%(self.HP_use_gaz)) 119 | print(" Hyperpara fix gaz emb: %s"%(self.HP_fix_gaz_emb)) 120 | print(" Hyperpara use_char: %s"%(self.HP_use_char)) 121 | if self.HP_use_char: 122 | print(" Char_features: %s"%(self.char_features)) 123 | print("DATA SUMMARY END.") 124 | sys.stdout.flush() 125 | 126 | def refresh_label_alphabet(self, input_file): 127 | old_size = self.label_alphabet_size 128 | self.label_alphabet.clear(True) 129 | in_lines = open(input_file,'r', encoding="utf-8").readlines() 130 | for line in in_lines: 131 | if len(line) > 2: 132 | pairs = line.strip().split() 133 | label = pairs[-1] 134 | self.label_alphabet.add(label) 135 | self.label_alphabet_size = self.label_alphabet.size() 136 | startS = False 137 | startB = False 138 | for label,_ in self.label_alphabet.iteritems(): 139 | if "S-" in label.upper(): 140 | startS = True 141 | elif "B-" in label.upper(): 142 | startB = True 143 | if startB: 144 | if startS: 145 | self.tagScheme = "BMES" 146 | else: 147 | self.tagScheme = "BIO" 148 | self.fix_alphabet() 149 | print("Refresh label alphabet finished: old:%s -> new:%s"%(old_size, self.label_alphabet_size)) 150 | 151 | 152 | 153 | def build_alphabet(self, input_file): 154 | in_lines = open(input_file, 'r', encoding="utf-8").readlines() 155 | for idx in range(len(in_lines)): 156 | line = in_lines[idx] 157 | if len(line) > 2: 158 | pairs = line.strip().split() 159 | word = pairs[0] 160 | if self.number_normalized: 161 | word = normalize_word(word) 162 | label = pairs[-1] 163 | self.label_alphabet.add(label) 164 | self.word_alphabet.add(word) 165 | if idx < len(in_lines) - 1 and len(in_lines[idx+1]) > 2: 166 | biword = word + in_lines[idx+1].strip().split()[0] 167 | else: 168 | biword = word + NULLKEY 169 | self.biword_alphabet.add(biword) 170 | for char in word: 171 | self.char_alphabet.add(char) 172 | self.word_alphabet_size = self.word_alphabet.size() 173 | self.biword_alphabet_size = self.biword_alphabet.size() 174 | self.char_alphabet_size = self.char_alphabet.size() 175 | self.label_alphabet_size = self.label_alphabet.size() 176 | startS = False 177 | startB = False 178 | for label,_ in self.label_alphabet.iteritems(): 179 | if "S-" in label.upper(): 180 | startS = True 181 | elif "B-" in label.upper(): 182 | startB = True 183 | if startB: 184 | if startS: 185 | self.tagScheme = "BMES" 186 | else: 187 | self.tagScheme = "BIO" 188 | 189 | 190 | def build_gaz_file(self, gaz_file): 191 | ## build gaz file,initial read gaz embedding file 192 | ## we only get the word, do not read embedding this step 193 | if gaz_file: 194 | fins = open(gaz_file, 'r', encoding="utf-8").readlines() 195 | for fin in fins: 196 | fin = fin.strip().split()[0] 197 | if fin: 198 | self.gaz.insert(fin, "one_source") 199 | print("Load gaz file: ", gaz_file, " total size:", self.gaz.size()) 200 | else: 201 | print("Gaz file is None, load nothing") 202 | 203 | 204 | def build_gaz_alphabet(self, input_file): 205 | """ 206 | based on the train, dev, test file, we only save the seb-sequence word that my be appear 207 | """ 208 | in_lines = open(input_file,'r', encoding="utf-8").readlines() 209 | word_list = [] 210 | for line in in_lines: 211 | if len(line) > 3: 212 | word = line.split()[0] 213 | if self.number_normalized: 214 | word = normalize_word(word) 215 | word_list.append(word) 216 | else: 217 | w_length = len(word_list) 218 | ## Travser from [0: n], [1: n] to [n-1: n] 219 | for idx in range(w_length): 220 | matched_entity = self.gaz.enumerateMatchList(word_list[idx:]) 221 | for entity in matched_entity: 222 | # print entity, self.gaz.searchId(entity),self.gaz.searchType(entity) 223 | self.gaz_alphabet.add(entity) 224 | word_list = [] 225 | print("gaz alphabet size:", self.gaz_alphabet.size()) 226 | 227 | 228 | def fix_alphabet(self): 229 | self.word_alphabet.close() 230 | self.biword_alphabet.close() 231 | self.char_alphabet.close() 232 | self.label_alphabet.close() 233 | self.gaz_alphabet.close() 234 | 235 | 236 | def build_word_pretrain_emb(self, emb_path): 237 | print("build word pretrain emb...") 238 | self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.word_emb_dim, self.norm_word_emb) 239 | 240 | def build_biword_pretrain_emb(self, emb_path): 241 | print("build biword pretrain emb...") 242 | self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(emb_path, self.biword_alphabet, self.biword_emb_dim, self.norm_biword_emb) 243 | 244 | def build_gaz_pretrain_emb(self, emb_path): 245 | print("build gaz pretrain emb...") 246 | self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(emb_path, self.gaz_alphabet, self.gaz_emb_dim, self.norm_gaz_emb) 247 | 248 | 249 | def generate_word_instance(self, input_file, name): 250 | """ 251 | every instance include: words, labels, word_ids, label_ids 252 | """ 253 | self.fix_alphabet() 254 | if name == "train": 255 | self.train_texts, self.train_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 256 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 257 | elif name == "dev": 258 | self.dev_texts, self.dev_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 259 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 260 | elif name == "test": 261 | self.test_texts, self.test_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 262 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 263 | elif name == "raw": 264 | self.raw_texts, self.raw_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 265 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 266 | else: 267 | print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name)) 268 | 269 | 270 | def generate_instance(self, input_file, name): 271 | """ 272 | every instance include: words, biwords, chars, labels, 273 | word_ids, biword_ids, char_ids, label_ids 274 | """ 275 | self.fix_alphabet() 276 | if name == "train": 277 | self.train_texts, self.train_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 278 | elif name == "dev": 279 | self.dev_texts, self.dev_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 280 | elif name == "test": 281 | self.test_texts, self.test_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 282 | elif name == "raw": 283 | self.raw_texts, self.raw_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 284 | else: 285 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 286 | 287 | 288 | def generate_instance_with_gaz(self, input_file, name): 289 | """ 290 | every instance include: words, biwords, chars, labels, gazs, 291 | word_ids, biword_ids, char_ids, label_ids, gaz_ids 292 | """ 293 | self.fix_alphabet() 294 | if name == "train": 295 | self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 296 | elif name == "dev": 297 | self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 298 | elif name == "test": 299 | self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 300 | elif name == "raw": 301 | self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 302 | else: 303 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 304 | 305 | def generate_instance_with_gaz_no_char(self, input_file, name): 306 | """ 307 | every instance include: 308 | words, biwords, gazs, labels 309 | word_Ids, biword_Ids, gazs_Ids, label_Ids 310 | """ 311 | self.fix_alphabet() 312 | if name == "train": 313 | self.train_texts, self.train_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 314 | elif name == "dev": 315 | self.dev_texts, self.dev_Ids = read_instance_with_gaz_no_char(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 316 | elif name == "test": 317 | self.test_texts, self.test_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 318 | elif name == "raw": 319 | self.raw_texts, self.raw_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 320 | else: 321 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 322 | 323 | 324 | def write_decoded_results(self, output_file, predict_results, name): 325 | fout = open(output_file,'w', encoding="utf-8") 326 | sent_num = len(predict_results) 327 | content_list = [] 328 | if name == 'raw': 329 | content_list = self.raw_texts 330 | elif name == 'test': 331 | content_list = self.test_texts 332 | elif name == 'dev': 333 | content_list = self.dev_texts 334 | elif name == 'train': 335 | content_list = self.train_texts 336 | else: 337 | print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !") 338 | assert(sent_num == len(content_list)) 339 | for idx in range(sent_num): 340 | sent_length = len(predict_results[idx]) 341 | for idy in range(sent_length): 342 | ## content_list[idx] is a list with [word, char, label] 343 | fout.write(content_list[idx][0][idy].encode('utf-8') + " " + predict_results[idx][idy] + '\n') 344 | 345 | fout.write('\n') 346 | fout.close() 347 | print("Predict %s result has been written into file. %s"%(name, output_file)) 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | -------------------------------------------------------------------------------- /ner/utils/resume_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import numpy as np 5 | import pickle 6 | from ner.utils.alphabet import Alphabet 7 | from ner.utils.functions import * 8 | from ner.utils.gazetteer import Gazetteer 9 | 10 | 11 | 12 | START = "" 13 | UNKNOWN = "" 14 | PADDING = "" 15 | NULLKEY = "-null-" 16 | 17 | class Data: 18 | def __init__(self): 19 | self.MAX_SENTENCE_LENGTH = 5000 20 | self.MAX_WORD_LENGTH = -1 21 | self.number_normalized = True 22 | self.norm_word_emb = True 23 | self.norm_biword_emb = True 24 | self.norm_gaz_emb = False 25 | self.use_single = False 26 | self.word_alphabet = Alphabet('word') 27 | self.biword_alphabet = Alphabet('biword') 28 | self.char_alphabet = Alphabet('character') 29 | # self.word_alphabet.add(START) 30 | # self.word_alphabet.add(UNKNOWN) 31 | # self.char_alphabet.add(START) 32 | # self.char_alphabet.add(UNKNOWN) 33 | # self.char_alphabet.add(PADDING) 34 | self.label_alphabet = Alphabet('label', True) 35 | self.gaz_lower = False 36 | self.gaz = Gazetteer(self.gaz_lower, self.use_single) 37 | self.gaz_alphabet = Alphabet('gaz') 38 | self.HP_fix_gaz_emb = False 39 | self.HP_use_gaz = True 40 | 41 | self.tagScheme = "NoSeg" 42 | self.char_features = "LSTM" 43 | 44 | self.train_texts = [] 45 | self.dev_texts = [] 46 | self.test_texts = [] 47 | self.raw_texts = [] 48 | 49 | self.train_Ids = [] 50 | self.dev_Ids = [] 51 | self.test_Ids = [] 52 | self.raw_Ids = [] 53 | self.use_bigram = True 54 | self.word_emb_dim = 100 55 | self.biword_emb_dim = 50 56 | self.char_emb_dim = 30 57 | self.gaz_emb_dim = 50 58 | self.gaz_dropout = 0.5 59 | self.pretrain_word_embedding = None 60 | self.pretrain_biword_embedding = None 61 | self.pretrain_gaz_embedding = None 62 | self.label_size = 0 63 | self.word_alphabet_size = 0 64 | self.biword_alphabet_size = 0 65 | self.char_alphabet_size = 0 66 | self.label_alphabet_size = 0 67 | ### hyperparameters 68 | self.HP_iteration = 100 69 | self.HP_batch_size = 10 70 | self.HP_char_hidden_dim = 100 71 | self.HP_hidden_dim = 100 72 | self.HP_dropout = 0.5 73 | self.HP_lstm_layer = 1 74 | self.HP_bilstm = True 75 | self.HP_use_char = False 76 | self.HP_gpu = False 77 | self.HP_lr = 0.015 78 | self.HP_lr_decay = 0.05 79 | self.HP_clip = 5.0 80 | self.HP_momentum = 0 81 | 82 | 83 | def show_data_summary(self): 84 | print("DATA SUMMARY START:") 85 | print(" Tag scheme: %s"%(self.tagScheme)) 86 | print(" MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH)) 87 | print(" MAX WORD LENGTH: %s"%(self.MAX_WORD_LENGTH)) 88 | print(" Number normalized: %s"%(self.number_normalized)) 89 | print(" Use bigram: %s"%(self.use_bigram)) 90 | print(" Word alphabet size: %s"%(self.word_alphabet_size)) 91 | print(" Biword alphabet size: %s"%(self.biword_alphabet_size)) 92 | print(" Char alphabet size: %s"%(self.char_alphabet_size)) 93 | print(" Gaz alphabet size: %s"%(self.gaz_alphabet.size())) 94 | print(" Label alphabet size: %s"%(self.label_alphabet_size)) 95 | print(" Word embedding size: %s"%(self.word_emb_dim)) 96 | print(" Biword embedding size: %s"%(self.biword_emb_dim)) 97 | print(" Char embedding size: %s"%(self.char_emb_dim)) 98 | print(" Gaz embedding size: %s"%(self.gaz_emb_dim)) 99 | print(" Norm word emb: %s"%(self.norm_word_emb)) 100 | print(" Norm biword emb: %s"%(self.norm_biword_emb)) 101 | print(" Norm gaz emb: %s"%(self.norm_gaz_emb)) 102 | print(" Norm gaz dropout: %s"%(self.gaz_dropout)) 103 | print(" Train instance number: %s"%(len(self.train_texts))) 104 | print(" Dev instance number: %s"%(len(self.dev_texts))) 105 | print(" Test instance number: %s"%(len(self.test_texts))) 106 | print(" Raw instance number: %s"%(len(self.raw_texts))) 107 | print(" Hyperpara iteration: %s"%(self.HP_iteration)) 108 | print(" Hyperpara batch size: %s"%(self.HP_batch_size)) 109 | print(" Hyperpara lr: %s"%(self.HP_lr)) 110 | print(" Hyperpara lr_decay: %s"%(self.HP_lr_decay)) 111 | print(" Hyperpara HP_clip: %s"%(self.HP_clip)) 112 | print(" Hyperpara momentum: %s"%(self.HP_momentum)) 113 | print(" Hyperpara hidden_dim: %s"%(self.HP_hidden_dim)) 114 | print(" Hyperpara dropout: %s"%(self.HP_dropout)) 115 | print(" Hyperpara lstm_layer: %s"%(self.HP_lstm_layer)) 116 | print(" Hyperpara bilstm: %s"%(self.HP_bilstm)) 117 | print(" Hyperpara GPU: %s"%(self.HP_gpu)) 118 | print(" Hyperpara use_gaz: %s"%(self.HP_use_gaz)) 119 | print(" Hyperpara fix gaz emb: %s"%(self.HP_fix_gaz_emb)) 120 | print(" Hyperpara use_char: %s"%(self.HP_use_char)) 121 | if self.HP_use_char: 122 | print(" Char_features: %s"%(self.char_features)) 123 | print("DATA SUMMARY END.") 124 | sys.stdout.flush() 125 | 126 | def refresh_label_alphabet(self, input_file): 127 | old_size = self.label_alphabet_size 128 | self.label_alphabet.clear(True) 129 | in_lines = open(input_file,'r', encoding="utf-8").readlines() 130 | for line in in_lines: 131 | if len(line) > 2: 132 | pairs = line.strip().split() 133 | label = pairs[-1] 134 | self.label_alphabet.add(label) 135 | self.label_alphabet_size = self.label_alphabet.size() 136 | startS = False 137 | startB = False 138 | for label,_ in self.label_alphabet.iteritems(): 139 | if "S-" in label.upper(): 140 | startS = True 141 | elif "B-" in label.upper(): 142 | startB = True 143 | if startB: 144 | if startS: 145 | self.tagScheme = "BMES" 146 | else: 147 | self.tagScheme = "BIO" 148 | self.fix_alphabet() 149 | print("Refresh label alphabet finished: old:%s -> new:%s"%(old_size, self.label_alphabet_size)) 150 | 151 | 152 | 153 | def build_alphabet(self, input_file): 154 | in_lines = open(input_file, 'r', encoding="utf-8").readlines() 155 | for idx in range(len(in_lines)): 156 | line = in_lines[idx] 157 | if len(line) > 2: 158 | pairs = line.strip().split() 159 | word = pairs[0] 160 | if self.number_normalized: 161 | word = normalize_word(word) 162 | label = pairs[-1] 163 | self.label_alphabet.add(label) 164 | self.word_alphabet.add(word) 165 | if idx < len(in_lines) - 1 and len(in_lines[idx+1]) > 2: 166 | biword = word + in_lines[idx+1].strip().split()[0] 167 | else: 168 | biword = word + NULLKEY 169 | self.biword_alphabet.add(biword) 170 | for char in word: 171 | self.char_alphabet.add(char) 172 | self.word_alphabet_size = self.word_alphabet.size() 173 | self.biword_alphabet_size = self.biword_alphabet.size() 174 | self.char_alphabet_size = self.char_alphabet.size() 175 | self.label_alphabet_size = self.label_alphabet.size() 176 | startS = False 177 | startB = False 178 | for label,_ in self.label_alphabet.iteritems(): 179 | if "S-" in label.upper(): 180 | startS = True 181 | elif "B-" in label.upper(): 182 | startB = True 183 | if startB: 184 | if startS: 185 | self.tagScheme = "BMES" 186 | else: 187 | self.tagScheme = "BIO" 188 | 189 | 190 | def build_gaz_file(self, gaz_file): 191 | ## build gaz file,initial read gaz embedding file 192 | ## we only get the word, do not read embedding this step 193 | if gaz_file: 194 | fins = open(gaz_file, 'r', encoding="utf-8").readlines() 195 | for fin in fins: 196 | fin = fin.strip().split()[0] 197 | if fin: 198 | self.gaz.insert(fin, "one_source") 199 | print("Load gaz file: ", gaz_file, " total size:", self.gaz.size()) 200 | else: 201 | print("Gaz file is None, load nothing") 202 | 203 | 204 | def build_gaz_alphabet(self, input_file): 205 | """ 206 | based on the train, dev, test file, we only save the seb-sequence word that my be appear 207 | """ 208 | in_lines = open(input_file,'r', encoding="utf-8").readlines() 209 | word_list = [] 210 | for line in in_lines: 211 | if len(line) > 3: 212 | word = line.split()[0] 213 | if self.number_normalized: 214 | word = normalize_word(word) 215 | word_list.append(word) 216 | else: 217 | w_length = len(word_list) 218 | ## Travser from [0: n], [1: n] to [n-1: n] 219 | for idx in range(w_length): 220 | matched_entity = self.gaz.enumerateMatchList(word_list[idx:]) 221 | for entity in matched_entity: 222 | # print entity, self.gaz.searchId(entity),self.gaz.searchType(entity) 223 | self.gaz_alphabet.add(entity) 224 | word_list = [] 225 | print("gaz alphabet size:", self.gaz_alphabet.size()) 226 | 227 | 228 | def fix_alphabet(self): 229 | self.word_alphabet.close() 230 | self.biword_alphabet.close() 231 | self.char_alphabet.close() 232 | self.label_alphabet.close() 233 | self.gaz_alphabet.close() 234 | 235 | 236 | def build_word_pretrain_emb(self, emb_path): 237 | print("build word pretrain emb...") 238 | self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.word_emb_dim, self.norm_word_emb) 239 | 240 | def build_biword_pretrain_emb(self, emb_path): 241 | print("build biword pretrain emb...") 242 | self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(emb_path, self.biword_alphabet, self.biword_emb_dim, self.norm_biword_emb) 243 | 244 | def build_gaz_pretrain_emb(self, emb_path): 245 | print("build gaz pretrain emb...") 246 | self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(emb_path, self.gaz_alphabet, self.gaz_emb_dim, self.norm_gaz_emb) 247 | 248 | 249 | def generate_word_instance(self, input_file, name): 250 | """ 251 | every instance include: words, labels, word_ids, label_ids 252 | """ 253 | self.fix_alphabet() 254 | if name == "train": 255 | self.train_texts, self.train_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 256 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 257 | elif name == "dev": 258 | self.dev_texts, self.dev_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 259 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 260 | elif name == "test": 261 | self.test_texts, self.test_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 262 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 263 | elif name == "raw": 264 | self.raw_texts, self.raw_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 265 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 266 | else: 267 | print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name)) 268 | 269 | 270 | def generate_instance(self, input_file, name): 271 | """ 272 | every instance include: words, biwords, chars, labels, 273 | word_ids, biword_ids, char_ids, label_ids 274 | """ 275 | self.fix_alphabet() 276 | if name == "train": 277 | self.train_texts, self.train_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 278 | elif name == "dev": 279 | self.dev_texts, self.dev_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 280 | elif name == "test": 281 | self.test_texts, self.test_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 282 | elif name == "raw": 283 | self.raw_texts, self.raw_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 284 | else: 285 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 286 | 287 | 288 | def generate_instance_with_gaz(self, input_file, name): 289 | """ 290 | every instance include: words, biwords, chars, labels, gazs, 291 | word_ids, biword_ids, char_ids, label_ids, gaz_ids 292 | """ 293 | self.fix_alphabet() 294 | if name == "train": 295 | self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 296 | elif name == "dev": 297 | self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 298 | elif name == "test": 299 | self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 300 | elif name == "raw": 301 | self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 302 | else: 303 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 304 | 305 | def generate_instance_with_gaz_no_char(self, input_file, name): 306 | """ 307 | every instance include: 308 | words, biwords, gazs, labels 309 | word_Ids, biword_Ids, gazs_Ids, label_Ids 310 | """ 311 | self.fix_alphabet() 312 | if name == "train": 313 | self.train_texts, self.train_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 314 | elif name == "dev": 315 | self.dev_texts, self.dev_Ids = read_instance_with_gaz_no_char(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 316 | elif name == "test": 317 | self.test_texts, self.test_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 318 | elif name == "raw": 319 | self.raw_texts, self.raw_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 320 | else: 321 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 322 | 323 | 324 | def write_decoded_results(self, output_file, predict_results, name): 325 | fout = open(output_file,'w', encoding="utf-8") 326 | sent_num = len(predict_results) 327 | content_list = [] 328 | if name == 'raw': 329 | content_list = self.raw_texts 330 | elif name == 'test': 331 | content_list = self.test_texts 332 | elif name == 'dev': 333 | content_list = self.dev_texts 334 | elif name == 'train': 335 | content_list = self.train_texts 336 | else: 337 | print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !") 338 | assert(sent_num == len(content_list)) 339 | for idx in range(sent_num): 340 | sent_length = len(predict_results[idx]) 341 | for idy in range(sent_length): 342 | ## content_list[idx] is a list with [word, char, label] 343 | fout.write(content_list[idx][0][idy].encode('utf-8') + " " + predict_results[idx][idy] + '\n') 344 | 345 | fout.write('\n') 346 | fout.close() 347 | print("Predict %s result has been written into file. %s"%(name, output_file)) 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | -------------------------------------------------------------------------------- /ner/utils/trie.py: -------------------------------------------------------------------------------- 1 | import collections 2 | class TrieNode: 3 | # Initialize your data structure here. 4 | def __init__(self): 5 | self.children = collections.defaultdict(TrieNode) 6 | self.is_word = False 7 | 8 | class Trie: 9 | """ 10 | In fact, this Trie is a letter three. 11 | root is a fake node, its function is only the begin of a word, same as 12 | 13 | the the first layer is all the word's possible first letter, for example, '中国' 14 | its first letter is '中' 15 | the second the layer is all the word's possible second letter. 16 | and so on 17 | """ 18 | def __init__(self, use_single): 19 | self.root = TrieNode() 20 | if use_single: 21 | self.min_len = 0 22 | else: 23 | self.min_len = 1 24 | 25 | def insert(self, word): 26 | 27 | current = self.root 28 | # Traversing all the letter in the chinese word, util the last letter 29 | for letter in word: 30 | current = current.children[letter] 31 | current.is_word = True 32 | 33 | def search(self, word): 34 | current = self.root 35 | for letter in word: 36 | current = current.children.get(letter) 37 | 38 | if current is None: 39 | return False 40 | return current.is_word 41 | 42 | def startsWith(self, prefix): 43 | current = self.root 44 | for letter in prefix: 45 | current = current.children.get(letter) 46 | if current is None: 47 | return False 48 | return True 49 | 50 | 51 | def enumerateMatch(self, word, space="_", backward=False): 52 | matched = [] 53 | ## while len(word) > 1 does not keep character itself, while word keed character itself 54 | while len(word) > self.min_len: 55 | if self.search(word): 56 | matched.append(space.join(word[:])) 57 | 58 | # del the last letter one by one 59 | del word[-1] 60 | # the matched is all the possible sub-sequence word in the give sequence word 61 | return matched 62 | 63 | -------------------------------------------------------------------------------- /ner/utils/weibo_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import numpy as np 5 | import pickle 6 | from ner.utils.alphabet import Alphabet 7 | from ner.utils.functions import * 8 | from ner.utils.gazetteer import Gazetteer 9 | 10 | 11 | 12 | START = "" 13 | UNKNOWN = "" 14 | PADDING = "" 15 | NULLKEY = "-null-" 16 | 17 | class Data: 18 | def __init__(self): 19 | self.MAX_SENTENCE_LENGTH = 250 20 | self.MAX_WORD_LENGTH = -1 21 | self.number_normalized = True 22 | self.norm_word_emb = True 23 | self.norm_biword_emb = True 24 | self.norm_gaz_emb = False 25 | self.use_single = True 26 | self.word_alphabet = Alphabet('word') 27 | self.biword_alphabet = Alphabet('biword') 28 | self.char_alphabet = Alphabet('character') 29 | # self.word_alphabet.add(START) 30 | # self.word_alphabet.add(UNKNOWN) 31 | # self.char_alphabet.add(START) 32 | # self.char_alphabet.add(UNKNOWN) 33 | # self.char_alphabet.add(PADDING) 34 | self.label_alphabet = Alphabet('label', True) 35 | self.gaz_lower = False 36 | self.gaz = Gazetteer(self.gaz_lower, self.use_single) 37 | self.gaz_alphabet = Alphabet('gaz') 38 | self.HP_fix_gaz_emb = False 39 | self.HP_use_gaz = True 40 | 41 | self.tagScheme = "NoSeg" 42 | self.char_features = "LSTM" 43 | 44 | self.train_texts = [] 45 | self.dev_texts = [] 46 | self.test_texts = [] 47 | self.raw_texts = [] 48 | 49 | self.train_Ids = [] 50 | self.dev_Ids = [] 51 | self.test_Ids = [] 52 | self.raw_Ids = [] 53 | self.use_bigram = True 54 | self.word_emb_dim = 100 55 | self.biword_emb_dim = 50 56 | self.char_emb_dim = 30 57 | self.gaz_emb_dim = 50 58 | self.gaz_dropout = 0.5 59 | self.pretrain_word_embedding = None 60 | self.pretrain_biword_embedding = None 61 | self.pretrain_gaz_embedding = None 62 | self.label_size = 0 63 | self.word_alphabet_size = 0 64 | self.biword_alphabet_size = 0 65 | self.char_alphabet_size = 0 66 | self.label_alphabet_size = 0 67 | ### hyperparameters 68 | self.HP_iteration = 100 69 | self.HP_batch_size = 10 70 | self.HP_char_hidden_dim = 100 71 | self.HP_hidden_dim = 100 72 | self.HP_dropout = 0.5 73 | self.HP_lstm_layer = 1 74 | self.HP_bilstm = True 75 | self.HP_use_char = False 76 | self.HP_gpu = False 77 | self.HP_lr = 0.015 78 | self.HP_lr_decay = 0.05 79 | self.HP_clip = 5.0 80 | self.HP_momentum = 0 81 | 82 | 83 | def show_data_summary(self): 84 | print("DATA SUMMARY START:") 85 | print(" Tag scheme: %s"%(self.tagScheme)) 86 | print(" MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH)) 87 | print(" MAX WORD LENGTH: %s"%(self.MAX_WORD_LENGTH)) 88 | print(" Number normalized: %s"%(self.number_normalized)) 89 | print(" Use bigram: %s"%(self.use_bigram)) 90 | print(" Word alphabet size: %s"%(self.word_alphabet_size)) 91 | print(" Biword alphabet size: %s"%(self.biword_alphabet_size)) 92 | print(" Char alphabet size: %s"%(self.char_alphabet_size)) 93 | print(" Gaz alphabet size: %s"%(self.gaz_alphabet.size())) 94 | print(" Label alphabet size: %s"%(self.label_alphabet_size)) 95 | print(" Word embedding size: %s"%(self.word_emb_dim)) 96 | print(" Biword embedding size: %s"%(self.biword_emb_dim)) 97 | print(" Char embedding size: %s"%(self.char_emb_dim)) 98 | print(" Gaz embedding size: %s"%(self.gaz_emb_dim)) 99 | print(" Norm word emb: %s"%(self.norm_word_emb)) 100 | print(" Norm biword emb: %s"%(self.norm_biword_emb)) 101 | print(" Norm gaz emb: %s"%(self.norm_gaz_emb)) 102 | print(" Norm gaz dropout: %s"%(self.gaz_dropout)) 103 | print(" Train instance number: %s"%(len(self.train_texts))) 104 | print(" Dev instance number: %s"%(len(self.dev_texts))) 105 | print(" Test instance number: %s"%(len(self.test_texts))) 106 | print(" Raw instance number: %s"%(len(self.raw_texts))) 107 | print(" Hyperpara iteration: %s"%(self.HP_iteration)) 108 | print(" Hyperpara batch size: %s"%(self.HP_batch_size)) 109 | print(" Hyperpara lr: %s"%(self.HP_lr)) 110 | print(" Hyperpara lr_decay: %s"%(self.HP_lr_decay)) 111 | print(" Hyperpara HP_clip: %s"%(self.HP_clip)) 112 | print(" Hyperpara momentum: %s"%(self.HP_momentum)) 113 | print(" Hyperpara hidden_dim: %s"%(self.HP_hidden_dim)) 114 | print(" Hyperpara dropout: %s"%(self.HP_dropout)) 115 | print(" Hyperpara lstm_layer: %s"%(self.HP_lstm_layer)) 116 | print(" Hyperpara bilstm: %s"%(self.HP_bilstm)) 117 | print(" Hyperpara GPU: %s"%(self.HP_gpu)) 118 | print(" Hyperpara use_gaz: %s"%(self.HP_use_gaz)) 119 | print(" Hyperpara fix gaz emb: %s"%(self.HP_fix_gaz_emb)) 120 | print(" Hyperpara use_char: %s"%(self.HP_use_char)) 121 | if self.HP_use_char: 122 | print(" Char_features: %s"%(self.char_features)) 123 | print("DATA SUMMARY END.") 124 | sys.stdout.flush() 125 | 126 | def refresh_label_alphabet(self, input_file): 127 | old_size = self.label_alphabet_size 128 | self.label_alphabet.clear(True) 129 | in_lines = open(input_file,'r', encoding="utf-8").readlines() 130 | for line in in_lines: 131 | if len(line) > 2: 132 | pairs = line.strip().split() 133 | label = pairs[-1] 134 | self.label_alphabet.add(label) 135 | self.label_alphabet_size = self.label_alphabet.size() 136 | startS = False 137 | startB = False 138 | for label,_ in self.label_alphabet.iteritems(): 139 | if "S-" in label.upper(): 140 | startS = True 141 | elif "B-" in label.upper(): 142 | startB = True 143 | if startB: 144 | if startS: 145 | self.tagScheme = "BMES" 146 | else: 147 | self.tagScheme = "BIO" 148 | self.fix_alphabet() 149 | print("Refresh label alphabet finished: old:%s -> new:%s"%(old_size, self.label_alphabet_size)) 150 | 151 | 152 | 153 | def build_alphabet(self, input_file): 154 | in_lines = open(input_file, 'r', encoding="utf-8").readlines() 155 | for idx in range(len(in_lines)): 156 | line = in_lines[idx] 157 | if len(line) > 2: 158 | pairs = line.strip().split() 159 | word = pairs[0] 160 | if self.number_normalized: 161 | word = normalize_word(word) 162 | label = pairs[-1] 163 | self.label_alphabet.add(label) 164 | self.word_alphabet.add(word) 165 | if idx < len(in_lines) - 1 and len(in_lines[idx+1]) > 2: 166 | biword = word + in_lines[idx+1].strip().split()[0] 167 | else: 168 | biword = word + NULLKEY 169 | self.biword_alphabet.add(biword) 170 | for char in word: 171 | self.char_alphabet.add(char) 172 | self.word_alphabet_size = self.word_alphabet.size() 173 | self.biword_alphabet_size = self.biword_alphabet.size() 174 | self.char_alphabet_size = self.char_alphabet.size() 175 | self.label_alphabet_size = self.label_alphabet.size() 176 | startS = False 177 | startB = False 178 | for label,_ in self.label_alphabet.iteritems(): 179 | if "S-" in label.upper(): 180 | startS = True 181 | elif "B-" in label.upper(): 182 | startB = True 183 | if startB: 184 | if startS: 185 | self.tagScheme = "BMES" 186 | else: 187 | self.tagScheme = "BIO" 188 | 189 | 190 | def build_gaz_file(self, gaz_file): 191 | ## build gaz file,initial read gaz embedding file 192 | ## we only get the word, do not read embedding this step 193 | if gaz_file: 194 | fins = open(gaz_file, 'r', encoding="utf-8").readlines() 195 | for fin in fins: 196 | fin = fin.strip().split()[0] 197 | if fin: 198 | self.gaz.insert(fin, "one_source") 199 | print("Load gaz file: ", gaz_file, " total size:", self.gaz.size()) 200 | else: 201 | print("Gaz file is None, load nothing") 202 | 203 | 204 | def build_gaz_alphabet(self, input_file): 205 | """ 206 | based on the train, dev, test file, we only save the seb-sequence word that my be appear 207 | """ 208 | in_lines = open(input_file,'r', encoding="utf-8").readlines() 209 | word_list = [] 210 | for line in in_lines: 211 | if len(line) > 3: 212 | word = line.split()[0] 213 | if self.number_normalized: 214 | word = normalize_word(word) 215 | word_list.append(word) 216 | else: 217 | w_length = len(word_list) 218 | ## Travser from [0: n], [1: n] to [n-1: n] 219 | for idx in range(w_length): 220 | matched_entity = self.gaz.enumerateMatchList(word_list[idx:]) 221 | for entity in matched_entity: 222 | # print entity, self.gaz.searchId(entity),self.gaz.searchType(entity) 223 | self.gaz_alphabet.add(entity) 224 | word_list = [] 225 | print("gaz alphabet size:", self.gaz_alphabet.size()) 226 | 227 | 228 | def fix_alphabet(self): 229 | self.word_alphabet.close() 230 | self.biword_alphabet.close() 231 | self.char_alphabet.close() 232 | self.label_alphabet.close() 233 | self.gaz_alphabet.close() 234 | 235 | 236 | def build_word_pretrain_emb(self, emb_path): 237 | print("build word pretrain emb...") 238 | self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.word_emb_dim, self.norm_word_emb) 239 | 240 | def build_biword_pretrain_emb(self, emb_path): 241 | print("build biword pretrain emb...") 242 | self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(emb_path, self.biword_alphabet, self.biword_emb_dim, self.norm_biword_emb) 243 | 244 | def build_gaz_pretrain_emb(self, emb_path): 245 | print("build gaz pretrain emb...") 246 | self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(emb_path, self.gaz_alphabet, self.gaz_emb_dim, self.norm_gaz_emb) 247 | 248 | 249 | def generate_word_instance(self, input_file, name): 250 | """ 251 | every instance include: words, labels, word_ids, label_ids 252 | """ 253 | self.fix_alphabet() 254 | if name == "train": 255 | self.train_texts, self.train_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 256 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 257 | elif name == "dev": 258 | self.dev_texts, self.dev_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 259 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 260 | elif name == "test": 261 | self.test_texts, self.test_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 262 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 263 | elif name == "raw": 264 | self.raw_texts, self.raw_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet, 265 | self.number_normalized, self.MAX_SENTENCE_LENGTH) 266 | else: 267 | print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name)) 268 | 269 | 270 | def generate_instance(self, input_file, name): 271 | """ 272 | every instance include: words, biwords, chars, labels, 273 | word_ids, biword_ids, char_ids, label_ids 274 | """ 275 | self.fix_alphabet() 276 | if name == "train": 277 | self.train_texts, self.train_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 278 | elif name == "dev": 279 | self.dev_texts, self.dev_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 280 | elif name == "test": 281 | self.test_texts, self.test_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 282 | elif name == "raw": 283 | self.raw_texts, self.raw_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 284 | else: 285 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 286 | 287 | 288 | def generate_instance_with_gaz(self, input_file, name): 289 | """ 290 | every instance include: words, biwords, chars, labels, gazs, 291 | word_ids, biword_ids, char_ids, label_ids, gaz_ids 292 | """ 293 | self.fix_alphabet() 294 | if name == "train": 295 | self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 296 | elif name == "dev": 297 | self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 298 | elif name == "test": 299 | self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 300 | elif name == "raw": 301 | self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) 302 | else: 303 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 304 | 305 | def generate_instance_with_gaz_no_char(self, input_file, name): 306 | """ 307 | every instance include: 308 | words, biwords, gazs, labels 309 | word_Ids, biword_Ids, gazs_Ids, label_Ids 310 | """ 311 | self.fix_alphabet() 312 | if name == "train": 313 | self.train_texts, self.train_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 314 | elif name == "dev": 315 | self.dev_texts, self.dev_Ids = read_instance_with_gaz_no_char(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 316 | elif name == "test": 317 | self.test_texts, self.test_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 318 | elif name == "raw": 319 | self.raw_texts, self.raw_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single) 320 | else: 321 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name)) 322 | 323 | 324 | def write_decoded_results(self, output_file, predict_results, name): 325 | fout = open(output_file,'w', encoding="utf-8") 326 | sent_num = len(predict_results) 327 | content_list = [] 328 | if name == 'raw': 329 | content_list = self.raw_texts 330 | elif name == 'test': 331 | content_list = self.test_texts 332 | elif name == 'dev': 333 | content_list = self.dev_texts 334 | elif name == 'train': 335 | content_list = self.train_texts 336 | else: 337 | print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !") 338 | assert(sent_num == len(content_list)) 339 | for idx in range(sent_num): 340 | sent_length = len(predict_results[idx]) 341 | for idy in range(sent_length): 342 | ## content_list[idx] is a list with [word, char, label] 343 | fout.write(content_list[idx][0][idy].encode('utf-8') + " " + predict_results[idx][idy] + '\n') 344 | 345 | fout.write('\n') 346 | fout.close() 347 | print("Predict %s result has been written into file. %s"%(name, output_file)) 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | -------------------------------------------------------------------------------- /note4.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "liuwei" 3 | 4 | 5 | import time 6 | import sys 7 | import argparse 8 | import random 9 | import copy 10 | import torch 11 | import gc 12 | import pickle 13 | import torch.autograd as autograd 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | import numpy as np 18 | import logging 19 | import re 20 | import os 21 | 22 | from ner.utils.metric import get_ner_fmeasure 23 | from ner.lw.cw_ner import CW_NER as SeqModel 24 | from ner.utils.note4_data import Data 25 | from tensorboardX import SummaryWriter 26 | from ner.lw.tbx_writer import TensorboardWriter 27 | 28 | 29 | seed_num = 100 30 | random.seed(seed_num) 31 | torch.manual_seed(seed_num) 32 | np.random.seed(seed_num) 33 | 34 | # set logger 35 | logger = logging.getLogger(__name__) 36 | logger.setLevel(logging.INFO) 37 | BASIC_FORMAT = "%(asctime)s:%(levelname)s: %(message)s" 38 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S' 39 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT) 40 | chlr = logging.StreamHandler() # 输出到控制台的handler 41 | chlr.setFormatter(formatter) 42 | logger.addHandler(chlr) 43 | 44 | 45 | # for checkpoint 46 | max_model_num = 100 47 | old_model_paths = [] 48 | 49 | # tensorboard writer 50 | log_dir = "data/log" 51 | 52 | train_log = SummaryWriter(os.path.join(log_dir, "train")) 53 | validation_log = SummaryWriter(os.path.join(log_dir, "validation")) 54 | test_log = SummaryWriter(os.path.join(log_dir, "test")) 55 | tensorboard = TensorboardWriter(train_log, validation_log, test_log) 56 | 57 | 58 | def data_initialization(data, gaz_file, train_file, dev_file, test_file): 59 | data.build_alphabet(train_file) 60 | data.build_alphabet(dev_file) 61 | data.build_alphabet(test_file) 62 | data.build_gaz_file(gaz_file) 63 | data.build_gaz_alphabet(train_file) 64 | data.build_gaz_alphabet(dev_file) 65 | data.build_gaz_alphabet(test_file) 66 | data.fix_alphabet() 67 | return data 68 | 69 | 70 | def predict_check(pred_variable, gold_variable, mask_variable): 71 | """ 72 | input: 73 | pred_variable (batch_size, sent_len): pred tag result, in numpy format 74 | gold_variable (batch_size, sent_len): gold result variable 75 | mask_variable (batch_size, sent_len): mask variable 76 | """ 77 | pred = pred_variable.cpu().data.numpy() 78 | gold = gold_variable.cpu().data.numpy() 79 | mask = mask_variable.cpu().data.numpy() 80 | overlaped = (pred == gold) 81 | right_token = np.sum(overlaped * mask) 82 | total_token = mask.sum() 83 | # print("right: %s, total: %s"%(right_token, total_token)) 84 | return right_token, total_token 85 | 86 | 87 | def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet, word_recover): 88 | """ 89 | input: 90 | pred_variable (batch_size, sent_len): pred tag result 91 | gold_variable (batch_size, sent_len): gold result variable 92 | mask_variable (batch_size, sent_len): mask variable 93 | """ 94 | 95 | pred_variable = pred_variable[word_recover] 96 | gold_variable = gold_variable[word_recover] 97 | mask_variable = mask_variable[word_recover] 98 | batch_size = gold_variable.size(0) 99 | seq_len = gold_variable.size(1) 100 | mask = mask_variable.cpu().data.numpy() 101 | pred_tag = pred_variable.cpu().data.numpy() 102 | gold_tag = gold_variable.cpu().data.numpy() 103 | batch_size = mask.shape[0] 104 | pred_label = [] 105 | gold_label = [] 106 | for idx in range(batch_size): 107 | pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] 108 | gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] 109 | # print "p:",pred, pred_tag.tolist() 110 | # print "g:", gold, gold_tag.tolist() 111 | assert (len(pred) == len(gold)) 112 | pred_label.append(pred) 113 | gold_label.append(gold) 114 | return pred_label, gold_label 115 | 116 | 117 | def lr_decay(optimizer, epoch, decay_rate, init_lr): 118 | lr = init_lr * ((1 - decay_rate) ** epoch) 119 | print(" Learning rate is setted as:", lr) 120 | for param_group in optimizer.param_groups: 121 | param_group['lr'] = lr 122 | return optimizer 123 | 124 | 125 | def save_model(epoch, state, models_dir): 126 | """ 127 | save the model state 128 | Args: 129 | epoch: the number of epoch 130 | state: [model_state, training_state] 131 | models_dir: the dir to save model 132 | """ 133 | if models_dir is not None: 134 | model_path = os.path.join(models_dir, "model_state_epoch_{}.th".format(epoch)) 135 | train_path = os.path.join(models_dir, "training_state_epoch_{}.th".format(epoch)) 136 | 137 | model_state, training_state = state 138 | torch.save(model_state, model_path) 139 | torch.save(training_state, train_path) 140 | 141 | if max_model_num > 0: 142 | old_model_paths.append([model_path, train_path]) 143 | if len(old_model_paths) > max_model_num: 144 | paths_to_remove = old_model_paths.pop(0) 145 | 146 | for fname in paths_to_remove: 147 | os.remove(fname) 148 | 149 | def find_last_model(models_dir): 150 | """ 151 | find the lastes checkpoint file 152 | Args: 153 | models_dir: the dir save models 154 | """ 155 | epoch_num = 0 156 | # return models_dir + "/model_state_epoch_{}.th".format(epoch_num), models_dir + "/training_state_epoch_{}.th".format(epoch_num) 157 | 158 | if models_dir is None: 159 | return None 160 | 161 | saved_models_path = os.listdir(models_dir) 162 | saved_models_path = [x for x in saved_models_path if 'model_state_epoch' in x] 163 | 164 | if len(saved_models_path) == 0: 165 | return None 166 | 167 | found_epochs = [ 168 | re.search("model_state_epoch_([0-9]+).th", x).group(1) 169 | for x in saved_models_path 170 | ] 171 | int_epochs = [int(epoch) for epoch in found_epochs] 172 | print("len: ", len(int_epochs)) 173 | last_epoch = sorted(int_epochs, reverse=True)[0] 174 | epoch_to_load = "{}".format(last_epoch) 175 | 176 | model_path = os.path.join(models_dir, "model_state_epoch_{}.th".format(epoch_to_load)) 177 | training_state_path = os.path.join(models_dir, "training_state_epoch_{}.th".format(epoch_to_load)) 178 | 179 | return model_path, training_state_path 180 | 181 | 182 | def restore_model(models_dir): 183 | """ 184 | restore the lastes checkpoint file 185 | """ 186 | lastest_checkpoint = find_last_model(models_dir) 187 | 188 | if lastest_checkpoint is None: 189 | return None 190 | else: 191 | model_path, training_state_path = lastest_checkpoint 192 | 193 | model_state = torch.load(model_path) 194 | training_state = torch.load(training_state_path) 195 | 196 | return (model_state, training_state) 197 | 198 | 199 | def evaluate(data, model, name): 200 | if name == "train": 201 | instances = data.train_Ids 202 | elif name == "dev": 203 | instances = data.dev_Ids 204 | elif name == 'test': 205 | instances = data.test_Ids 206 | elif name == 'raw': 207 | instances = data.raw_Ids 208 | else: 209 | print("Error: wrong evaluate name,", name) 210 | right_token = 0 211 | whole_token = 0 212 | pred_results = [] 213 | gold_results = [] 214 | ## set model in eval model 215 | model.eval() 216 | batch_size = 8 217 | start_time = time.time() 218 | train_num = len(instances) 219 | total_batch = train_num // batch_size + 1 220 | for batch_id in range(total_batch): 221 | start = batch_id * batch_size 222 | end = (batch_id + 1) * batch_size 223 | if end > train_num: 224 | end = train_num 225 | instance = instances[start:end] 226 | if not instance: 227 | continue 228 | gaz_list, reverse_gaz_list, batch_word, batch_biword, batch_wordlen, batch_wordrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu, True) 229 | tag_seq = model(gaz_list, reverse_gaz_list, batch_word, batch_wordlen, mask) 230 | pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet, batch_wordrecover) 231 | pred_results += pred_label 232 | gold_results += gold_label 233 | decode_time = time.time() - start_time 234 | speed = len(instances) / decode_time 235 | acc, p, r, f = get_ner_fmeasure(gold_results, pred_results, data.tagScheme) 236 | return speed, acc, p, r, f, pred_results 237 | 238 | 239 | def batchify_with_label(input_batch_list, gpu, volatile_flag=False): 240 | """ 241 | input: list of words, chars and labels, various length. [[words,biwords,chars,gaz, labels],[words,biwords,chars,labels],...] 242 | words: word ids for one sentence. (batch_size, sent_len) 243 | chars: char ids for on sentences, various length. (batch_size, sent_len, each_word_length) 244 | output: 245 | zero padding for word and char, with their batch length 246 | word_seq_tensor: (batch_size, max_sent_len) Variable 247 | word_seq_lengths: (batch_size,1) Tensor 248 | char_seq_recover: (batch_size*max_sent_len,1) recover char sequence order 249 | label_seq_tensor: (batch_size, max_sent_len) 250 | mask: (batch_size, max_sent_len) 251 | """ 252 | batch_size = len(input_batch_list) 253 | words = [sent[0] for sent in input_batch_list] 254 | biwords = [sent[1] for sent in input_batch_list] 255 | gazs = [sent[2] for sent in input_batch_list] 256 | reverse_gazs = [sent[3] for sent in input_batch_list] 257 | labels = [sent[4] for sent in input_batch_list] 258 | word_seq_lengths = torch.LongTensor(list(map(len, words))) 259 | max_seq_len = word_seq_lengths.max() 260 | word_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long() 261 | biword_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long() 262 | label_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long() 263 | mask = autograd.Variable(torch.zeros((batch_size, max_seq_len))).byte() 264 | for idx, (seq, biseq, label, seqlen) in enumerate(zip(words, biwords, labels, word_seq_lengths)): 265 | word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq) 266 | biword_seq_tensor[idx, :seqlen] = torch.LongTensor(biseq) 267 | label_seq_tensor[idx, :seqlen] = torch.LongTensor(label) 268 | mask[idx, :seqlen] = torch.Tensor([1] * int(seqlen)) 269 | word_seq_lengths, word_perm_idx = word_seq_lengths.sort(0, descending=True) 270 | word_seq_tensor = word_seq_tensor[word_perm_idx] 271 | biword_seq_tensor = biword_seq_tensor[word_perm_idx] 272 | label_seq_tensor = label_seq_tensor[word_perm_idx] 273 | mask = mask[word_perm_idx] 274 | 275 | _, word_seq_recover = word_perm_idx.sort(0, descending=False) 276 | 277 | gaz_list = [gazs[i] for i in word_perm_idx] 278 | reverse_gaz_list = [reverse_gazs[i] for i in word_perm_idx] 279 | 280 | if gpu: 281 | word_seq_tensor = word_seq_tensor.cuda() 282 | biword_seq_tensor = biword_seq_tensor.cuda() 283 | word_seq_lengths = word_seq_lengths.cuda() 284 | word_seq_recover = word_seq_recover.cuda() 285 | label_seq_tensor = label_seq_tensor.cuda() 286 | mask = mask.cuda() 287 | 288 | # gaz_seq_tensor = gaz_seq_tensor.cuda() 289 | # gaz_seq_length = gaz_seq_length.cuda() 290 | # gaz_mask_tensor = gaz_mask_tensor.cuda() 291 | 292 | return gaz_list, reverse_gaz_list, word_seq_tensor, biword_seq_tensor, word_seq_lengths, word_seq_recover, label_seq_tensor, mask 293 | 294 | 295 | def train(data, save_model_dir, seg=True): 296 | print("Training model...") 297 | data.show_data_summary() 298 | model = SeqModel(data, type=5) 299 | print("finished built model.") 300 | loss_function = nn.NLLLoss() 301 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 302 | optimizer = optim.SGD(parameters, lr=data.HP_lr, momentum=data.HP_momentum) 303 | best_dev = -1 304 | data.HP_iteration = 75 305 | 306 | ## here we should restore the model 307 | state = restore_model(save_model_dir) 308 | epoch = 0 309 | if state is not None: 310 | model_state = state[0] 311 | training_state = state[1] 312 | 313 | model.load_state_dict(model_state) 314 | optimizer.load_state_dict(training_state['optimizer']) 315 | epoch = int(training_state['epoch']) 316 | 317 | batch_size = 8 318 | now_batch_num = epoch * len(data.train_Ids) // batch_size 319 | 320 | ## start training 321 | for idx in range(epoch, data.HP_iteration): 322 | epoch_start = time.time() 323 | temp_start = epoch_start 324 | print("Epoch: %s/%s" % (idx, data.HP_iteration)) 325 | optimizer = lr_decay(optimizer, idx, data.HP_lr_decay, data.HP_lr) 326 | instance_count = 0 327 | sample_id = 0 328 | sample_loss = 0 329 | batch_loss = 0 330 | total_loss = 0 331 | right_token = 0 332 | whole_token = 0 333 | random.shuffle(data.train_Ids) 334 | ## set model in train model 335 | model.train() 336 | model.zero_grad() 337 | batch_size = 8 338 | batch_id = 0 339 | train_num = len(data.train_Ids) 340 | total_batch = train_num // batch_size + 1 341 | for batch_id in range(total_batch): 342 | start = batch_id * batch_size 343 | end = (batch_id + 1) * batch_size 344 | if end > train_num: 345 | end = train_num 346 | instance = data.train_Ids[start:end] 347 | if not instance: 348 | continue 349 | gaz_list, reverse_gaz_list, batch_word, batch_biword, batch_wordlen, batch_wordrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu) 350 | # print "gaz_list:",gaz_list 351 | # exit(0) 352 | instance_count += (end - start) 353 | loss, tag_seq = model.neg_log_likelihood_loss(gaz_list, reverse_gaz_list, batch_word, batch_wordlen, batch_label, mask) 354 | right, whole = predict_check(tag_seq, batch_label, mask) 355 | right_token += right 356 | whole_token += whole 357 | sample_loss += loss.item() 358 | total_loss += loss.item() 359 | batch_loss += loss 360 | 361 | if end % 4000 == 0: 362 | temp_time = time.time() 363 | temp_cost = temp_time - temp_start 364 | temp_start = temp_time 365 | print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token)) 366 | sys.stdout.flush() 367 | sample_loss = 0 368 | if end % data.HP_batch_size == 0: 369 | batch_loss.backward() 370 | optimizer.step() 371 | model.zero_grad() 372 | batch_loss = 0 373 | 374 | tensorboard.add_train_scalar("train_loss", total_loss / (batch_id + 1), now_batch_num) 375 | tensorboard.add_train_scalar("train_accu", right_token / whole_token, now_batch_num) 376 | now_batch_num += 1 377 | 378 | temp_time = time.time() 379 | temp_cost = temp_time - temp_start 380 | print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token)) 381 | epoch_finish = time.time() 382 | epoch_cost = epoch_finish - epoch_start 383 | print("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s" % ( 384 | idx, epoch_cost, train_num / epoch_cost, total_loss)) 385 | # exit(0) 386 | # continue 387 | speed, acc, p, r, f, _ = evaluate(data, model, "dev") 388 | dev_finish = time.time() 389 | dev_cost = dev_finish - epoch_finish 390 | 391 | if seg: 392 | current_score = f 393 | print("Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (dev_cost, speed, acc, p, r, f)) 394 | 395 | tensorboard.add_validation_scalar("dev_accu", acc, idx) 396 | tensorboard.add_validation_scalar("dev_p", p, idx) 397 | tensorboard.add_validation_scalar("dev_r", r, idx) 398 | tensorboard.add_validation_scalar("dev_f", f, idx) 399 | 400 | else: 401 | current_score = acc 402 | print("Dev: time: %.2fs speed: %.2fst/s; acc: %.4f" % (dev_cost, speed, acc)) 403 | 404 | if current_score > best_dev: 405 | if seg: 406 | print("Exceed previous best f score:", best_dev) 407 | else: 408 | print("Exceed previous best acc score:", best_dev) 409 | # model_name = save_model_dir + '.' + str(idx) + ".model" 410 | # torch.save(model.state_dict(), model_name) 411 | best_dev = current_score 412 | # ## decode test 413 | speed, acc, p, r, f, _ = evaluate(data, model, "test") 414 | test_finish = time.time() 415 | test_cost = test_finish - dev_finish 416 | if seg: 417 | print("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (test_cost, speed, acc, p, r, f)) 418 | 419 | tensorboard.add_test_scalar("test_accu", acc, idx) 420 | tensorboard.add_test_scalar("test_p", p, idx) 421 | tensorboard.add_test_scalar("test_r", r, idx) 422 | tensorboard.add_test_scalar("test_f", f, idx) 423 | 424 | else: 425 | print("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f" % (test_cost, speed, acc)) 426 | gc.collect() 427 | 428 | ## save the model every epoch 429 | model_state = model.state_dict() 430 | training_state = { 431 | "epoch": idx, 432 | "optimizer": optimizer.state_dict() 433 | } 434 | save_model(idx, (model_state, training_state), save_model_dir) 435 | 436 | 437 | if __name__ == '__main__': 438 | parser = argparse.ArgumentParser(description='Tuning with bi-directional LSTM-CRF') 439 | parser.add_argument('--embedding', help='Embedding for words', default='None') 440 | parser.add_argument('--status', choices=['train', 'test', 'decode'], help='update algorithm', default='train') 441 | parser.add_argument('--savemodel', default="data/model/note4") 442 | parser.add_argument('--savedset', help='Dir of saved data setting', default="data/save.dset") 443 | parser.add_argument('--train', default="data/note4/train.char.bmes") 444 | parser.add_argument('--dev', default="data/note4/dev.char.bmes") 445 | parser.add_argument('--test', default="data/note4/test.char.bmes") 446 | parser.add_argument('--seg', default="True") 447 | parser.add_argument('--extendalphabet', default="True") 448 | parser.add_argument('--raw') 449 | parser.add_argument('--loadmodel') 450 | parser.add_argument('--output') 451 | args = parser.parse_args() 452 | 453 | train_file = args.train 454 | dev_file = args.dev 455 | test_file = args.test 456 | raw_file = args.raw 457 | model_dir = args.loadmodel 458 | dset_dir = args.savedset 459 | output_file = args.output 460 | if args.seg.lower() == "true": 461 | seg = True 462 | else: 463 | seg = False 464 | status = args.status.lower() 465 | 466 | save_model_dir = args.savemodel 467 | gpu = torch.cuda.is_available() 468 | 469 | char_emb = "data/gigaword_chn.all.a2b.uni.ite50.vec" 470 | bichar_emb = None 471 | gaz_file = "data/ctb.50d.vec" 472 | # gaz_file = None 473 | # char_emb = None 474 | # bichar_emb = None 475 | 476 | print("CuDNN:", torch.backends.cudnn.enabled) 477 | # gpu = False 478 | print("GPU available:", gpu) 479 | print("Status:", status) 480 | print("Seg: ", seg) 481 | print("Train file:", train_file) 482 | print("Dev file:", dev_file) 483 | print("Test file:", test_file) 484 | print("Raw file:", raw_file) 485 | print("Char emb:", char_emb) 486 | print("Bichar emb:", bichar_emb) 487 | print("Gaz file:", gaz_file) 488 | if status == 'train': 489 | print("Model saved to:", save_model_dir) 490 | sys.stdout.flush() 491 | 492 | if status == 'train': 493 | data = Data() 494 | data.HP_gpu = gpu 495 | data.HP_use_char = False 496 | data.HP_batch_size = 1 497 | data.use_bigram = False 498 | data.gaz_dropout = 0.5 499 | data.norm_gaz_emb = False 500 | data.HP_fix_gaz_emb = False 501 | data_initialization(data, gaz_file, train_file, dev_file, test_file) 502 | data.generate_instance_with_gaz_no_char(train_file, 'train') 503 | data.generate_instance_with_gaz_no_char(dev_file, 'dev') 504 | data.generate_instance_with_gaz_no_char(test_file, 'test') 505 | 506 | data.build_word_pretrain_emb(char_emb) 507 | data.build_biword_pretrain_emb(bichar_emb) 508 | data.build_gaz_pretrain_emb(gaz_file) 509 | 510 | train(data, save_model_dir, seg) 511 | 512 | 513 | 514 | -------------------------------------------------------------------------------- /resume.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __author__ = "liuwei" 3 | 4 | import time 5 | import sys 6 | import argparse 7 | import random 8 | import copy 9 | import torch 10 | import gc 11 | import pickle 12 | import torch.autograd as autograd 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import numpy as np 17 | import logging 18 | import re 19 | import os 20 | 21 | from ner.utils.metric import get_ner_fmeasure 22 | from ner.lw.cw_ner import CW_NER as SeqModel 23 | from ner.utils.resume_data import Data 24 | from tensorboardX import SummaryWriter 25 | from ner.lw.tbx_writer import TensorboardWriter 26 | 27 | 28 | seed_num = 100 29 | random.seed(seed_num) 30 | torch.manual_seed(seed_num) 31 | np.random.seed(seed_num) 32 | 33 | # set logger 34 | logger = logging.getLogger(__name__) 35 | logger.setLevel(logging.INFO) 36 | BASIC_FORMAT = "%(asctime)s:%(levelname)s: %(message)s" 37 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S' 38 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT) 39 | chlr = logging.StreamHandler() # 输出到控制台的handler 40 | chlr.setFormatter(formatter) 41 | logger.addHandler(chlr) 42 | 43 | 44 | # for checkpoint 45 | max_model_num = 100 46 | old_model_paths = [] 47 | 48 | # tensorboard writer 49 | log_dir = "data/log" 50 | 51 | train_log = SummaryWriter(os.path.join(log_dir, "train")) 52 | validation_log = SummaryWriter(os.path.join(log_dir, "validation")) 53 | test_log = SummaryWriter(os.path.join(log_dir, "test")) 54 | tensorboard = TensorboardWriter(train_log, validation_log, test_log) 55 | 56 | 57 | def data_initialization(data, gaz_file, train_file, dev_file, test_file): 58 | data.build_alphabet(train_file) 59 | data.build_alphabet(dev_file) 60 | data.build_alphabet(test_file) 61 | data.build_gaz_file(gaz_file) 62 | data.build_gaz_alphabet(train_file) 63 | data.build_gaz_alphabet(dev_file) 64 | data.build_gaz_alphabet(test_file) 65 | data.fix_alphabet() 66 | return data 67 | 68 | 69 | def predict_check(pred_variable, gold_variable, mask_variable): 70 | """ 71 | input: 72 | pred_variable (batch_size, sent_len): pred tag result, in numpy format 73 | gold_variable (batch_size, sent_len): gold result variable 74 | mask_variable (batch_size, sent_len): mask variable 75 | """ 76 | pred = pred_variable.cpu().data.numpy() 77 | gold = gold_variable.cpu().data.numpy() 78 | mask = mask_variable.cpu().data.numpy() 79 | overlaped = (pred == gold) 80 | right_token = np.sum(overlaped * mask) 81 | total_token = mask.sum() 82 | # print("right: %s, total: %s"%(right_token, total_token)) 83 | return right_token, total_token 84 | 85 | 86 | def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet, word_recover): 87 | """ 88 | input: 89 | pred_variable (batch_size, sent_len): pred tag result 90 | gold_variable (batch_size, sent_len): gold result variable 91 | mask_variable (batch_size, sent_len): mask variable 92 | """ 93 | 94 | pred_variable = pred_variable[word_recover] 95 | gold_variable = gold_variable[word_recover] 96 | mask_variable = mask_variable[word_recover] 97 | batch_size = gold_variable.size(0) 98 | seq_len = gold_variable.size(1) 99 | mask = mask_variable.cpu().data.numpy() 100 | pred_tag = pred_variable.cpu().data.numpy() 101 | gold_tag = gold_variable.cpu().data.numpy() 102 | batch_size = mask.shape[0] 103 | pred_label = [] 104 | gold_label = [] 105 | for idx in range(batch_size): 106 | pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] 107 | gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] 108 | # print "p:",pred, pred_tag.tolist() 109 | # print "g:", gold, gold_tag.tolist() 110 | assert (len(pred) == len(gold)) 111 | pred_label.append(pred) 112 | gold_label.append(gold) 113 | return pred_label, gold_label 114 | 115 | 116 | def lr_decay(optimizer, epoch, decay_rate, init_lr): 117 | lr = init_lr * ((1 - decay_rate) ** epoch) 118 | print(" Learning rate is setted as:", lr) 119 | for param_group in optimizer.param_groups: 120 | param_group['lr'] = lr 121 | return optimizer 122 | 123 | 124 | def save_model(epoch, state, models_dir): 125 | """ 126 | save the model state 127 | Args: 128 | epoch: the number of epoch 129 | state: [model_state, training_state] 130 | models_dir: the dir to save model 131 | """ 132 | if models_dir is not None: 133 | model_path = os.path.join(models_dir, "model_state_epoch_{}.th".format(epoch)) 134 | train_path = os.path.join(models_dir, "training_state_epoch_{}.th".format(epoch)) 135 | 136 | model_state, training_state = state 137 | torch.save(model_state, model_path) 138 | torch.save(training_state, train_path) 139 | 140 | if max_model_num > 0: 141 | old_model_paths.append([model_path, train_path]) 142 | if len(old_model_paths) > max_model_num: 143 | paths_to_remove = old_model_paths.pop(0) 144 | 145 | for fname in paths_to_remove: 146 | os.remove(fname) 147 | 148 | def find_last_model(models_dir): 149 | """ 150 | find the lastes checkpoint file 151 | Args: 152 | models_dir: the dir save models 153 | """ 154 | epoch_num = 0 155 | # return models_dir + "/model_state_epoch_{}.th".format(epoch_num), models_dir + "/training_state_epoch_{}.th".format(epoch_num) 156 | 157 | if models_dir is None: 158 | return None 159 | 160 | saved_models_path = os.listdir(models_dir) 161 | saved_models_path = [x for x in saved_models_path if 'model_state_epoch' in x] 162 | 163 | if len(saved_models_path) == 0: 164 | return None 165 | 166 | found_epochs = [ 167 | re.search("model_state_epoch_([0-9]+).th", x).group(1) 168 | for x in saved_models_path 169 | ] 170 | int_epochs = [int(epoch) for epoch in found_epochs] 171 | print("len: ", len(int_epochs)) 172 | last_epoch = sorted(int_epochs, reverse=True)[0] 173 | epoch_to_load = "{}".format(last_epoch) 174 | 175 | model_path = os.path.join(models_dir, "model_state_epoch_{}.th".format(epoch_to_load)) 176 | training_state_path = os.path.join(models_dir, "training_state_epoch_{}.th".format(epoch_to_load)) 177 | 178 | return model_path, training_state_path 179 | 180 | 181 | def restore_model(models_dir): 182 | """ 183 | restore the lastes checkpoint file 184 | """ 185 | lastest_checkpoint = find_last_model(models_dir) 186 | 187 | if lastest_checkpoint is None: 188 | return None 189 | else: 190 | model_path, training_state_path = lastest_checkpoint 191 | 192 | model_state = torch.load(model_path) 193 | training_state = torch.load(training_state_path) 194 | 195 | return (model_state, training_state) 196 | 197 | 198 | def evaluate(data, model, name): 199 | if name == "train": 200 | instances = data.train_Ids 201 | elif name == "dev": 202 | instances = data.dev_Ids 203 | elif name == 'test': 204 | instances = data.test_Ids 205 | elif name == 'raw': 206 | instances = data.raw_Ids 207 | else: 208 | print("Error: wrong evaluate name,", name) 209 | right_token = 0 210 | whole_token = 0 211 | pred_results = [] 212 | gold_results = [] 213 | ## set model in eval model 214 | model.eval() 215 | batch_size = 8 216 | start_time = time.time() 217 | train_num = len(instances) 218 | total_batch = train_num // batch_size + 1 219 | for batch_id in range(total_batch): 220 | start = batch_id * batch_size 221 | end = (batch_id + 1) * batch_size 222 | if end > train_num: 223 | end = train_num 224 | instance = instances[start:end] 225 | if not instance: 226 | continue 227 | gaz_list, reverse_gaz_list, batch_word, batch_biword, batch_wordlen, batch_wordrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu, True) 228 | tag_seq = model(gaz_list, reverse_gaz_list, batch_word, batch_wordlen, mask) 229 | pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet, batch_wordrecover) 230 | pred_results += pred_label 231 | gold_results += gold_label 232 | decode_time = time.time() - start_time 233 | speed = len(instances) / decode_time 234 | acc, p, r, f = get_ner_fmeasure(gold_results, pred_results, data.tagScheme) 235 | return speed, acc, p, r, f, pred_results 236 | 237 | 238 | def batchify_with_label(input_batch_list, gpu, volatile_flag=False): 239 | """ 240 | input: list of words, chars and labels, various length. [[words,biwords,chars,gaz, labels],[words,biwords,chars,labels],...] 241 | words: word ids for one sentence. (batch_size, sent_len) 242 | chars: char ids for on sentences, various length. (batch_size, sent_len, each_word_length) 243 | output: 244 | zero padding for word and char, with their batch length 245 | word_seq_tensor: (batch_size, max_sent_len) Variable 246 | word_seq_lengths: (batch_size,1) Tensor 247 | char_seq_recover: (batch_size*max_sent_len,1) recover char sequence order 248 | label_seq_tensor: (batch_size, max_sent_len) 249 | mask: (batch_size, max_sent_len) 250 | """ 251 | batch_size = len(input_batch_list) 252 | words = [sent[0] for sent in input_batch_list] 253 | biwords = [sent[1] for sent in input_batch_list] 254 | gazs = [sent[2] for sent in input_batch_list] 255 | reverse_gazs = [sent[3] for sent in input_batch_list] 256 | labels = [sent[4] for sent in input_batch_list] 257 | word_seq_lengths = torch.LongTensor(list(map(len, words))) 258 | max_seq_len = word_seq_lengths.max() 259 | word_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long() 260 | biword_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long() 261 | label_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long() 262 | mask = autograd.Variable(torch.zeros((batch_size, max_seq_len))).byte() 263 | for idx, (seq, biseq, label, seqlen) in enumerate(zip(words, biwords, labels, word_seq_lengths)): 264 | word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq) 265 | biword_seq_tensor[idx, :seqlen] = torch.LongTensor(biseq) 266 | label_seq_tensor[idx, :seqlen] = torch.LongTensor(label) 267 | mask[idx, :seqlen] = torch.Tensor([1] * int(seqlen)) 268 | word_seq_lengths, word_perm_idx = word_seq_lengths.sort(0, descending=True) 269 | word_seq_tensor = word_seq_tensor[word_perm_idx] 270 | biword_seq_tensor = biword_seq_tensor[word_perm_idx] 271 | label_seq_tensor = label_seq_tensor[word_perm_idx] 272 | mask = mask[word_perm_idx] 273 | 274 | _, word_seq_recover = word_perm_idx.sort(0, descending=False) 275 | 276 | gaz_list = [gazs[i] for i in word_perm_idx] 277 | reverse_gaz_list = [reverse_gazs[i] for i in word_perm_idx] 278 | 279 | if gpu: 280 | word_seq_tensor = word_seq_tensor.cuda() 281 | biword_seq_tensor = biword_seq_tensor.cuda() 282 | word_seq_lengths = word_seq_lengths.cuda() 283 | word_seq_recover = word_seq_recover.cuda() 284 | label_seq_tensor = label_seq_tensor.cuda() 285 | mask = mask.cuda() 286 | 287 | # gaz_seq_tensor = gaz_seq_tensor.cuda() 288 | # gaz_seq_length = gaz_seq_length.cuda() 289 | # gaz_mask_tensor = gaz_mask_tensor.cuda() 290 | 291 | return gaz_list, reverse_gaz_list, word_seq_tensor, biword_seq_tensor, word_seq_lengths, word_seq_recover, label_seq_tensor, mask 292 | 293 | 294 | def train(data, save_model_dir, seg=True): 295 | print("Training model...") 296 | data.show_data_summary() 297 | model = SeqModel(data, type=2) 298 | print("finished built model.") 299 | loss_function = nn.NLLLoss() 300 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 301 | optimizer = optim.SGD(parameters, lr=data.HP_lr, momentum=data.HP_momentum) 302 | best_dev = -1 303 | data.HP_iteration = 75 304 | 305 | ## here we should restore the model 306 | state = restore_model(save_model_dir) 307 | epoch = 0 308 | if state is not None: 309 | model_state = state[0] 310 | training_state = state[1] 311 | 312 | model.load_state_dict(model_state) 313 | optimizer.load_state_dict(training_state['optimizer']) 314 | epoch = int(training_state['epoch']) 315 | 316 | batch_size = 1 ## model can be trained in any batch_size, here we set it to 1 317 | now_batch_num = epoch * len(data.train_Ids) // batch_size 318 | 319 | ## start training 320 | for idx in range(epoch, data.HP_iteration): 321 | epoch_start = time.time() 322 | temp_start = epoch_start 323 | print("Epoch: %s/%s" % (idx, data.HP_iteration)) 324 | optimizer = lr_decay(optimizer, idx, data.HP_lr_decay, data.HP_lr) 325 | instance_count = 0 326 | sample_id = 0 327 | sample_loss = 0 328 | batch_loss = 0 329 | total_loss = 0 330 | right_token = 0 331 | whole_token = 0 332 | random.shuffle(data.train_Ids) 333 | ## set model in train model 334 | model.train() 335 | model.zero_grad() 336 | batch_size = 1 ## we can train in any batch_size, here we set it to 1 337 | batch_id = 0 338 | train_num = len(data.train_Ids) 339 | total_batch = train_num // batch_size + 1 340 | for batch_id in range(total_batch): 341 | start = batch_id * batch_size 342 | end = (batch_id + 1) * batch_size 343 | if end > train_num: 344 | end = train_num 345 | instance = data.train_Ids[start:end] 346 | if not instance: 347 | continue 348 | gaz_list, reverse_gaz_list, batch_word, batch_biword, batch_wordlen, batch_wordrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu) 349 | # print "gaz_list:",gaz_list 350 | # exit(0) 351 | instance_count += (end - start) 352 | loss, tag_seq = model.neg_log_likelihood_loss(gaz_list, reverse_gaz_list, batch_word, batch_wordlen, batch_label, mask) 353 | right, whole = predict_check(tag_seq, batch_label, mask) 354 | right_token += right 355 | whole_token += whole 356 | sample_loss += loss.item() 357 | total_loss += loss.item() 358 | batch_loss += loss 359 | 360 | if end % 4000 == 0: 361 | temp_time = time.time() 362 | temp_cost = temp_time - temp_start 363 | temp_start = temp_time 364 | print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token)) 365 | sys.stdout.flush() 366 | sample_loss = 0 367 | if end % data.HP_batch_size == 0: 368 | batch_loss.backward() 369 | optimizer.step() 370 | model.zero_grad() 371 | batch_loss = 0 372 | 373 | tensorboard.add_train_scalar("train_loss", total_loss / (batch_id + 1), now_batch_num) 374 | tensorboard.add_train_scalar("train_accu", right_token / whole_token, now_batch_num) 375 | now_batch_num += 1 376 | 377 | temp_time = time.time() 378 | temp_cost = temp_time - temp_start 379 | print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token)) 380 | epoch_finish = time.time() 381 | epoch_cost = epoch_finish - epoch_start 382 | print("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s" % ( 383 | idx, epoch_cost, train_num / epoch_cost, total_loss)) 384 | # exit(0) 385 | # continue 386 | speed, acc, p, r, f, _ = evaluate(data, model, "dev") 387 | dev_finish = time.time() 388 | dev_cost = dev_finish - epoch_finish 389 | 390 | if seg: 391 | current_score = f 392 | print("Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (dev_cost, speed, acc, p, r, f)) 393 | 394 | tensorboard.add_validation_scalar("dev_accu", acc, idx) 395 | tensorboard.add_validation_scalar("dev_p", p, idx) 396 | tensorboard.add_validation_scalar("dev_r", r, idx) 397 | tensorboard.add_validation_scalar("dev_f", f, idx) 398 | 399 | else: 400 | current_score = acc 401 | print("Dev: time: %.2fs speed: %.2fst/s; acc: %.4f" % (dev_cost, speed, acc)) 402 | 403 | if current_score > best_dev: 404 | if seg: 405 | print("Exceed previous best f score:", best_dev) 406 | else: 407 | print("Exceed previous best acc score:", best_dev) 408 | # model_name = save_model_dir + '.' + str(idx) + ".model" 409 | # torch.save(model.state_dict(), model_name) 410 | best_dev = current_score 411 | # ## decode test 412 | speed, acc, p, r, f, _ = evaluate(data, model, "test") 413 | test_finish = time.time() 414 | test_cost = test_finish - dev_finish 415 | if seg: 416 | print("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (test_cost, speed, acc, p, r, f)) 417 | 418 | tensorboard.add_test_scalar("test_accu", acc, idx) 419 | tensorboard.add_test_scalar("test_p", p, idx) 420 | tensorboard.add_test_scalar("test_r", r, idx) 421 | tensorboard.add_test_scalar("test_f", f, idx) 422 | 423 | else: 424 | print("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f" % (test_cost, speed, acc)) 425 | gc.collect() 426 | 427 | ## save the model every epoch 428 | model_state = model.state_dict() 429 | training_state = { 430 | "epoch": idx, 431 | "optimizer": optimizer.state_dict() 432 | } 433 | save_model(idx, (model_state, training_state), save_model_dir) 434 | 435 | 436 | if __name__ == '__main__': 437 | parser = argparse.ArgumentParser(description='Tuning with bi-directional LSTM-CRF') 438 | parser.add_argument('--embedding', help='Embedding for words', default='None') 439 | parser.add_argument('--status', choices=['train', 'test', 'decode'], help='update algorithm', default='train') 440 | parser.add_argument('--savemodel', default="data/model/resume") 441 | parser.add_argument('--savedset', help='Dir of saved data setting', default="data/save.dset") 442 | parser.add_argument('--train', default="data/resume/train.char.bmes") 443 | parser.add_argument('--dev', default="data/resume/dev.char.bmes") 444 | parser.add_argument('--test', default="data/resume/test.char.bmes") 445 | parser.add_argument('--seg', default="True") 446 | parser.add_argument('--extendalphabet', default="True") 447 | parser.add_argument('--raw') 448 | parser.add_argument('--loadmodel') 449 | parser.add_argument('--output') 450 | args = parser.parse_args() 451 | 452 | train_file = args.train 453 | dev_file = args.dev 454 | test_file = args.test 455 | raw_file = args.raw 456 | model_dir = args.loadmodel 457 | dset_dir = args.savedset 458 | output_file = args.output 459 | if args.seg.lower() == "true": 460 | seg = True 461 | else: 462 | seg = False 463 | status = args.status.lower() 464 | 465 | save_model_dir = args.savemodel 466 | gpu = torch.cuda.is_available() 467 | 468 | char_emb = "data/gigaword_chn.all.a2b.uni.ite50.vec" 469 | bichar_emb = None 470 | gaz_file = "data/ctb.50d.vec" 471 | # gaz_file = None 472 | # char_emb = None 473 | # bichar_emb = None 474 | 475 | print("CuDNN:", torch.backends.cudnn.enabled) 476 | # gpu = False 477 | print("GPU available:", gpu) 478 | print("Status:", status) 479 | print("Seg: ", seg) 480 | print("Train file:", train_file) 481 | print("Dev file:", dev_file) 482 | print("Test file:", test_file) 483 | print("Raw file:", raw_file) 484 | print("Char emb:", char_emb) 485 | print("Bichar emb:", bichar_emb) 486 | print("Gaz file:", gaz_file) 487 | if status == 'train': 488 | print("Model saved to:", save_model_dir) 489 | sys.stdout.flush() 490 | 491 | if status == 'train': 492 | data = Data() 493 | data.HP_gpu = gpu 494 | data.HP_use_char = False 495 | data.HP_batch_size = 1 496 | data.use_bigram = False 497 | data.gaz_dropout = 0.5 498 | data.norm_gaz_emb = False 499 | data.HP_fix_gaz_emb = False 500 | data_initialization(data, gaz_file, train_file, dev_file, test_file) 501 | data.generate_instance_with_gaz_no_char(train_file, 'train') 502 | data.generate_instance_with_gaz_no_char(dev_file, 'dev') 503 | data.generate_instance_with_gaz_no_char(test_file, 'test') 504 | 505 | data.build_word_pretrain_emb(char_emb) 506 | data.build_biword_pretrain_emb(bichar_emb) 507 | data.build_gaz_pretrain_emb(gaz_file) 508 | 509 | train(data, save_model_dir, seg) 510 | 511 | 512 | 513 | --------------------------------------------------------------------------------