├── README.md
├── data
├── log
│ ├── readme.txt
│ ├── test
│ │ └── readme.txt
│ ├── train
│ │ └── readme.txt
│ └── validation
│ │ └── readme.txt
├── model
│ ├── msra
│ │ └── readme.txt
│ ├── note4
│ │ └── readme.txt
│ ├── readme.txt
│ ├── resume
│ │ └── readme.txt
│ └── weibo
│ │ └── readme.txt
├── msra
│ └── readme.txt
├── note4
│ └── readme.txt
├── result
│ ├── msra
│ │ └── readme.txt
│ ├── note4
│ │ └── readme.txt
│ ├── resume
│ │ └── readme.txt
│ └── weibo
│ │ └── readme.txt
├── resume
│ └── readme.txt
└── weibo
│ ├── dev.char.bmes
│ ├── test.char.bmes
│ └── train.char.bmes
├── msra.py
├── ner
├── __init__.py
├── __pycache__
│ └── __init__.cpython-36.pyc
├── functions
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── gaz_opt.cpython-36.pyc
│ │ ├── iniatlize.cpython-36.pyc
│ │ ├── save_res.cpython-36.pyc
│ │ └── utils.cpython-36.pyc
│ ├── gaz_opt.py
│ ├── iniatlize.py
│ ├── save_res.py
│ └── utils.py
├── lw
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── cw_ner.cpython-36.pyc
│ │ └── tbx_writer.cpython-36.pyc
│ ├── cw_ner.py
│ └── tbx_writer.py
├── model
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── bilstm.cpython-36.pyc
│ │ ├── bilstmcrf.cpython-36.pyc
│ │ ├── charbilstm.cpython-36.pyc
│ │ ├── charcnn.cpython-36.pyc
│ │ ├── crf.cpython-36.pyc
│ │ └── latticelstm.cpython-36.pyc
│ └── crf.py
├── modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── gaz_bilstm.cpython-36.pyc
│ │ ├── gaz_embed.cpython-36.pyc
│ │ └── highway.cpython-36.pyc
│ ├── gaz_bilstm.py
│ ├── gaz_embed.py
│ └── highway.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── alphabet.cpython-36.pyc
│ ├── data.cpython-36.pyc
│ ├── functions.cpython-36.pyc
│ ├── gazetteer.cpython-36.pyc
│ ├── metric.cpython-36.pyc
│ ├── msra_data.cpython-36.pyc
│ ├── trie.cpython-36.pyc
│ └── weibo_data.cpython-36.pyc
│ ├── alphabet.py
│ ├── functions.py
│ ├── gazetteer.py
│ ├── metric.py
│ ├── msra_data.py
│ ├── note4_data.py
│ ├── resume_data.py
│ ├── trie.py
│ └── weibo_data.py
├── note4.py
├── resume.py
└── weibo.py
/README.md:
--------------------------------------------------------------------------------
1 | An Encoding Strategy based Word-Character LSTM for Chinese NER
2 | =============================================================
3 | An model take both character and words as input for Chinese NER task.
4 |
5 |
6 | Models and results can be found at our NAACL 2019 paper "[An Encoding Strategy based Word-Character LSTM for Chinese NER](https://www.aclweb.org/anthology/papers/N/N19/N19-1247/)". It achieves state-of-the-art performance on most of the dataset.
7 |
8 |
9 | Most of the code is written with reference to Yang Jie's "NCRF++". To know more about "NCRF++", please refer to the paper "NCRF++: An Open-source Neural Sequence Labeling Toolkit".
10 |
11 |
12 | Requirement:
13 | ============================
14 | Python 3.6
15 | Pytorch: 0.4.0
16 |
17 | if you want to use the tensorboard in our code, you should also install the followings:
18 | tensorboardX 1.2
19 | tensorflow 1.6.0
20 |
21 |
22 | Input format:
23 | =============================
24 | CoNLL format (prefer BIOES tag scheme), with each character its label for one line. Sentences are splited with a null line.
25 | ```cpp
26 | 美 B-LOC
27 | 国 E-LOC
28 | 的 O
29 | 华 B-PER
30 | 莱 I-PER
31 | 士 E-PER
32 |
33 | 我 O
34 | 跟 O
35 | 他 O
36 | 谈 O
37 | 笑 O
38 | 风 O
39 | 生 O
40 | ```
41 |
42 | Pretrained Embeddings:
43 | ===============
44 | Character embeddings: [gigword_chn.all.a2b.uni.ite50.vec](https://pan.baidu.com/s/1pLO6T9D)
45 | Word embeddings: [ctb.50d.vec](https://pan.baidu.com/s/1pLO6T9D)
46 |
47 |
48 | Run:
49 | ============
50 | put each dataset to the data dir, and then simply run the .py file.
51 | For example, to run Weibo experiment, just run: python3 weibo.py
52 |
53 | Cite:
54 | ========
55 | @inproceedings{liu-etal-2019-encoding, \
56 | title = "An Encoding Strategy Based Word-Character {LSTM} for {C}hinese {NER}", \
57 | author = "Liu, Wei and
58 | Xu, Tongge and
59 | Xu, Qinghua and
60 | Song, Jiayu and
61 | Zu, Yueran", \
62 | booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies", \
63 | year = "2019", \
64 | publisher = "Association for Computational Linguistics", \
65 | url = "https://www.aclweb.org/anthology/N19-1247", \
66 | pages = "2379--2389" \
67 | }
68 |
69 |
--------------------------------------------------------------------------------
/data/log/readme.txt:
--------------------------------------------------------------------------------
1 | The dir for tensorboard log file.
--------------------------------------------------------------------------------
/data/log/test/readme.txt:
--------------------------------------------------------------------------------
1 | for test.
--------------------------------------------------------------------------------
/data/log/train/readme.txt:
--------------------------------------------------------------------------------
1 | for train.
--------------------------------------------------------------------------------
/data/log/validation/readme.txt:
--------------------------------------------------------------------------------
1 | for dev.
--------------------------------------------------------------------------------
/data/model/msra/readme.txt:
--------------------------------------------------------------------------------
1 | for MSRA model checkpoints.
--------------------------------------------------------------------------------
/data/model/note4/readme.txt:
--------------------------------------------------------------------------------
1 | for OntoNote4.0 model checkpoints.
--------------------------------------------------------------------------------
/data/model/readme.txt:
--------------------------------------------------------------------------------
1 | the dir to save checkpoint files.
--------------------------------------------------------------------------------
/data/model/resume/readme.txt:
--------------------------------------------------------------------------------
1 | for resume model checkpoints.
--------------------------------------------------------------------------------
/data/model/weibo/readme.txt:
--------------------------------------------------------------------------------
1 | for weibo model checkpoints.
--------------------------------------------------------------------------------
/data/msra/readme.txt:
--------------------------------------------------------------------------------
1 | The dir to place MSRA dataset
--------------------------------------------------------------------------------
/data/note4/readme.txt:
--------------------------------------------------------------------------------
1 | The dir to place OntoNote4 dataset
--------------------------------------------------------------------------------
/data/result/msra/readme.txt:
--------------------------------------------------------------------------------
1 | The output results of MSRA
--------------------------------------------------------------------------------
/data/result/note4/readme.txt:
--------------------------------------------------------------------------------
1 | The output results of OntoNote4
--------------------------------------------------------------------------------
/data/result/resume/readme.txt:
--------------------------------------------------------------------------------
1 | The output results of Resume
--------------------------------------------------------------------------------
/data/result/weibo/readme.txt:
--------------------------------------------------------------------------------
1 | The output results of Weibo
--------------------------------------------------------------------------------
/data/resume/readme.txt:
--------------------------------------------------------------------------------
1 | The dir to place Resume dataset
--------------------------------------------------------------------------------
/ner/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/__init__.py
--------------------------------------------------------------------------------
/ner/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__init__.py
--------------------------------------------------------------------------------
/ner/functions/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/functions/__pycache__/gaz_opt.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__pycache__/gaz_opt.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/functions/__pycache__/iniatlize.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__pycache__/iniatlize.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/functions/__pycache__/save_res.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__pycache__/save_res.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/functions/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/functions/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/functions/gaz_opt.py:
--------------------------------------------------------------------------------
1 | __author__ = "liuwei"
2 |
3 |
4 | import numpy as np
5 | import torch
6 | import torch.autograd as autograd
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | def get_batch_gaz(gazs, batch_size, max_seq_len, gpu=False):
11 | """
12 | rely on the gazs for batch_data, generation a batched gaz tensor for train
13 | Args:
14 | gazs: a list list list, gazs for batch_data
15 | batch_size: the size of batch
16 | max_seq_len: the max seq length
17 | """
18 | # we need guarantee that every word has the same number gaz, that is use paddding
19 | # record the really length
20 | gaz_seq_length = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long()
21 | max_gaz_len = 1
22 | for i in range(batch_size):
23 | this_gaz_len = len(gazs[i])
24 | gaz_lens = [len(gazs[i][j]) for j in range(this_gaz_len)]
25 | gaz_seq_length[i, :this_gaz_len] = torch.LongTensor(gaz_lens)
26 | l = max(gaz_lens)
27 | if max_gaz_len < l:
28 | max_gaz_len = l
29 |
30 | # do padding
31 | gaz_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len, max_gaz_len))).long()
32 | for i in range(batch_size):
33 | for j in range(len(gazs[i])):
34 | l = int(gaz_seq_length[i][j])
35 | gaz_seq_tensor[i, j, :l] = torch.LongTensor(gazs[i][j][:l])
36 |
37 | # get mask
38 | empty_tensor = (gaz_seq_length == 0).long()
39 | empty_tensor = empty_tensor * max_gaz_len
40 | gaz_seq_length = gaz_seq_length + empty_tensor
41 |
42 | gaz_mask_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len, max_gaz_len))).long()
43 | for i in range(batch_size):
44 | for j in range(max_seq_len):
45 | l = int(gaz_seq_length[i][j])
46 | gaz_mask_tensor[i, j, :l] = 1
47 | del empty_tensor
48 |
49 | if gpu:
50 | gaz_seq_tensor = gaz_seq_tensor.cuda()
51 | gaz_seq_length = gaz_seq_length.cuda()
52 | gaz_mask_tensor = gaz_mask_tensor.cuda()
53 |
54 | return gaz_seq_tensor, gaz_seq_length, gaz_mask_tensor
55 |
--------------------------------------------------------------------------------
/ner/functions/iniatlize.py:
--------------------------------------------------------------------------------
1 | __author__ = "liuwei"
2 |
3 | """
4 | some init function
5 | """
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 |
11 | def init_cnn_weight(cnn_layer, seed=100):
12 | """
13 | init the weight of cnn net
14 | Args:
15 | cnn_layer: weight.size() = [out_channel, in_channels, kernei size]
16 | seed: int
17 | """
18 | torch.manual_seed(seed)
19 | nn.init.xavier_uniform_(cnn_layer.weight)
20 | cnn_layer.bias.data.zero_()
21 |
22 | def init_embedding(input_embedding, seed=100):
23 | """
24 | Args:
25 | input_embedding: the weight of embedding need to be init
26 | seed: int
27 | """
28 | # torch.manual_seed(seed)
29 | # scope = np.sqrt(3.0 / input_embedding.size(1))
30 | # nn.init.uniform_(input_embedding, -scope, scope)
31 |
32 | torch.manual_seed(seed)
33 | nn.init.normal_(input_embedding, 0, 0.1)
34 |
35 | def init_linear(input_linear, seed=100):
36 | """
37 | init the weight of linear net
38 | Args:
39 | input_linear: a linear layer
40 | """
41 | torch.manual_seed(seed)
42 | scope = np.sqrt(6.0 / (input_linear.weight.size(0) + input_linear.weight.size(1)))
43 | nn.init.uniform_(input_linear.weight, -scope, scope)
44 | if input_linear.bias is not None:
45 | input_linear.bias.data.zero_()
46 |
47 | def init_maxtrix_weight(weights, seed=100):
48 | """
49 | init the weight of a matrix
50 | """
51 | torch.manual_seed(seed)
52 | scope = np.sqrt(6.0 / (weights.size(0) + weights.size(1)))
53 | nn.init.uniform_(weights, -scope, scope)
54 |
55 | def init_vector(vector, seed=100):
56 | """
57 | init a vector, note that vector is 1-D
58 | """
59 | torch.manual_seed(seed)
60 | v_size = vector.size(0)
61 | scale = np.sqrt(3.0 / v_size)
62 | nn.init.uniform_(vector, -scale, scale)
63 |
64 | def init_highway(highway_layer, seed=100):
65 | """
66 | init the weight of highway net. do not init the bias
67 | Args:
68 | highway_layer: a highway layer
69 | seed:
70 | """
71 | torch.manual_seed(seed)
72 | scope = np.sqrt(6.0 / (highway_layer.weight.size(0) + highway_layer.weight.size(1)))
73 | nn.init.uniform_(highway_layer.weight, -scope, scope)
74 |
75 |
76 | def get_embedding_weight(weight_file, vocab_size, word_dim):
77 | """
78 | get the embedding from weight file, and then use it to init embedding
79 | Args:
80 | weight_file: embedding weight file
81 | vocab_size: the size of vocab
82 | word_dim: the dim of word embedding
83 | """
84 |
--------------------------------------------------------------------------------
/ner/functions/save_res.py:
--------------------------------------------------------------------------------
1 | __author__ = "liuwei"
2 |
3 |
4 | """
5 | some util functions
6 | """
7 |
8 | import torch
9 | import numpy as np
10 | import os
11 |
12 | def save_gold_pred(instances_text, preds, golds, name):
13 | """
14 | save the gold and pred result to do compare
15 | Args:
16 | instances_text: is a list list, each list is a sentence
17 | preds: is also a list list, each list is a sentence predict tag
18 | golds: is also a list list, each list is a sentence gold tag
19 | name: train? dev? or test
20 | """
21 | sent_len = len(instances_text)
22 |
23 | assert len(instances_text) == len(preds)
24 | assert len(preds) == len(golds)
25 |
26 | dir = "data/result/resume/"
27 | file_path = os.path.join(dir, name)
28 | num = 1
29 | with open(file_path, 'w') as f:
30 | f.write("wrod gold pred\n")
31 | for sent, gold, pred in zip(instances_text, golds, preds):
32 | # for each sentence
33 | for word, w_g, w_p in zip(sent[0], gold, pred):
34 | if w_g != w_p:
35 | f.write(word)
36 | f.write(" ")
37 | f.write(w_g)
38 | f.write(" ")
39 | f.write(w_p)
40 | f.write(" ")
41 | f.write(str(num))
42 | f.write("\n")
43 | num += 1
44 | else:
45 | f.write(word)
46 | f.write(" ")
47 | f.write(w_g)
48 | f.write(" ")
49 | f.write(w_p)
50 | f.write("\n")
51 |
52 | f.write("\n")
53 |
54 |
55 |
--------------------------------------------------------------------------------
/ner/functions/utils.py:
--------------------------------------------------------------------------------
1 | __author__ = "liuwei"
2 |
3 |
4 | """
5 | some util functions
6 | """
7 |
8 | import torch
9 | import numpy as np
10 |
11 | def reverse_padded_sequence(inputs, lengths, batch_first=True):
12 | """Reverses sequences according to their lengths.
13 | Inputs should have size ``T x B x *`` if ``batch_first`` is False, or
14 | ``B x T x *`` if True. T is the length of the longest sequence (or larger),
15 | B is the batch size, and * is any number of dimensions (including 0).
16 | Arguments:
17 | inputs (Variable): padded batch of variable length sequences.
18 | lengths (list[int]): list of sequence lengths
19 | batch_first (bool, optional): if True, inputs should be B x T x *.
20 | Returns:
21 | A Variable with the same size as inputs, but with each sequence
22 | reversed according to its length.
23 | """
24 | if batch_first:
25 | inputs = inputs.transpose(0, 1)
26 | max_length, batch_size = inputs.size(0), inputs.size(1)
27 | if len(lengths) != batch_size:
28 | raise ValueError("inputs is incompatible with lengths.")
29 | ind = [list(reversed(range(0, length))) + list(range(length, max_length))
30 | for length in lengths]
31 | ind = torch.LongTensor(ind).transpose(0, 1)
32 | for dim in range(2, inputs.dim()):
33 | ind = ind.unsqueeze(dim)
34 | ind = ind.expand_as(inputs)
35 | if inputs.is_cuda:
36 | ind = ind.cuda()
37 | reversed_inputs = torch.gather(inputs, 0, ind)
38 | if batch_first:
39 | reversed_inputs = reversed_inputs.transpose(0, 1)
40 |
41 | return reversed_inputs
42 |
43 |
44 | def random_embedding(vocab_size, embedding_dim):
45 | pretrain_emb = np.empty([vocab_size, embedding_dim])
46 | scale = np.sqrt(3.0 / embedding_dim)
47 | for index in range(vocab_size):
48 | pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedding_dim])
49 | return pretrain_emb
50 |
--------------------------------------------------------------------------------
/ner/lw/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/lw/__init__.py
--------------------------------------------------------------------------------
/ner/lw/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/lw/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/lw/__pycache__/cw_ner.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/lw/__pycache__/cw_ner.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/lw/__pycache__/tbx_writer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/lw/__pycache__/tbx_writer.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/lw/cw_ner.py:
--------------------------------------------------------------------------------
1 | __author__ = "liuwei"
2 |
3 | """
4 | A char word model
5 | """
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 |
10 | from ner.modules.gaz_embed import Gaz_Embed
11 | from ner.modules.gaz_bilstm import Gaz_BiLSTM
12 | from ner.model.crf import CRF
13 | from ner.functions.utils import random_embedding
14 | from ner.functions.gaz_opt import get_batch_gaz
15 | from ner.functions.utils import reverse_padded_sequence
16 |
17 | class CW_NER(torch.nn.Module):
18 | def __init__(self, data, type=1):
19 | print("Build char-word based NER Task...")
20 | super(CW_NER, self).__init__()
21 |
22 | self.gpu = data.HP_gpu
23 | label_size = data.label_alphabet_size
24 | self.type = type
25 | self.gaz_embed = Gaz_Embed(data, type)
26 |
27 | self.word_embedding = nn.Embedding(data.word_alphabet.size(), data.word_emb_dim)
28 |
29 | self.lstm = Gaz_BiLSTM(data, data.word_emb_dim + data.gaz_emb_dim, data.HP_hidden_dim)
30 |
31 | self.crf = CRF(data.label_alphabet_size, self.gpu)
32 |
33 | self.hidden2tag = nn.Linear(data.HP_hidden_dim * 2, data.label_alphabet_size + 2)
34 |
35 | if data.pretrain_word_embedding is not None:
36 | self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))
37 | else:
38 | self.word_embedding.weight.data.copy_(
39 | random_embedding(data.word_alphabet_size, data.word_emb_dim)
40 | )
41 |
42 | if self.gpu:
43 | self.word_embedding = self.word_embedding.cuda()
44 | self.hidden2tag = self.hidden2tag.cuda()
45 |
46 |
47 | def neg_log_likelihood_loss(self, gaz_list, reverse_gaz_list, word_inputs, word_seq_lengths, batch_label, mask):
48 | """
49 | get the neg_log_likelihood_loss
50 | Args:
51 | gaz_list: the batch data's gaz, for every chinese char
52 | reverse_gaz_list: the reverse list
53 | word_inputs: word input ids, [batch_size, seq_len]
54 | word_seq_lengths: [batch_size]
55 | batch_label: [batch_size, seq_len]
56 | mask: [batch_size, seq_len]
57 | """
58 | batch_size = word_inputs.size(0)
59 | seq_len = word_inputs.size(1)
60 | lengths = list(map(int, word_seq_lengths))
61 |
62 | # print('one ', reverse_gaz_list[0][:10])
63 |
64 | ## get batch gaz ids
65 | batch_gaz_ids, batch_gaz_length, batch_gaz_mask = get_batch_gaz(reverse_gaz_list, batch_size, seq_len, self.gpu)
66 |
67 | # print('two ', batch_gaz_ids[0][:10])
68 |
69 | reverse_batch_gaz_ids, reverse_batch_gaz_length, reverse_batch_gaz_mask = get_batch_gaz(gaz_list, batch_size, seq_len, self.gpu)
70 | reverse_batch_gaz_ids = reverse_padded_sequence(reverse_batch_gaz_ids, lengths)
71 | reverse_batch_gaz_length = reverse_padded_sequence(reverse_batch_gaz_length, lengths)
72 | reverse_batch_gaz_mask = reverse_padded_sequence(reverse_batch_gaz_mask, lengths)
73 |
74 | ## word embedding
75 | word_embs = self.word_embedding(word_inputs)
76 | reverse_word_embs = reverse_padded_sequence(word_embs, lengths)
77 |
78 | ## gaz embedding
79 | gaz_embs = self.gaz_embed((batch_gaz_ids, batch_gaz_length, batch_gaz_mask))
80 | reverse_gaz_embs = self.gaz_embed((reverse_batch_gaz_ids, reverse_batch_gaz_length, reverse_batch_gaz_mask))
81 | # print(gaz_embs[0][0][:20])
82 |
83 | ## lstm
84 | forward_inputs = torch.cat((word_embs, gaz_embs), dim=-1)
85 | backward_inputs = torch.cat((reverse_word_embs, reverse_gaz_embs), dim=-1)
86 |
87 | lstm_outs, _ = self.lstm((forward_inputs, backward_inputs), word_seq_lengths)
88 |
89 | ## hidden2tag
90 | outs = self.hidden2tag(lstm_outs)
91 |
92 | ## crf and loss
93 | loss = self.crf.neg_log_likelihood_loss(outs, mask, batch_label)
94 | _, tag_seq = self.crf._viterbi_decode(outs, mask)
95 |
96 | return loss, tag_seq
97 |
98 | def forward(self, gaz_list, reverse_gaz_list, word_inputs, word_seq_lengths, mask):
99 | """
100 | Args:
101 | gaz_list: the forward gaz_list
102 | reverse_gaz_list: the backward gaz list
103 | word_inputs: word ids
104 | word_seq_lengths: each sentence length
105 | mask: sentence mask
106 | """
107 | batch_size = word_inputs.size(0)
108 | seq_len = word_inputs.size(1)
109 | lengths = list(map(int, word_seq_lengths))
110 |
111 | ## get batch gaz ids
112 | batch_gaz_ids, batch_gaz_length, batch_gaz_mask = get_batch_gaz(reverse_gaz_list, batch_size, seq_len, self.gpu)
113 | reverse_batch_gaz_ids, reverse_batch_gaz_length, reverse_batch_gaz_mask = get_batch_gaz(gaz_list, batch_size, seq_len, self.gpu)
114 | reverse_batch_gaz_ids = reverse_padded_sequence(reverse_batch_gaz_ids, lengths)
115 | reverse_batch_gaz_length = reverse_padded_sequence(reverse_batch_gaz_length, lengths)
116 | reverse_batch_gaz_mask = reverse_padded_sequence(reverse_batch_gaz_mask, lengths)
117 |
118 | ## word embedding
119 | word_embs = self.word_embedding(word_inputs)
120 | reverse_word_embs = reverse_padded_sequence(word_embs, lengths)
121 |
122 | ## gaz embedding
123 | gaz_embs = self.gaz_embed((batch_gaz_ids, batch_gaz_length, batch_gaz_mask))
124 | reverse_gaz_embs = self.gaz_embed((reverse_batch_gaz_ids, reverse_batch_gaz_length, reverse_batch_gaz_mask))
125 |
126 | ## lstm
127 | forward_inputs = torch.cat((word_embs, gaz_embs), dim=-1)
128 | backward_inputs = torch.cat((reverse_word_embs, reverse_gaz_embs), dim=-1)
129 |
130 | lstm_outs, _ = self.lstm((forward_inputs, backward_inputs), word_seq_lengths)
131 |
132 | ## hidden2tag
133 | outs = self.hidden2tag(lstm_outs)
134 |
135 | ## crf and loss
136 | _, tag_seq = self.crf._viterbi_decode(outs, mask)
137 |
138 | return tag_seq
139 |
--------------------------------------------------------------------------------
/ner/lw/tbx_writer.py:
--------------------------------------------------------------------------------
1 | __author__ = "liuwei"
2 |
3 | """
4 | write the loss and accu to the tensorboardx
5 | """
6 | from typing import Any
7 | from tensorboardX import SummaryWriter
8 |
9 | class TensorboardWriter:
10 | """
11 | wrap the SummaryWriter, print the value to the tensorboard
12 | """
13 | def __init__(self, train_log: SummaryWriter = None, validation_log: SummaryWriter = None,
14 | test_log: SummaryWriter = None):
15 | self._train_log = train_log
16 | self._validation_log = validation_log
17 | self._test_log = test_log
18 |
19 | @staticmethod
20 | def _item(value: Any):
21 | if hasattr(value, 'item'):
22 | val = value.item()
23 | else:
24 | val = value
25 |
26 | return val
27 |
28 | def add_train_scalar(self, name: str, value: float, global_step: int):
29 | """
30 | add train scalar value to tensorboardX
31 | Args:
32 | name: the name of the value
33 | value:
34 | global_step: the steps
35 | """
36 | if self._train_log is not None:
37 | self._train_log.add_scalar(name, self._item(value), global_step)
38 |
39 |
40 | def add_validation_scalar(self, name: str, value: float, global_step: int):
41 | """
42 | add validation scalar value to tensorboardX
43 | Args:
44 | name: the name of the value
45 | value:
46 | global_step:
47 | """
48 | if self._validation_log is not None:
49 | self._validation_log.add_scalar(name, self._item(value), global_step)
50 |
51 | def add_test_scalar(self, name: str, value: float, global_step: int):
52 | """
53 | add test scalar value to tensorboardX
54 | Args:
55 | name: the name of the value
56 | value:
57 | global_step:
58 | """
59 | if self._test_log is not None:
60 | self._test_log.add_scalar(name, self._item(value), global_step)
61 |
62 |
--------------------------------------------------------------------------------
/ner/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__init__.py
--------------------------------------------------------------------------------
/ner/model/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/model/__pycache__/bilstm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/bilstm.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/model/__pycache__/bilstmcrf.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/bilstmcrf.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/model/__pycache__/charbilstm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/charbilstm.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/model/__pycache__/charcnn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/charcnn.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/model/__pycache__/crf.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/crf.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/model/__pycache__/latticelstm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/model/__pycache__/latticelstm.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/model/crf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.autograd as autograd
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | START_TAG = -2
9 | STOP_TAG = -1
10 |
11 |
12 | # Compute log sum exp in a numerically stable way for the forward algorithm
13 | def log_sum_exp(vec, m_size):
14 | """
15 | calculate log of exp sum
16 | args:
17 | vec (batch_size, vanishing_dim, hidden_dim) : input tensor
18 | m_size : hidden_dim
19 | return:
20 | batch_size, hidden_dim
21 | """
22 | _, idx = torch.max(vec, 1) # B * 1 * M
23 | max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M
24 | return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) # B * M
25 |
26 | class CRF(nn.Module):
27 |
28 | def __init__(self, tagset_size, gpu):
29 | super(CRF, self).__init__()
30 | print("build batched crf...")
31 | self.gpu = gpu
32 | # Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j.
33 | self.average_batch = False
34 | self.tagset_size = tagset_size
35 | # # We add 2 here, because of START_TAG and STOP_TAG
36 | # # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag
37 | init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
38 | # init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
39 | # init_transitions[:,START_TAG] = -1000.0
40 | # init_transitions[STOP_TAG,:] = -1000.0
41 | # init_transitions[:,0] = -1000.0
42 | # init_transitions[0,:] = -1000.0
43 | if self.gpu:
44 | init_transitions = init_transitions.cuda()
45 | self.transitions = nn.Parameter(init_transitions)
46 |
47 | # self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2))
48 | # self.transitions.data.zero_()
49 |
50 | def _calculate_PZ(self, feats, mask):
51 | """
52 | input:
53 | feats: (batch, seq_len, self.tag_size+2)
54 | masks: (batch, seq_len)
55 | """
56 | batch_size = feats.size(0)
57 | seq_len = feats.size(1)
58 | tag_size = feats.size(2)
59 | # print feats.view(seq_len, tag_size)
60 | assert(tag_size == self.tagset_size+2)
61 | mask = mask.transpose(1,0).contiguous()
62 | ins_num = seq_len * batch_size
63 | ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
64 | feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size)
65 | ## need to consider start
66 | scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
67 | scores = scores.view(seq_len, batch_size, tag_size, tag_size)
68 | # build iter
69 | seq_iter = enumerate(scores)
70 | #_, inivalues = seq_iter.next() # bat_size * from_target_size * to_target_size
71 | _, inivalues = next(seq_iter)
72 | # only need start from start_tag
73 | partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size
74 |
75 | ## add start score (from start to all tag, duplicate to batch_size)
76 | # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1)
77 | # iter over last scores
78 | for idx, cur_values in seq_iter:
79 | # previous to_target is current from_target
80 | # partition: previous results log(exp(from_target)), #(batch_size * from_target)
81 | # cur_values: bat_size * from_target * to_target
82 |
83 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
84 | cur_partition = log_sum_exp(cur_values, tag_size)
85 | # print cur_partition.data
86 |
87 | # (bat_size * from_target * to_target) -> (bat_size * to_target)
88 | # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1)
89 | mask_idx = mask[idx, :]
90 | mask_idx = mask_idx.view(batch_size, 1).expand(batch_size, tag_size)
91 |
92 | ## effective updated partition part, only keep the partition value of mask value = 1
93 | masked_cur_partition = cur_partition.masked_select(mask_idx)
94 | ## let mask_idx broadcastable, to disable warning
95 | mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)
96 |
97 | ## replace the partition where the maskvalue=1, other partition value keeps the same
98 | partition.masked_scatter_(mask_idx, masked_cur_partition)
99 | # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG
100 | cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
101 | cur_partition = log_sum_exp(cur_values, tag_size)
102 | final_partition = cur_partition[:, STOP_TAG]
103 | return final_partition.sum(), scores
104 |
105 |
106 | def _viterbi_decode(self, feats, mask):
107 | """
108 | input:
109 | feats: (batch, seq_len, self.tag_size+2)
110 | mask: (batch, seq_len)
111 | output:
112 | decode_idx: (batch, seq_len) decoded sequence
113 | path_score: (batch, 1) corresponding score for each sequence (to be implementated)
114 | """
115 | batch_size = feats.size(0)
116 | seq_len = feats.size(1)
117 | tag_size = feats.size(2)
118 | assert(tag_size == self.tagset_size+2)
119 | ## calculate sentence length for each sentence
120 | length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
121 | ## mask to (seq_len, batch_size)
122 | mask = mask.transpose(1,0).contiguous()
123 | ins_num = seq_len * batch_size
124 | ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
125 | feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)
126 | ## need to consider start
127 | scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
128 | scores = scores.view(seq_len, batch_size, tag_size, tag_size)
129 |
130 | # build iter
131 | seq_iter = enumerate(scores)
132 | ## record the position of best score
133 | back_points = list()
134 | partition_history = list()
135 |
136 |
137 | ## reverse mask (bug for mask = 1- mask, use this as alternative choice)
138 | # mask = 1 + (-1)*mask
139 | mask = (1 - mask.long()).byte()
140 | #_, inivalues = seq_iter.next() # bat_size * from_target_size * to_target_size
141 | _, inivalues = next(seq_iter)
142 | # only need start from start_tag
143 | partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size
144 | partition_history.append(partition)
145 | # iter over last scores
146 | for idx, cur_values in seq_iter:
147 | # previous to_target is current from_target
148 | # partition: previous results log(exp(from_target)), #(batch_size * from_target)
149 | # cur_values: batch_size * from_target * to_target
150 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
151 | ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG
152 | partition, cur_bp = torch.max(cur_values, 1)
153 |
154 | # add by liuwei
155 | partition = partition.view(partition.size()[0], partition.size()[1], 1)
156 |
157 | partition_history.append(partition)
158 | ## cur_bp: (batch_size, tag_size) max source score position in current tag
159 | ## set padded label as 0, which will be filtered in post processing
160 | cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
161 | back_points.append(cur_bp)
162 | ### add score to final STOP_TAG
163 | partition_history = torch.cat(partition_history,0).view(seq_len, batch_size,-1).transpose(1,0).contiguous() ## (batch_size, seq_len. tag_size)
164 | ### get the last position for each setences, and select the last partitions using gather()
165 | last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1
166 | last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1)
167 | ### calculate the score from last partition to end state (and then select the STOP_TAG from it)
168 | last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size)
169 | _, last_bp = torch.max(last_values, 1)
170 | pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long()
171 | if self.gpu:
172 | pad_zero = pad_zero.cuda()
173 | back_points.append(pad_zero)
174 | back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)
175 |
176 | ## select end ids in STOP_TAG
177 | pointer = last_bp[:, STOP_TAG]
178 | insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size)
179 | back_points = back_points.transpose(1,0).contiguous()
180 | ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values
181 | # print "lp:",last_position
182 | # print "il:",insert_last
183 | back_points.scatter_(1, last_position, insert_last)
184 | # print "bp:",back_points
185 | # exit(0)
186 | back_points = back_points.transpose(1,0).contiguous()
187 | ## decode from the end, padded position ids are 0, which will be filtered if following evaluation
188 | decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size))
189 | if self.gpu:
190 | decode_idx = decode_idx.cuda()
191 | decode_idx[-1] = pointer.data
192 | for idx in range(len(back_points)-2, -1, -1):
193 | pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1))
194 | decode_idx[idx] = pointer.data.view(-1)
195 | path_score = None
196 | decode_idx = decode_idx.transpose(1,0)
197 | return path_score, decode_idx
198 |
199 |
200 |
201 | def forward(self, feats):
202 | path_score, best_path = self._viterbi_decode(feats)
203 | return path_score, best_path
204 |
205 |
206 | def _score_sentence(self, scores, mask, tags):
207 | """
208 | input:
209 | scores: variable (seq_len, batch, tag_size, tag_size)
210 | mask: (batch, seq_len)
211 | tags: tensor (batch, seq_len)
212 | output:
213 | score: sum of score for gold sequences within whole batch
214 | """
215 | # Gives the score of a provided tag sequence
216 | batch_size = scores.size(1)
217 | seq_len = scores.size(0)
218 | tag_size = scores.size(2)
219 | ## convert tag value into a new format, recorded label bigram information to index
220 | new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len))
221 | if self.gpu:
222 | new_tags = new_tags.cuda()
223 | for idx in range(seq_len):
224 | if idx == 0:
225 | ## start -> first score
226 | new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0]
227 |
228 | else:
229 | new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx]
230 |
231 | ## transition for label to STOP_TAG
232 | end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size)
233 | ## length for batch, last word position = length - 1
234 | length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
235 | ## index the label id of last word
236 | end_ids = torch.gather(tags, 1, length_mask - 1)
237 |
238 | ## index the transition score for end_id to STOP_TAG
239 | end_energy = torch.gather(end_transition, 1, end_ids)
240 |
241 | ## convert tag as (seq_len, batch_size, 1)
242 | new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1)
243 | ### need convert tags id to search from 400 positions of scores
244 | tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size
245 | ## mask transpose to (seq_len, batch_size)
246 | tg_energy = tg_energy.masked_select(mask.transpose(1,0))
247 |
248 | # ## calculate the score from START_TAG to first label
249 | # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size)
250 | # start_energy = torch.gather(start_transition, 1, tags[0,:])
251 |
252 | ## add all score together
253 | # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum()
254 | gold_score = tg_energy.sum() + end_energy.sum()
255 | return gold_score
256 |
257 | def neg_log_likelihood_loss(self, feats, mask, tags):
258 | # nonegative log likelihood
259 | batch_size = feats.size(0)
260 | forward_score, scores = self._calculate_PZ(feats, mask)
261 | gold_score = self._score_sentence(scores, mask, tags)
262 | # print "batch, f:", forward_score.data[0], " g:", gold_score.data[0], " dis:", forward_score.data[0] - gold_score.data[0]
263 | # exit(0)
264 | if self.average_batch:
265 | return (forward_score - gold_score)/batch_size
266 | else:
267 | return forward_score - gold_score
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
--------------------------------------------------------------------------------
/ner/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/modules/__init__.py
--------------------------------------------------------------------------------
/ner/modules/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/modules/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/modules/__pycache__/gaz_bilstm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/modules/__pycache__/gaz_bilstm.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/modules/__pycache__/gaz_embed.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/modules/__pycache__/gaz_embed.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/modules/__pycache__/highway.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/modules/__pycache__/highway.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/modules/gaz_bilstm.py:
--------------------------------------------------------------------------------
1 | __author__ = "liuwei"
2 |
3 | """
4 | A bilstm use for the sentences contain gaz msg
5 | """
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 |
11 | from ner.functions.utils import reverse_padded_sequence
12 |
13 | class Gaz_BiLSTM(torch.nn.Module):
14 | def __init__(self, data, input_size, hidden_size):
15 | print("Build the Gaz bilstm...")
16 | super(Gaz_BiLSTM, self).__init__()
17 |
18 | self.gpu = data.HP_gpu
19 | self.batch_size = data.HP_batch_size
20 | self.drop = nn.Dropout(data.HP_dropout)
21 | self.droplstm = nn.Dropout(data.HP_dropout)
22 |
23 | self.f_lstm = nn.LSTM(input_size, hidden_size, num_layers=data.HP_lstm_layer, batch_first=True)
24 | self.b_lstm = nn.LSTM(input_size, hidden_size, num_layers=data.HP_lstm_layer, batch_first=True)
25 |
26 | if self.gpu:
27 | self.drop = self.drop.cuda()
28 | self.droplstm = self.droplstm.cuda()
29 | self.f_lstm = self.f_lstm.cuda()
30 | self.b_lstm = self.b_lstm.cuda()
31 |
32 |
33 | def get_lstm_features(self, inputs, word_seq_length):
34 | """
35 | get the output of bilstm. Note that inputs is forward and backward inputs
36 | Args:
37 | inputs: a tuple, each item size is [batch_size, sent_len, dim]
38 | word_seq_length: a [batch_size] tensor
39 | """
40 | # lengths = list(map(int, word_seq_length))
41 | f_inputs, b_inputs = inputs
42 | f_inputs = self.drop(f_inputs)
43 | b_inputs = self.drop(b_inputs)
44 |
45 | f_lstm_out, f_hidden = self.f_lstm(f_inputs)
46 | b_lstm_out, b_hidden = self.b_lstm(b_inputs)
47 | # b_lstm_out = reverse_padded_sequence(b_lstm_out, lengths)
48 |
49 | f_lstm_out = self.droplstm(f_lstm_out)
50 | b_lstm_out = self.droplstm(b_lstm_out)
51 |
52 | return f_lstm_out, b_lstm_out
53 |
54 |
55 | def forward(self, inputs, word_seq_length):
56 | """
57 | """
58 | f_lstm_out, b_lstm_out = self.get_lstm_features(inputs, word_seq_length)
59 |
60 | lengths = list(map(int, word_seq_length))
61 | rb_lstm_out = reverse_padded_sequence(b_lstm_out, lengths)
62 |
63 | lstm_out = torch.cat((f_lstm_out, rb_lstm_out), dim=-1)
64 |
65 | return lstm_out, (f_lstm_out, b_lstm_out)
66 |
--------------------------------------------------------------------------------
/ner/modules/gaz_embed.py:
--------------------------------------------------------------------------------
1 | __author__ = "liuwei"
2 |
3 | """
4 | the strategy to obtain gaz embedding. this embedding is concat to the char embedding
5 | """
6 |
7 | import numpy
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from ner.modules.highway import Highway
13 | from ner.functions.utils import random_embedding
14 | from ner.functions.utils import reverse_padded_sequence
15 | from ner.functions.iniatlize import init_cnn_weight
16 | from ner.functions.iniatlize import init_maxtrix_weight
17 | from ner.functions.iniatlize import init_vector
18 |
19 | class Gaz_Embed(torch.nn.Module):
20 | def __init__(self, data, type=1):
21 | """
22 | Args:
23 | data: all the data information
24 | type: the type of strategy, 1 for avg, 2 for short first, 3 for long first
25 | """
26 | print('build gaz embedding...')
27 |
28 | super(Gaz_Embed, self).__init__()
29 |
30 | self.gpu = data.HP_gpu
31 | self.data = data
32 | self.type = type
33 | self.gaz_dim = data.gaz_emb_dim
34 | self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), data.gaz_emb_dim)
35 | self.dropout = nn.Dropout(p=0.5)
36 |
37 | if data.pretrain_gaz_embedding is not None:
38 | self.gaz_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_gaz_embedding))
39 | else:
40 | self.gaz_embedding.weight.data.copy_(
41 | torch.from_numpy(random_embedding(data.gaz_alphabet.size(), data.gaz_emb_dim))
42 | )
43 |
44 | self.filters = [[1, 20], [2, 30]]
45 | if self.type == 4:
46 | # use conv, so we need to define some conv
47 | # here we use 20 1-d conv, and 30 2-d conv
48 | self.build_cnn(self.filters)
49 |
50 | ## also use highway, 2 layers highway
51 | # self.highway = Highway(self.gaz_dim, num_layers=2)
52 | # if self.gpu:
53 | # self.highway = self.highway.cuda()
54 |
55 | if self.type == 5:
56 | # use self-attention
57 | self.build_attention()
58 |
59 | if self.gpu:
60 | self.gaz_embedding = self.gaz_embedding.cuda()
61 |
62 | def build_cnn(self, filters):
63 | """
64 | build cnn for convolution the gaz embeddings
65 | Args:
66 | filters: the filters definetion
67 | """
68 | for filter in filters:
69 | k_size = filter[0]
70 | channel_size = filter[1]
71 |
72 | conv = torch.nn.Conv1d(
73 | in_channels=self.gaz_dim,
74 | out_channels=channel_size,
75 | kernel_size=k_size,
76 | bias=True
77 | )
78 |
79 | if self.gpu:
80 | conv = conv.cuda()
81 |
82 | init_cnn_weight(conv)
83 | self.add_module('conv_{}d'.format(k_size), conv)
84 | del conv
85 |
86 | def build_attention(self):
87 | """
88 | build a self-attention to weight add the values
89 | Args:
90 | max_gaz_num: the max gaz number
91 | """
92 | w1 = torch.zeros(self.gaz_dim, self.gaz_dim)
93 | w2 = torch.zeros(self.gaz_dim)
94 |
95 | if self.gpu:
96 | w1 = w1.cuda()
97 | w2 = w2.cuda()
98 | init_maxtrix_weight(w1)
99 | init_vector(w2)
100 |
101 | self.W1 = nn.Parameter(w1)
102 | self.W2 = nn.Parameter(w2)
103 |
104 | def forward(self, inputs):
105 | """
106 | the inputs is a tuple, include gaz_seq_tensor, gaz_seq_length, gaz_mask_tensor
107 | Args:
108 | gaz_seq_tensor: [batch_size, seq_len, gaz_num]
109 | gaz_seq_length: [batch_size, seq_len]
110 | gaz_mask_tensor: [batch_size, seq_len, gaz_num]
111 | """
112 | gaz_seq_tensor, gaz_seq_lengths, gaz_mask_tensor = inputs
113 | batch_size = gaz_seq_tensor.size(0)
114 | seq_len = gaz_seq_tensor.size(1)
115 | gaz_num = gaz_seq_tensor.size(2)
116 |
117 | # type = 1, short first; type = 2, long first; type = 3, avg; type = 4, cnn
118 | if self.type == 1:
119 | # short first
120 | gaz_ids = gaz_seq_tensor[:, :, 0]
121 | gaz_ids = gaz_ids.view(batch_size, seq_len, -1)
122 | gaz_ids = torch.squeeze(gaz_ids, dim=-1)
123 | if self.gpu:
124 | gaz_ids = gaz_ids.cuda()
125 |
126 | gaz_embs = self.gaz_embedding(gaz_ids)
127 |
128 | return gaz_embs
129 |
130 | elif self.type == 2:
131 | # long first
132 | select_ids = gaz_seq_lengths
133 | select_ids = select_ids.view(batch_size, seq_len, -1)
134 |
135 | # the max index = len - 1
136 | select_ids = select_ids - 1
137 | # print(select_ids[0])
138 | gaz_ids = torch.gather(gaz_seq_tensor, dim=2, index=select_ids)
139 | gaz_ids = gaz_ids.view(batch_size, seq_len, -1)
140 | gaz_ids = torch.squeeze(gaz_ids, dim=-1)
141 | if self.gpu:
142 | gaz_ids = gaz_ids.cuda()
143 |
144 | gaz_embs = self.gaz_embedding(gaz_ids)
145 |
146 | return gaz_embs
147 |
148 | elif self.type == 3:
149 | ## avg first
150 | # [batch_size, seq_len, gaz_num, gaz_dim]
151 | if self.gpu:
152 | gaz_seq_tensor = gaz_seq_tensor.cuda()
153 |
154 | gaz_embs = self.gaz_embedding(gaz_seq_tensor)
155 |
156 | # use mask to do select sum, mask: [batch_size, seq_len, gaz_num]
157 | pad_mask = gaz_mask_tensor.view(batch_size, seq_len, gaz_num, -1).float()
158 |
159 | # padding embedding transform to 0
160 | gaz_embs = gaz_embs * pad_mask
161 |
162 | # do sum at the gaz_num axis, result is [batch_size, seq_len, gaz_dim]
163 | gaz_embs = torch.sum(gaz_embs, dim=2)
164 | gaz_seq_lengths = gaz_seq_lengths.view(batch_size, seq_len, -1).float()
165 | gaz_embs = gaz_embs / gaz_seq_lengths
166 |
167 | return gaz_embs
168 |
169 | elif self.type == 4:
170 | ## use convolution
171 | # first get all the gaz embedding representation
172 | # [batch_size, seq_len, gaz_num, gaz_dim]
173 | input_embs = self.gaz_embedding(gaz_seq_tensor)
174 |
175 | # transform to [batch_size * seq_len, gaz_num, gaz_dim] to use Conv1d
176 | input_embs = input_embs.view(-1, gaz_num, self.gaz_dim)
177 | input_embs = torch.transpose(input_embs, 2, 1)
178 | input_embs = self.dropout(input_embs)
179 |
180 | gaz_embs = []
181 |
182 | for filter in self.filters:
183 | k_size = filter[0]
184 |
185 | conv = getattr(self, 'conv_{}d'.format(k_size))
186 | convolved = conv(input_embs)
187 |
188 | # convolved is [batch_size * seq_len, channel, width-k_size+1]
189 | # do active and max-pool, [batch_size * seq_len, channel]
190 | convolved, _ = torch.max(convolved, dim=-1)
191 | if True:
192 | convolved = F.tanh(convolved)
193 |
194 | gaz_embs.append(convolved)
195 |
196 | # transpose to [batch_size, seq_len, gaz_dim]
197 | gaz_embs = torch.cat(gaz_embs, dim=-1)
198 |
199 | # gaz_embs = self.dropout(gaz_embs)
200 | # gaz_embs = self.highway(gaz_embs)
201 | gaz_embs = gaz_embs.view(batch_size, seq_len, -1)
202 |
203 | return gaz_embs
204 |
205 | elif self.type == 5:
206 | ## self attention
207 | gaz_embs = self.gaz_embedding(gaz_seq_tensor)
208 | # print('origin: ', gaz_embs[0][0][:20])
209 | input_embs = gaz_embs.view(-1, gaz_num, self.gaz_dim)
210 | input_embs = torch.transpose(input_embs, 2, 1)
211 |
212 | ### step1, cal alpha
213 | # cal the alpha, result is [batch * seq_len, d1, gaz_num]
214 | alpha = torch.matmul(self.W1, input_embs)
215 | alpha = F.tanh(alpha)
216 |
217 | # result is [batch * seq_len, gaz_num]
218 | alpha = torch.transpose(alpha, 2, 1)
219 | weight2 = self.W2
220 | alpha = torch.matmul(alpha, weight2)
221 |
222 | # before softmax, we need to mask
223 | # alpha = alpha * gaz_mask_tensor.contiguous().view(-1, gaz_num).float()
224 | zero_mask = (1 - gaz_mask_tensor.float()) * (-2**31 + 1)
225 | zero_mask = zero_mask.contiguous().view(-1, gaz_num)
226 | alpha = alpha + zero_mask
227 |
228 | ### step2 do softmax,
229 | # [batch * seq_len, gaz_num]
230 | # alpha = torch.exp(alpha)
231 | # total_alpha = torch.sum(alpha, dim=-1, keepdim=True)
232 | # alpha = torch.div(alpha, total_alpha)
233 | alpha = F.softmax(alpha, dim=-1)
234 | alpha = alpha.view(batch_size, seq_len, gaz_num, -1)
235 |
236 | ### step3, weighted add, [batch_size, seq_len, gaz_num, gaz_dim]
237 | gaz_embs = gaz_embs * alpha
238 | gaz_embs = torch.sum(gaz_embs, dim=2)
239 |
240 | return gaz_embs
241 |
242 |
243 |
244 |
245 |
--------------------------------------------------------------------------------
/ner/modules/highway.py:
--------------------------------------------------------------------------------
1 | __author__ = 'liuwei'
2 |
3 | """
4 | implements of the highway net
5 | include two gate: transform gate(G_T) and carry gate(G_C)
6 | H = w_h * x + b_h
7 | G_T = sigmoid(w_t * x + b_t)
8 | G_C = sigmoid(w_c * x + b_c)
9 | outputs:
10 | outputs = G_T * H + G_C * x
11 |
12 | for sample:
13 | G_C = (1 - G_T), then:
14 | outputs = G_T * H + (1 - G_T) * x
15 | and generally set b_c = -1 or -3, that mean set b_t = 1 or 3
16 | """
17 |
18 | import torch
19 | import torch.nn as nn
20 | import numpy as np
21 |
22 | from ner.functions.iniatlize import init_highway
23 |
24 | class Highway(nn.Module):
25 | def __init__(self, input_dim, num_layers=1, activation=nn.functional.relu,
26 | require_grad=True):
27 | """
28 | Args:
29 | input_dim: the dim
30 | num_layers: the numer of highway layers
31 | activation: activation function, tanh or relu
32 | """
33 | super(Highway, self).__init__()
34 |
35 | self._input_dim = input_dim
36 | self._num_layers = num_layers
37 |
38 | # output is input_dim * 2, because one is candidate status, and another
39 | # is transform gate
40 | self._layers = torch.nn.ModuleList(
41 | [nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)]
42 | )
43 | self._activation = activation
44 | i = 0
45 | for layer in self._layers:
46 | layer.weight.requires_grad = require_grad
47 | layer.bias.requires_grad = require_grad
48 | init_highway(layer, 100)
49 | layer.bias[input_dim:].data.fill_(1)
50 |
51 | i += 1
52 |
53 |
54 | def forward(self, inputs):
55 | """
56 | Args:
57 | inputs: a tensor, size is [batch_size, n_tokens, input_dim]
58 | """
59 | current_input = inputs
60 | for layer in self._layers:
61 | proj_inputs = layer(current_input)
62 | linear_part = current_input
63 |
64 | del current_input
65 |
66 | # here the gate is carry gate, if you change it to transform gate
67 | # the bias init should change too, maybe -1 or -3 even
68 | nonlinear_part, carry_gate = proj_inputs.chunk(2, dim=-1)
69 | nonlinear_part = self._activation(nonlinear_part)
70 | carry_gate = torch.nn.functional.sigmoid(carry_gate)
71 | current_input = (1 - carry_gate) * nonlinear_part + carry_gate * linear_part
72 |
73 | return current_input
74 |
75 |
--------------------------------------------------------------------------------
/ner/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__init__.py
--------------------------------------------------------------------------------
/ner/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/utils/__pycache__/alphabet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/alphabet.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/utils/__pycache__/data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/data.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/utils/__pycache__/functions.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/functions.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/utils/__pycache__/gazetteer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/gazetteer.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/utils/__pycache__/metric.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/metric.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/utils/__pycache__/msra_data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/msra_data.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/utils/__pycache__/trie.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/trie.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/utils/__pycache__/weibo_data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuwei1206/CCW-NER/064be07ec6e9747aef35e4afcfac97184f25b42c/ner/utils/__pycache__/weibo_data.cpython-36.pyc
--------------------------------------------------------------------------------
/ner/utils/alphabet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects.
5 | """
6 | import json
7 | import os
8 |
9 |
10 | class Alphabet:
11 | def __init__(self, name, label=False, keep_growing=True):
12 | self.__name = name
13 | self.UNKNOWN = ""
14 | self.label = label
15 | self.instance2index = {}
16 | self.instances = []
17 | self.keep_growing = keep_growing
18 |
19 | # Index 0 is occupied by default, all else following.
20 | self.default_index = 0
21 | self.next_index = 1
22 | if not self.label:
23 | self.add(self.UNKNOWN)
24 |
25 | def clear(self, keep_growing=True):
26 | self.instance2index = {}
27 | self.instances = []
28 | self.keep_growing = keep_growing
29 |
30 | # Index 0 is occupied by default, all else following.
31 | self.default_index = 0
32 | self.next_index = 1
33 |
34 | def add(self, instance):
35 | if instance not in self.instance2index:
36 | self.instances.append(instance)
37 | self.instance2index[instance] = self.next_index
38 | self.next_index += 1
39 |
40 | def get_index(self, instance):
41 | try:
42 | return self.instance2index[instance]
43 | except KeyError:
44 | if self.keep_growing:
45 | index = self.next_index
46 | self.add(instance)
47 | return index
48 | else:
49 | return self.instance2index[self.UNKNOWN]
50 |
51 | def get_instance(self, index):
52 | if index == 0:
53 | # First index is occupied by the wildcard element.
54 | return None
55 | try:
56 | return self.instances[index - 1]
57 | except IndexError:
58 | print('WARNING:Alphabet get_instance ,unknown instance, return the first label.')
59 | return self.instances[0]
60 |
61 | def size(self):
62 | # if self.label:
63 | # return len(self.instances)
64 | # else:
65 | return len(self.instances) + 1
66 |
67 | def iteritems(self):
68 | return self.instance2index.items()
69 |
70 | def enumerate_items(self, start=1):
71 | if start < 1 or start >= self.size():
72 | raise IndexError("Enumerate is allowed between [1 : size of the alphabet)")
73 | return zip(range(start, len(self.instances) + 1), self.instances[start - 1:])
74 |
75 | def close(self):
76 | self.keep_growing = False
77 |
78 | def open(self):
79 | self.keep_growing = True
80 |
81 | def get_content(self):
82 | return {'instance2index': self.instance2index, 'instances': self.instances}
83 |
84 | def from_json(self, data):
85 | self.instances = data["instances"]
86 | self.instance2index = data["instance2index"]
87 |
88 | def save(self, output_directory, name=None):
89 | """
90 | Save both alhpabet records to the given directory.
91 | :param output_directory: Directory to save model and weights.
92 | :param name: The alphabet saving name, optional.
93 | :return:
94 | """
95 | saving_name = name if name else self.__name
96 | try:
97 | json.dump(self.get_content(), open(os.path.join(output_directory, saving_name + ".json"), 'w'))
98 | except Exception as e:
99 | print("Exception: Alphabet is not saved: " % repr(e))
100 |
101 | def load(self, input_directory, name=None):
102 | """
103 | Load model architecture and weights from the give directory. This allow we use old models even the structure
104 | changes.
105 | :param input_directory: Directory to save model and weights
106 | :return:
107 | """
108 | loading_name = name if name else self.__name
109 | self.from_json(json.load(open(os.path.join(input_directory, loading_name + ".json"))))
110 |
--------------------------------------------------------------------------------
/ner/utils/functions.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import sys
4 | import numpy as np
5 | from ner.utils.alphabet import Alphabet
6 | NULLKEY = "-null-"
7 | def normalize_word(word):
8 | new_word = ""
9 | for char in word:
10 | if char.isdigit():
11 | new_word += '0'
12 | else:
13 | new_word += char
14 | return new_word
15 |
16 | def read_word_instance(input_file, word_alphabet, label_alphabet, number_normalized, max_sent_lengths):
17 | """
18 | read only word msg, no char.
19 | note that here word is chinese character!
20 | """
21 | instence_texts = []
22 | instence_ids = []
23 | entity_num = 0
24 | # max_sent_lengths = 1000
25 | with open(input_file, 'r', errors='ignore') as f:
26 | in_lines = f.readlines()
27 | words = []
28 | labels = []
29 | word_ids = []
30 | label_ids = []
31 |
32 | for line in in_lines:
33 | if len(line) > 2:
34 | # len less than 2 mean a blank line
35 | pairs = line.strip().split()
36 | word = pairs[0]
37 | if number_normalized:
38 | word = normalize_word(word)
39 | label = pairs[-1]
40 | if "B-" in label or "S-" in label:
41 | entity_num += 1
42 | words.append(word)
43 | labels.append(label)
44 | word_ids.append(word_alphabet.get_index(word))
45 | label_ids.append(label_alphabet.get_index(label))
46 | else:
47 | if (max_sent_lengths < 0) or (len(words) < max_sent_lengths):
48 | instence_texts.append([words, labels])
49 | instence_ids.append([word_ids, label_ids])
50 | # else:
51 | # print(len(words))
52 | # print('so long!!!')
53 | words = []
54 | labels = []
55 | word_ids = []
56 | label_ids = []
57 | # if len(words) > 0:
58 | # instence_texts.append([words, labels])
59 | # instence_ids.append([word_ids, label_ids])
60 |
61 | # print("entity num: ", entity_num)
62 | return instence_texts, instence_ids
63 |
64 |
65 | def read_instance(input_file, word_alphabet, char_alphabet, label_alphabet, number_normalized,max_sent_length, char_padding_size=-1, char_padding_symbol = ''):
66 | """
67 | read the word and char msg, all read into instance_texts and instance_ids
68 |
69 | note that, in the file, every line is a single chinese character and its tag,
70 | but when read into instace_texts and instance_ids, every line is a sentence character,
71 | and sentence character ids
72 | """
73 | in_lines = open(input_file,'r').readlines()
74 | instence_texts = []
75 | instence_Ids = []
76 | words = []
77 | chars = []
78 | labels = []
79 | word_Ids = []
80 | char_Ids = []
81 | label_Ids = []
82 | for line in in_lines:
83 | if len(line) > 2:
84 | pairs = line.strip().split()
85 | word = pairs[0]
86 | if number_normalized:
87 | word = normalize_word(word)
88 | label = pairs[-1]
89 | words.append(word)
90 | labels.append(label)
91 | word_Ids.append(word_alphabet.get_index(word))
92 | label_Ids.append(label_alphabet.get_index(label))
93 | char_list = []
94 | char_Id = []
95 | for char in word:
96 | char_list.append(char)
97 | if char_padding_size > 0:
98 | char_number = len(char_list)
99 | if char_number < char_padding_size:
100 | char_list = char_list + [char_padding_symbol]*(char_padding_size-char_number)
101 | assert(len(char_list) == char_padding_size)
102 | else:
103 | ### not padding
104 | pass
105 | for char in char_list:
106 | char_Id.append(char_alphabet.get_index(char))
107 | chars.append(char_list)
108 | char_Ids.append(char_Id)
109 | else:
110 | if (max_sent_length < 0) or (len(words) < max_sent_length):
111 | instence_texts.append([words, chars, labels])
112 | instence_Ids.append([word_Ids, char_Ids,label_Ids])
113 | words = []
114 | chars = []
115 | labels = []
116 | word_Ids = []
117 | char_Ids = []
118 | label_Ids = []
119 | return instence_texts, instence_Ids
120 |
121 |
122 | def read_seg_instance(input_file, word_alphabet, biword_alphabet, char_alphabet, label_alphabet, number_normalized, max_sent_length, char_padding_size=-1, char_padding_symbol = ''):
123 | """
124 | read the word, biword and char msg, all read into instance_texts and instance_ids
125 |
126 | note that, in the file, every line is a single chinese character and its tag,
127 | but when read into instace_texts and instance_ids, every line is a sentence character,
128 | and sentence character ids
129 | """
130 | in_lines = open(input_file,'r').readlines()
131 | instence_texts = []
132 | instence_Ids = []
133 | words = []
134 | biwords = []
135 | chars = []
136 | labels = []
137 | word_Ids = []
138 | biword_Ids = []
139 | char_Ids = []
140 | label_Ids = []
141 | for idx in range(len(in_lines)):
142 | line = in_lines[idx]
143 | if len(line) > 2:
144 | pairs = line.strip().split()
145 | word = pairs[0]
146 | if number_normalized:
147 | word = normalize_word(word)
148 | label = pairs[-1]
149 | words.append(word)
150 | if idx < len(in_lines) -1 and len(in_lines[idx+1]) > 2:
151 | biword = word + in_lines[idx+1].strip().split()[0]
152 | else:
153 | biword = word + NULLKEY
154 | biwords.append(biword)
155 | labels.append(label)
156 | word_Ids.append(word_alphabet.get_index(word))
157 | biword_Ids.append(biword_alphabet.get_index(biword))
158 | label_Ids.append(label_alphabet.get_index(label))
159 | char_list = []
160 | char_Id = []
161 | for char in word:
162 | char_list.append(char)
163 | if char_padding_size > 0:
164 | char_number = len(char_list)
165 | if char_number < char_padding_size:
166 | char_list = char_list + [char_padding_symbol]*(char_padding_size-char_number)
167 | assert(len(char_list) == char_padding_size)
168 | else:
169 | ### not padding
170 | pass
171 | for char in char_list:
172 | char_Id.append(char_alphabet.get_index(char))
173 | chars.append(char_list)
174 | char_Ids.append(char_Id)
175 | else:
176 | if (max_sent_length < 0) or (len(words) < max_sent_length):
177 | instence_texts.append([words, biwords, chars, labels])
178 | instence_Ids.append([word_Ids, biword_Ids, char_Ids,label_Ids])
179 | words = []
180 | biwords = []
181 | chars = []
182 | labels = []
183 | word_Ids = []
184 | biword_Ids = []
185 | char_Ids = []
186 | label_Ids = []
187 | return instence_texts, instence_Ids
188 |
189 | def read_instance_with_gaz_no_char(input_file, gaz, word_alphabet, biword_alphabet, gaz_alphabet, label_alphabet, number_normalized, max_sent_length, use_single):
190 | """
191 | read instance with, word, biword, gaz, lable, no char
192 | Args:
193 | input_file: the input file path
194 | gaz: the gaz obj
195 | word_alphabet: word
196 | biword_alphabet: biword
197 | gaz_alphabet: gaz
198 | label_alphabet: label
199 | number_normalized: true or false
200 | max_sent_length: the max length
201 | """
202 | in_lines = open(input_file, 'r', encoding="utf-8").readlines()
203 | instence_texts = []
204 | instence_Ids = []
205 | words = []
206 | biwords = []
207 | labels = []
208 | word_Ids = []
209 | biword_Ids = []
210 | label_Ids = []
211 | for idx in range(len(in_lines)):
212 | line = in_lines[idx]
213 | if len(line) > 2:
214 | pairs = line.strip().split()
215 | word = pairs[0]
216 | if number_normalized:
217 | word = normalize_word(word)
218 | label = pairs[-1]
219 | if idx < len(in_lines) - 1 and len(in_lines[idx+1]) > 2:
220 | biword = word + in_lines[idx+1].strip().split()[0]
221 | else:
222 | biword = word + NULLKEY
223 | biwords.append(biword)
224 | words.append(word)
225 | labels.append(label)
226 | word_Ids.append(word_alphabet.get_index(word))
227 | biword_Ids.append(biword_alphabet.get_index(biword))
228 | label_Ids.append(label_alphabet.get_index(label))
229 | else:
230 | if ((max_sent_length < 0) or (len(words) < max_sent_length)) and (len(words) > 0):
231 | gazs = []
232 | gaz_Ids = []
233 | gazs_length = []
234 | w_length = len(words)
235 |
236 | reverse_gazs = [[] for i in range(w_length)]
237 | reverse_gaz_Ids = [[] for i in range(w_length)]
238 | flag = [0 for f in range(w_length)]
239 | # assign sub-sequence to every chinese letter
240 | for i in range(w_length):
241 | matched_list = gaz.enumerateMatchList(words[i:])
242 |
243 | if use_single and len(matched_list) > 0:
244 | f_len = len(matched_list[0])
245 |
246 | if (flag[i] == 1 or len(matched_list) > 1) and len(matched_list[-1]) == 1:
247 | matched_list = matched_list[:-1]
248 |
249 | for f_pos in range(i, i+f_len):
250 | flag[f_pos] = 1
251 |
252 | matched_length = [len(a) for a in matched_list]
253 |
254 | gazs.append(matched_list)
255 | matched_Id = [gaz_alphabet.get_index(entity) for entity in matched_list]
256 | if matched_Id:
257 | # gaz_Ids.append([matched_Id, matched_length])
258 | gaz_Ids.append(matched_Id)
259 | gazs_length.append(matched_length)
260 | else:
261 | gaz_Ids.append([])
262 | gazs_length.append([])
263 |
264 | for i in range(w_length-1, -1, -1):
265 | now_pos_gaz = gazs[i]
266 | now_pos_gaz_Id = gaz_Ids[i]
267 | now_pos_gaz_len = gazs_length[i]
268 |
269 | ## Traversing it
270 | l = len(now_pos_gaz)
271 | assert len(now_pos_gaz) == len(now_pos_gaz_Id)
272 | for j in range(l):
273 | width = now_pos_gaz_len[j]
274 | end_char_pos = i + width - 1
275 |
276 | reverse_gazs[end_char_pos].append(now_pos_gaz[j])
277 | reverse_gaz_Ids[end_char_pos].append(now_pos_gaz_Id[j])
278 |
279 |
280 | instence_texts.append([words, biwords, gazs, reverse_gazs, labels])
281 | instence_Ids.append([word_Ids, biword_Ids, gaz_Ids, reverse_gaz_Ids, label_Ids])
282 | words = []
283 | biwords = []
284 | labels = []
285 | word_Ids = []
286 | biword_Ids = []
287 | label_Ids = []
288 | gazs = []
289 | reverse_gazs = []
290 | gaz_Ids = []
291 | reverse_gaz_Ids = []
292 | return instence_texts, instence_Ids
293 |
294 | def read_instance_with_gaz(input_file, gaz, word_alphabet, biword_alphabet, char_alphabet, gaz_alphabet, label_alphabet, number_normalized, max_sent_length, char_padding_size=-1, char_padding_symbol = ''):
295 | in_lines = open(input_file,'r').readlines()
296 | instence_texts = []
297 | instence_Ids = []
298 | words = []
299 | biwords = []
300 | chars = []
301 | labels = []
302 | word_Ids = []
303 | biword_Ids = []
304 | char_Ids = []
305 | label_Ids = []
306 | for idx in range(len(in_lines)):
307 | line = in_lines[idx]
308 | if len(line) > 2:
309 | pairs = line.strip().split()
310 | word = pairs[0]
311 | if number_normalized:
312 | word = normalize_word(word)
313 | label = pairs[-1]
314 | if idx < len(in_lines) -1 and len(in_lines[idx+1]) > 2:
315 | biword = word + in_lines[idx+1].strip().split()[0]
316 | else:
317 | biword = word + NULLKEY
318 | biwords.append(biword)
319 | words.append(word)
320 | labels.append(label)
321 | word_Ids.append(word_alphabet.get_index(word))
322 | biword_Ids.append(biword_alphabet.get_index(biword))
323 | label_Ids.append(label_alphabet.get_index(label))
324 | char_list = []
325 | char_Id = []
326 | for char in word:
327 | char_list.append(char)
328 | if char_padding_size > 0:
329 | char_number = len(char_list)
330 | if char_number < char_padding_size:
331 | char_list = char_list + [char_padding_symbol]*(char_padding_size-char_number)
332 | assert(len(char_list) == char_padding_size)
333 | else:
334 | ### not padding
335 | pass
336 | for char in char_list:
337 | char_Id.append(char_alphabet.get_index(char))
338 | chars.append(char_list)
339 | char_Ids.append(char_Id)
340 |
341 | else:
342 | if ((max_sent_length < 0) or (len(words) < max_sent_length)) and (len(words)>0):
343 | gazs = []
344 | gaz_Ids = []
345 | w_length = len(words)
346 | # print sentence
347 | # for w in words:
348 | # print w," ",
349 | # print
350 | for idx in range(w_length):
351 | matched_list = gaz.enumerateMatchList(words[idx:])
352 | matched_length = [len(a) for a in matched_list]
353 | # print idx,"----------"
354 | # print "forward...feed:","".join(words[idx:])
355 | # for a in matched_list:
356 | # print a,len(a)," ",
357 | # print
358 |
359 | # print matched_length
360 |
361 | gazs.append(matched_list)
362 | matched_Id = [gaz_alphabet.get_index(entity) for entity in matched_list]
363 | if matched_Id:
364 | gaz_Ids.append([matched_Id, matched_length])
365 | else:
366 | gaz_Ids.append([])
367 |
368 | instence_texts.append([words, biwords, chars, gazs, labels])
369 | instence_Ids.append([word_Ids, biword_Ids, char_Ids, gaz_Ids, label_Ids])
370 | words = []
371 | biwords = []
372 | chars = []
373 | labels = []
374 | word_Ids = []
375 | biword_Ids = []
376 | char_Ids = []
377 | label_Ids = []
378 | gazs = []
379 | gaz_Ids = []
380 | return instence_texts, instence_Ids
381 |
382 |
383 | def read_instance_with_gaz_in_sentence(input_file, gaz, word_alphabet, biword_alphabet, char_alphabet, gaz_alphabet, label_alphabet, number_normalized, max_sent_length, char_padding_size=-1, char_padding_symbol = ''):
384 | in_lines = open(input_file,'r').readlines()
385 | instence_texts = []
386 | instence_Ids = []
387 | for idx in range(len(in_lines)):
388 | pair = in_lines[idx].strip()
389 | orig_words = list(pair[0])
390 |
391 | if (max_sent_length > 0) and (len(orig_words) > max_sent_length):
392 | continue
393 | biwords = []
394 | biword_Ids = []
395 | if number_normalized:
396 | words = []
397 | for word in orig_words:
398 | word = normalize_word(word)
399 | words.append(word)
400 | else:
401 | words = orig_words
402 | word_num = len(words)
403 | for idy in range(word_num):
404 | if idy < word_num - 1:
405 | biword = words[idy]+words[idy+1]
406 | else:
407 | biword = words[idy]+NULLKEY
408 | biwords.append(biword)
409 | biword_Ids.append(biword_alphabet.get_index(biword))
410 | word_Ids = [word_alphabet.get_index(word) for word in words]
411 | label = pair[-1]
412 | label_Id = label_alphabet.get_index(label)
413 | gazs = []
414 | gaz_Ids = []
415 | word_num = len(words)
416 | chars = [[word] for word in words]
417 | char_Ids = [[char_alphabet.get_index(word)] for word in words]
418 | ## print sentence
419 | # for w in words:
420 | # print w," ",
421 | # print
422 | for idx in range(word_num):
423 | matched_list = gaz.enumerateMatchList(words[idx:])
424 | matched_length = [len(a) for a in matched_list]
425 | # print idx,"----------"
426 | # print "forward...feed:","".join(words[idx:])
427 | # for a in matched_list:
428 | # print a,len(a)," ",
429 | # print
430 | # print matched_length
431 | gazs.append(matched_list)
432 | matched_Id = [gaz_alphabet.get_index(entity) for entity in matched_list]
433 | if matched_Id:
434 | gaz_Ids.append([matched_Id, matched_length])
435 | else:
436 | gaz_Ids.append([])
437 | instence_texts.append([words, biwords, chars, gazs, label])
438 | instence_Ids.append([word_Ids, biword_Ids, char_Ids, gaz_Ids, label_Id])
439 | return instence_texts, instence_Ids
440 |
441 |
442 | def build_pretrain_embedding(embedding_path, word_alphabet, embedd_dim=100, norm=True):
443 | embedd_dict = dict()
444 | if embedding_path != None:
445 | embedd_dict, embedd_dim = load_pretrain_emb(embedding_path)
446 | scale = np.sqrt(3.0 / embedd_dim)
447 | pretrain_emb = np.empty([word_alphabet.size(), embedd_dim])
448 | perfect_match = 0
449 | case_match = 0
450 | not_match = 0
451 |
452 | ## we should also init the index 0
453 | pretrain_emb[0, :] = np.random.uniform(-scale, scale, [1, embedd_dim])
454 |
455 | for word, index in word_alphabet.iteritems():
456 | if word in embedd_dict:
457 | if norm:
458 | pretrain_emb[index,:] = norm2one(embedd_dict[word])
459 | else:
460 | pretrain_emb[index,:] = embedd_dict[word]
461 | perfect_match += 1
462 | elif word.lower() in embedd_dict:
463 | if norm:
464 | pretrain_emb[index,:] = norm2one(embedd_dict[word.lower()])
465 | else:
466 | pretrain_emb[index,:] = embedd_dict[word.lower()]
467 | case_match += 1
468 | else:
469 | pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedd_dim])
470 | not_match += 1
471 | pretrained_size = len(embedd_dict)
472 | print("Embedding:\n pretrain word:%s, prefect match:%s, case_match:%s, oov:%s, oov%%:%s"%(pretrained_size, perfect_match, case_match, not_match, (not_match+0.)/word_alphabet.size()))
473 | return pretrain_emb, embedd_dim
474 |
475 |
476 |
477 | def norm2one(vec):
478 | root_sum_square = np.sqrt(np.sum(np.square(vec)))
479 | return vec/root_sum_square
480 |
481 | def load_pretrain_emb(embedding_path):
482 | embedd_dim = -1
483 | embedd_dict = dict()
484 | with open(embedding_path, 'r', encoding="utf-8") as file:
485 | for line in file:
486 | line = line.strip()
487 | if len(line) == 0:
488 | continue
489 | tokens = line.split()
490 | if embedd_dim < 0:
491 | embedd_dim = len(tokens) - 1
492 | else:
493 | assert (embedd_dim + 1 == len(tokens))
494 | embedd = np.empty([1, embedd_dim])
495 | embedd[:] = tokens[1:]
496 | embedd_dict[tokens[0]] = embedd
497 | return embedd_dict, embedd_dim
498 |
499 | if __name__ == '__main__':
500 | a = np.arange(9.0)
501 | print(a)
502 | print(norm2one(a))
503 |
--------------------------------------------------------------------------------
/ner/utils/gazetteer.py:
--------------------------------------------------------------------------------
1 | from ner.utils.trie import Trie
2 |
3 | class Gazetteer:
4 | def __init__(self, lower, use_single):
5 | self.trie = Trie(use_single)
6 | self.ent2type = {} ## word list to type
7 | self.ent2id = {"":0} ## word list to id
8 | self.lower = lower
9 | self.space = ""
10 |
11 | def enumerateMatchList(self, word_list):
12 | if self.lower:
13 | word_list = [word.lower() for word in word_list]
14 | match_list = self.trie.enumerateMatch(word_list, self.space)
15 | return match_list
16 |
17 | def insert(self, word_list, source):
18 | """
19 | Args:
20 | word_list: this is the letter list of a possible word,
21 | source: source is the word ?
22 | """
23 | if self.lower:
24 | word_list = [word.lower() for word in word_list]
25 | self.trie.insert(word_list)
26 | string = self.space.join(word_list)
27 | if string not in self.ent2type:
28 | self.ent2type[string] = source
29 | if string not in self.ent2id:
30 | self.ent2id[string] = len(self.ent2id)
31 |
32 | def searchId(self, word_list):
33 | if self.lower:
34 | word_list = [word.lower() for word in word_list]
35 | string = self.space.join(word_list)
36 | if string in self.ent2id:
37 | return self.ent2id[string]
38 | return self.ent2id[""]
39 |
40 | def searchType(self, word_list):
41 | if self.lower:
42 | word_list = [word.lower() for word in word_list]
43 | string = self.space.join(word_list)
44 | if string in self.ent2type:
45 | return self.ent2type[string]
46 | print("Error in finding entity type at gazetteer.py, exit program! String:", string)
47 | exit(0)
48 |
49 | def size(self):
50 | return len(self.ent2type)
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/ner/utils/metric.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | # from operator import add
5 | #
6 | import numpy as np
7 | import math
8 | import sys
9 | import os
10 |
11 |
12 |
13 | ## input as sentence level labels
14 | def get_ner_fmeasure(golden_lists, predict_lists, label_type="BMES"):
15 | sent_num = len(golden_lists)
16 | golden_full = []
17 | predict_full = []
18 | right_full = []
19 | right_tag = 0
20 | all_tag = 0
21 | for idx in range(0, sent_num):
22 | # word_list = sentence_lists[idx]
23 | golden_list = golden_lists[idx]
24 | predict_list = predict_lists[idx]
25 | for idy in range(len(golden_list)):
26 | if golden_list[idy] == predict_list[idy]:
27 | right_tag += 1
28 | all_tag += len(golden_list)
29 | if label_type == "BMES":
30 | gold_matrix = get_ner_BMES(golden_list)
31 | pred_matrix = get_ner_BMES(predict_list)
32 | else:
33 | gold_matrix = get_ner_BIO(golden_list)
34 | pred_matrix = get_ner_BIO(predict_list)
35 | # print "gold", gold_matrix
36 | # print "pred", pred_matrix
37 | right_ner = list(set(gold_matrix).intersection(set(pred_matrix)))
38 | golden_full += gold_matrix
39 | predict_full += pred_matrix
40 | right_full += right_ner
41 | right_num = len(right_full)
42 | golden_num = len(golden_full)
43 | predict_num = len(predict_full)
44 |
45 | if predict_num == 0:
46 | precision = -1
47 | else:
48 | precision = (right_num+0.0)/predict_num
49 |
50 | if golden_num == 0:
51 | recall = -1
52 | else:
53 | recall = (right_num+0.0)/golden_num
54 | if (precision == -1) or (recall == -1) or (precision + recall) <= 0.:
55 | f_measure = -1
56 | else:
57 | f_measure = 2 * precision * recall / (precision + recall)
58 | accuracy = (right_tag+0.0)/all_tag
59 | # print "Accuracy: ", right_tag,"/",all_tag,"=",accuracy
60 | print("gold_num = ", golden_num, " pred_num = ", predict_num, " right_num = ", right_num)
61 | return accuracy, precision, recall, f_measure
62 |
63 |
64 | def reverse_style(input_string):
65 | target_position = input_string.index('[')
66 | input_len = len(input_string)
67 | output_string = input_string[target_position:input_len] + input_string[0:target_position]
68 | return output_string
69 |
70 |
71 | def get_ner_BMES(label_list):
72 | # list_len = len(word_list)
73 | # assert(list_len == len(label_list)), "word list size unmatch with label list"
74 | list_len = len(label_list)
75 | begin_label = 'B-'
76 | end_label = 'E-'
77 | single_label = 'S-'
78 | whole_tag = ''
79 | index_tag = ''
80 | tag_list = []
81 | stand_matrix = []
82 | for i in range(0, list_len):
83 | # wordlabel = word_list[i]
84 | current_label = label_list[i].upper()
85 | if begin_label in current_label:
86 | if index_tag != '':
87 | tag_list.append(whole_tag + ',' + str(i-1))
88 | whole_tag = current_label.replace(begin_label,"",1) +'[' +str(i)
89 | index_tag = current_label.replace(begin_label,"",1)
90 |
91 | elif single_label in current_label:
92 | if index_tag != '':
93 | tag_list.append(whole_tag + ',' + str(i-1))
94 | whole_tag = current_label.replace(single_label,"",1) +'[' +str(i)
95 | tag_list.append(whole_tag)
96 | whole_tag = ""
97 | index_tag = ""
98 | elif end_label in current_label:
99 | if index_tag != '':
100 | tag_list.append(whole_tag +',' + str(i))
101 | whole_tag = ''
102 | index_tag = ''
103 | else:
104 | continue
105 | if (whole_tag != '') and (index_tag != ''):
106 | tag_list.append(whole_tag)
107 | tag_list_len = len(tag_list)
108 |
109 | for i in range(0, tag_list_len):
110 | if len(tag_list[i]) > 0:
111 | tag_list[i] = tag_list[i]+ ']'
112 | insert_list = reverse_style(tag_list[i])
113 | stand_matrix.append(insert_list)
114 | # print stand_matrix
115 | return stand_matrix
116 |
117 |
118 | def get_ner_BIO(label_list):
119 | # list_len = len(word_list)
120 | # assert(list_len == len(label_list)), "word list size unmatch with label list"
121 | list_len = len(label_list)
122 | begin_label = 'B-'
123 | inside_label = 'I-'
124 | whole_tag = ''
125 | index_tag = ''
126 | tag_list = []
127 | stand_matrix = []
128 | for i in range(0, list_len):
129 | # wordlabel = word_list[i]
130 | current_label = label_list[i].upper()
131 | if begin_label in current_label:
132 | if index_tag == '':
133 | whole_tag = current_label.replace(begin_label,"",1) +'[' +str(i)
134 | index_tag = current_label.replace(begin_label,"",1)
135 | else:
136 | tag_list.append(whole_tag + ',' + str(i-1))
137 | whole_tag = current_label.replace(begin_label,"",1) + '[' + str(i)
138 | index_tag = current_label.replace(begin_label,"",1)
139 |
140 | elif inside_label in current_label:
141 | if current_label.replace(inside_label,"",1) == index_tag:
142 | whole_tag = whole_tag
143 | else:
144 | if (whole_tag != '')&(index_tag != ''):
145 | tag_list.append(whole_tag +',' + str(i-1))
146 | whole_tag = ''
147 | index_tag = ''
148 | else:
149 | if (whole_tag != '')&(index_tag != ''):
150 | tag_list.append(whole_tag +',' + str(i-1))
151 | whole_tag = ''
152 | index_tag = ''
153 |
154 | if (whole_tag != '')&(index_tag != ''):
155 | tag_list.append(whole_tag)
156 | tag_list_len = len(tag_list)
157 |
158 | for i in range(0, tag_list_len):
159 | if len(tag_list[i]) > 0:
160 | tag_list[i] = tag_list[i]+ ']'
161 | insert_list = reverse_style(tag_list[i])
162 | stand_matrix.append(insert_list)
163 |
164 | return stand_matrix
165 |
166 |
167 |
168 | def readSentence(input_file):
169 | in_lines = open(input_file,'r').readlines()
170 | sentences = []
171 | labels = []
172 | sentence = []
173 | label = []
174 | for line in in_lines:
175 | if len(line) < 2:
176 | sentences.append(sentence)
177 | labels.append(label)
178 | sentence = []
179 | label = []
180 | else:
181 | pair = line.strip('\n').split(' ')
182 | sentence.append(pair[0])
183 | label.append(pair[-1])
184 | return sentences,labels
185 |
186 |
187 | def readTwoLabelSentence(input_file, pred_col=-1):
188 | in_lines = open(input_file,'r').readlines()
189 | sentences = []
190 | predict_labels = []
191 | golden_labels = []
192 | sentence = []
193 | predict_label = []
194 | golden_label = []
195 | for line in in_lines:
196 | if "##score##" in line:
197 | continue
198 | if len(line) < 2:
199 | sentences.append(sentence)
200 | golden_labels.append(golden_label)
201 | predict_labels.append(predict_label)
202 | sentence = []
203 | golden_label = []
204 | predict_label = []
205 | else:
206 | pair = line.strip('\n').split(' ')
207 | sentence.append(pair[0])
208 | golden_label.append(pair[1])
209 | predict_label.append(pair[pred_col])
210 |
211 | return sentences,golden_labels,predict_labels
212 |
213 |
214 | def fmeasure_from_file(golden_file, predict_file, label_type="BMES"):
215 | print("Get f measure from file:", golden_file, predict_file)
216 | print("Label format:",label_type)
217 | golden_sent,golden_labels = readSentence(golden_file)
218 | predict_sent,predict_labels = readSentence(predict_file)
219 | acc, P,R,F = get_ner_fmeasure(golden_labels, predict_labels, label_type)
220 | print ("Acc:%s, P:%s R:%s, F:%s"%(acc, P,R,F))
221 |
222 |
223 |
224 | def fmeasure_from_singlefile(twolabel_file, label_type="BMES", pred_col=-1):
225 | sent,golden_labels,predict_labels = readTwoLabelSentence(twolabel_file, pred_col)
226 | P,R,F = get_ner_fmeasure(golden_labels, predict_labels, label_type)
227 | print ("P:%s, R:%s, F:%s"%(P,R,F))
228 |
229 |
230 |
231 | if __name__ == '__main__':
232 | # print "sys:",len(sys.argv)
233 | if len(sys.argv) == 3:
234 | fmeasure_from_singlefile(sys.argv[1],"BMES",int(sys.argv[2]))
235 | else:
236 | fmeasure_from_singlefile(sys.argv[1],"BMES")
237 |
238 |
--------------------------------------------------------------------------------
/ner/utils/msra_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import sys
4 | import numpy as np
5 | import pickle
6 | from ner.utils.alphabet import Alphabet
7 | from ner.utils.functions import *
8 | from ner.utils.gazetteer import Gazetteer
9 |
10 |
11 |
12 | START = ""
13 | UNKNOWN = ""
14 | PADDING = ""
15 | NULLKEY = "-null-"
16 |
17 | class Data:
18 | def __init__(self):
19 | self.MAX_SENTENCE_LENGTH = 5000
20 | self.MAX_WORD_LENGTH = -1
21 | self.number_normalized = True
22 | self.norm_word_emb = True
23 | self.norm_biword_emb = True
24 | self.norm_gaz_emb = False
25 | self.use_single = False
26 | self.word_alphabet = Alphabet('word')
27 | self.biword_alphabet = Alphabet('biword')
28 | self.char_alphabet = Alphabet('character')
29 | # self.word_alphabet.add(START)
30 | # self.word_alphabet.add(UNKNOWN)
31 | # self.char_alphabet.add(START)
32 | # self.char_alphabet.add(UNKNOWN)
33 | # self.char_alphabet.add(PADDING)
34 | self.label_alphabet = Alphabet('label', True)
35 | self.gaz_lower = False
36 | self.gaz = Gazetteer(self.gaz_lower, self.use_single)
37 | self.gaz_alphabet = Alphabet('gaz')
38 | self.HP_fix_gaz_emb = False
39 | self.HP_use_gaz = True
40 |
41 | self.tagScheme = "NoSeg"
42 | self.char_features = "LSTM"
43 |
44 | self.train_texts = []
45 | self.dev_texts = []
46 | self.test_texts = []
47 | self.raw_texts = []
48 |
49 | self.train_Ids = []
50 | self.dev_Ids = []
51 | self.test_Ids = []
52 | self.raw_Ids = []
53 | self.use_bigram = True
54 | self.word_emb_dim = 100
55 | self.biword_emb_dim = 50
56 | self.char_emb_dim = 30
57 | self.gaz_emb_dim = 50
58 | self.gaz_dropout = 0.5
59 | self.pretrain_word_embedding = None
60 | self.pretrain_biword_embedding = None
61 | self.pretrain_gaz_embedding = None
62 | self.label_size = 0
63 | self.word_alphabet_size = 0
64 | self.biword_alphabet_size = 0
65 | self.char_alphabet_size = 0
66 | self.label_alphabet_size = 0
67 | ### hyperparameters
68 | self.HP_iteration = 100
69 | self.HP_batch_size = 10
70 | self.HP_char_hidden_dim = 100
71 | self.HP_hidden_dim = 200
72 | self.HP_dropout = 0.5
73 | self.HP_lstm_layer = 1
74 | self.HP_bilstm = True
75 | self.HP_use_char = False
76 | self.HP_gpu = False
77 | self.HP_lr = 0.015
78 | self.HP_lr_decay = 0.05
79 | self.HP_clip = 5.0
80 | self.HP_momentum = 0
81 |
82 |
83 | def show_data_summary(self):
84 | print("DATA SUMMARY START:")
85 | print(" Tag scheme: %s"%(self.tagScheme))
86 | print(" MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH))
87 | print(" MAX WORD LENGTH: %s"%(self.MAX_WORD_LENGTH))
88 | print(" Number normalized: %s"%(self.number_normalized))
89 | print(" Use bigram: %s"%(self.use_bigram))
90 | print(" Word alphabet size: %s"%(self.word_alphabet_size))
91 | print(" Biword alphabet size: %s"%(self.biword_alphabet_size))
92 | print(" Char alphabet size: %s"%(self.char_alphabet_size))
93 | print(" Gaz alphabet size: %s"%(self.gaz_alphabet.size()))
94 | print(" Label alphabet size: %s"%(self.label_alphabet_size))
95 | print(" Word embedding size: %s"%(self.word_emb_dim))
96 | print(" Biword embedding size: %s"%(self.biword_emb_dim))
97 | print(" Char embedding size: %s"%(self.char_emb_dim))
98 | print(" Gaz embedding size: %s"%(self.gaz_emb_dim))
99 | print(" Norm word emb: %s"%(self.norm_word_emb))
100 | print(" Norm biword emb: %s"%(self.norm_biword_emb))
101 | print(" Norm gaz emb: %s"%(self.norm_gaz_emb))
102 | print(" Norm gaz dropout: %s"%(self.gaz_dropout))
103 | print(" Train instance number: %s"%(len(self.train_texts)))
104 | print(" Dev instance number: %s"%(len(self.dev_texts)))
105 | print(" Test instance number: %s"%(len(self.test_texts)))
106 | print(" Raw instance number: %s"%(len(self.raw_texts)))
107 | print(" Hyperpara iteration: %s"%(self.HP_iteration))
108 | print(" Hyperpara batch size: %s"%(self.HP_batch_size))
109 | print(" Hyperpara lr: %s"%(self.HP_lr))
110 | print(" Hyperpara lr_decay: %s"%(self.HP_lr_decay))
111 | print(" Hyperpara HP_clip: %s"%(self.HP_clip))
112 | print(" Hyperpara momentum: %s"%(self.HP_momentum))
113 | print(" Hyperpara hidden_dim: %s"%(self.HP_hidden_dim))
114 | print(" Hyperpara dropout: %s"%(self.HP_dropout))
115 | print(" Hyperpara lstm_layer: %s"%(self.HP_lstm_layer))
116 | print(" Hyperpara bilstm: %s"%(self.HP_bilstm))
117 | print(" Hyperpara GPU: %s"%(self.HP_gpu))
118 | print(" Hyperpara use_gaz: %s"%(self.HP_use_gaz))
119 | print(" Hyperpara fix gaz emb: %s"%(self.HP_fix_gaz_emb))
120 | print(" Hyperpara use_char: %s"%(self.HP_use_char))
121 | if self.HP_use_char:
122 | print(" Char_features: %s"%(self.char_features))
123 | print("DATA SUMMARY END.")
124 | sys.stdout.flush()
125 |
126 | def refresh_label_alphabet(self, input_file):
127 | old_size = self.label_alphabet_size
128 | self.label_alphabet.clear(True)
129 | in_lines = open(input_file,'r', encoding="utf-8").readlines()
130 | for line in in_lines:
131 | if len(line) > 2:
132 | pairs = line.strip().split()
133 | label = pairs[-1]
134 | self.label_alphabet.add(label)
135 | self.label_alphabet_size = self.label_alphabet.size()
136 | startS = False
137 | startB = False
138 | for label,_ in self.label_alphabet.iteritems():
139 | if "S-" in label.upper():
140 | startS = True
141 | elif "B-" in label.upper():
142 | startB = True
143 | if startB:
144 | if startS:
145 | self.tagScheme = "BMES"
146 | else:
147 | self.tagScheme = "BIO"
148 | self.fix_alphabet()
149 | print("Refresh label alphabet finished: old:%s -> new:%s"%(old_size, self.label_alphabet_size))
150 |
151 |
152 |
153 | def build_alphabet(self, input_file):
154 | in_lines = open(input_file, 'r', encoding="utf-8").readlines()
155 | for idx in range(len(in_lines)):
156 | line = in_lines[idx]
157 | if len(line) > 2:
158 | pairs = line.strip().split()
159 | word = pairs[0]
160 | if self.number_normalized:
161 | word = normalize_word(word)
162 | label = pairs[-1]
163 | self.label_alphabet.add(label)
164 | self.word_alphabet.add(word)
165 | if idx < len(in_lines) - 1 and len(in_lines[idx+1]) > 2:
166 | biword = word + in_lines[idx+1].strip().split()[0]
167 | else:
168 | biword = word + NULLKEY
169 | self.biword_alphabet.add(biword)
170 | for char in word:
171 | self.char_alphabet.add(char)
172 | self.word_alphabet_size = self.word_alphabet.size()
173 | self.biword_alphabet_size = self.biword_alphabet.size()
174 | self.char_alphabet_size = self.char_alphabet.size()
175 | self.label_alphabet_size = self.label_alphabet.size()
176 | startS = False
177 | startB = False
178 | for label,_ in self.label_alphabet.iteritems():
179 | if "S-" in label.upper():
180 | startS = True
181 | elif "B-" in label.upper():
182 | startB = True
183 | if startB:
184 | if startS:
185 | self.tagScheme = "BMES"
186 | else:
187 | self.tagScheme = "BIO"
188 |
189 |
190 | def build_gaz_file(self, gaz_file):
191 | ## build gaz file,initial read gaz embedding file
192 | ## we only get the word, do not read embedding this step
193 | if gaz_file:
194 | fins = open(gaz_file, 'r', encoding="utf-8").readlines()
195 | for fin in fins:
196 | fin = fin.strip().split()[0]
197 | if fin:
198 | self.gaz.insert(fin, "one_source")
199 | print("Load gaz file: ", gaz_file, " total size:", self.gaz.size())
200 | else:
201 | print("Gaz file is None, load nothing")
202 |
203 |
204 | def build_gaz_alphabet(self, input_file):
205 | """
206 | based on the train, dev, test file, we only save the seb-sequence word that my be appear
207 | """
208 | in_lines = open(input_file,'r', encoding="utf-8").readlines()
209 | word_list = []
210 | for line in in_lines:
211 | if len(line) > 3:
212 | word = line.split()[0]
213 | if self.number_normalized:
214 | word = normalize_word(word)
215 | word_list.append(word)
216 | else:
217 | w_length = len(word_list)
218 | ## Travser from [0: n], [1: n] to [n-1: n]
219 | for idx in range(w_length):
220 | matched_entity = self.gaz.enumerateMatchList(word_list[idx:])
221 | for entity in matched_entity:
222 | # print entity, self.gaz.searchId(entity),self.gaz.searchType(entity)
223 | self.gaz_alphabet.add(entity)
224 | word_list = []
225 | print("gaz alphabet size:", self.gaz_alphabet.size())
226 |
227 |
228 | def fix_alphabet(self):
229 | self.word_alphabet.close()
230 | self.biword_alphabet.close()
231 | self.char_alphabet.close()
232 | self.label_alphabet.close()
233 | self.gaz_alphabet.close()
234 |
235 |
236 | def build_word_pretrain_emb(self, emb_path):
237 | print("build word pretrain emb...")
238 | self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.word_emb_dim, self.norm_word_emb)
239 |
240 | def build_biword_pretrain_emb(self, emb_path):
241 | print("build biword pretrain emb...")
242 | self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(emb_path, self.biword_alphabet, self.biword_emb_dim, self.norm_biword_emb)
243 |
244 | def build_gaz_pretrain_emb(self, emb_path):
245 | print("build gaz pretrain emb...")
246 | self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(emb_path, self.gaz_alphabet, self.gaz_emb_dim, self.norm_gaz_emb)
247 |
248 |
249 | def generate_word_instance(self, input_file, name):
250 | """
251 | every instance include: words, labels, word_ids, label_ids
252 | """
253 | self.fix_alphabet()
254 | if name == "train":
255 | self.train_texts, self.train_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
256 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
257 | elif name == "dev":
258 | self.dev_texts, self.dev_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
259 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
260 | elif name == "test":
261 | self.test_texts, self.test_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
262 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
263 | elif name == "raw":
264 | self.raw_texts, self.raw_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
265 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
266 | else:
267 | print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name))
268 |
269 |
270 | def generate_instance(self, input_file, name):
271 | """
272 | every instance include: words, biwords, chars, labels,
273 | word_ids, biword_ids, char_ids, label_ids
274 | """
275 | self.fix_alphabet()
276 | if name == "train":
277 | self.train_texts, self.train_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
278 | elif name == "dev":
279 | self.dev_texts, self.dev_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
280 | elif name == "test":
281 | self.test_texts, self.test_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
282 | elif name == "raw":
283 | self.raw_texts, self.raw_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
284 | else:
285 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
286 |
287 |
288 | def generate_instance_with_gaz(self, input_file, name):
289 | """
290 | every instance include: words, biwords, chars, labels, gazs,
291 | word_ids, biword_ids, char_ids, label_ids, gaz_ids
292 | """
293 | self.fix_alphabet()
294 | if name == "train":
295 | self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
296 | elif name == "dev":
297 | self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
298 | elif name == "test":
299 | self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
300 | elif name == "raw":
301 | self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
302 | else:
303 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
304 |
305 | def generate_instance_with_gaz_no_char(self, input_file, name):
306 | """
307 | every instance include:
308 | words, biwords, gazs, labels
309 | word_Ids, biword_Ids, gazs_Ids, label_Ids
310 | """
311 | self.fix_alphabet()
312 | if name == "train":
313 | self.train_texts, self.train_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
314 | elif name == "dev":
315 | self.dev_texts, self.dev_Ids = read_instance_with_gaz_no_char(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
316 | elif name == "test":
317 | self.test_texts, self.test_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
318 | elif name == "raw":
319 | self.raw_texts, self.raw_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
320 | else:
321 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
322 |
323 |
324 | def write_decoded_results(self, output_file, predict_results, name):
325 | fout = open(output_file,'w', encoding="utf-8")
326 | sent_num = len(predict_results)
327 | content_list = []
328 | if name == 'raw':
329 | content_list = self.raw_texts
330 | elif name == 'test':
331 | content_list = self.test_texts
332 | elif name == 'dev':
333 | content_list = self.dev_texts
334 | elif name == 'train':
335 | content_list = self.train_texts
336 | else:
337 | print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !")
338 | assert(sent_num == len(content_list))
339 | for idx in range(sent_num):
340 | sent_length = len(predict_results[idx])
341 | for idy in range(sent_length):
342 | ## content_list[idx] is a list with [word, char, label]
343 | fout.write(content_list[idx][0][idy].encode('utf-8') + " " + predict_results[idx][idy] + '\n')
344 |
345 | fout.write('\n')
346 | fout.close()
347 | print("Predict %s result has been written into file. %s"%(name, output_file))
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
--------------------------------------------------------------------------------
/ner/utils/note4_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import sys
4 | import numpy as np
5 | import pickle
6 | from ner.utils.alphabet import Alphabet
7 | from ner.utils.functions import *
8 | from ner.utils.gazetteer import Gazetteer
9 |
10 |
11 |
12 | START = ""
13 | UNKNOWN = ""
14 | PADDING = ""
15 | NULLKEY = "-null-"
16 |
17 | class Data:
18 | def __init__(self):
19 | self.MAX_SENTENCE_LENGTH = 250
20 | self.MAX_WORD_LENGTH = -1
21 | self.number_normalized = True
22 | self.norm_word_emb = True
23 | self.norm_biword_emb = True
24 | self.norm_gaz_emb = False
25 | self.use_single = True
26 | self.word_alphabet = Alphabet('word')
27 | self.biword_alphabet = Alphabet('biword')
28 | self.char_alphabet = Alphabet('character')
29 | # self.word_alphabet.add(START)
30 | # self.word_alphabet.add(UNKNOWN)
31 | # self.char_alphabet.add(START)
32 | # self.char_alphabet.add(UNKNOWN)
33 | # self.char_alphabet.add(PADDING)
34 | self.label_alphabet = Alphabet('label', True)
35 | self.gaz_lower = False
36 | self.gaz = Gazetteer(self.gaz_lower, self.use_single)
37 | self.gaz_alphabet = Alphabet('gaz')
38 | self.HP_fix_gaz_emb = False
39 | self.HP_use_gaz = True
40 |
41 | self.tagScheme = "NoSeg"
42 | self.char_features = "LSTM"
43 |
44 | self.train_texts = []
45 | self.dev_texts = []
46 | self.test_texts = []
47 | self.raw_texts = []
48 |
49 | self.train_Ids = []
50 | self.dev_Ids = []
51 | self.test_Ids = []
52 | self.raw_Ids = []
53 | self.use_bigram = True
54 | self.word_emb_dim = 100
55 | self.biword_emb_dim = 50
56 | self.char_emb_dim = 30
57 | self.gaz_emb_dim = 50
58 | self.gaz_dropout = 0.5
59 | self.pretrain_word_embedding = None
60 | self.pretrain_biword_embedding = None
61 | self.pretrain_gaz_embedding = None
62 | self.label_size = 0
63 | self.word_alphabet_size = 0
64 | self.biword_alphabet_size = 0
65 | self.char_alphabet_size = 0
66 | self.label_alphabet_size = 0
67 | ### hyperparameters
68 | self.HP_iteration = 100
69 | self.HP_batch_size = 10
70 | self.HP_char_hidden_dim = 100
71 | self.HP_hidden_dim = 200
72 | self.HP_dropout = 0.5
73 | self.HP_lstm_layer = 1
74 | self.HP_bilstm = True
75 | self.HP_use_char = False
76 | self.HP_gpu = False
77 | self.HP_lr = 0.015
78 | self.HP_lr_decay = 0.05
79 | self.HP_clip = 5.0
80 | self.HP_momentum = 0
81 |
82 |
83 | def show_data_summary(self):
84 | print("DATA SUMMARY START:")
85 | print(" Tag scheme: %s"%(self.tagScheme))
86 | print(" MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH))
87 | print(" MAX WORD LENGTH: %s"%(self.MAX_WORD_LENGTH))
88 | print(" Number normalized: %s"%(self.number_normalized))
89 | print(" Use bigram: %s"%(self.use_bigram))
90 | print(" Word alphabet size: %s"%(self.word_alphabet_size))
91 | print(" Biword alphabet size: %s"%(self.biword_alphabet_size))
92 | print(" Char alphabet size: %s"%(self.char_alphabet_size))
93 | print(" Gaz alphabet size: %s"%(self.gaz_alphabet.size()))
94 | print(" Label alphabet size: %s"%(self.label_alphabet_size))
95 | print(" Word embedding size: %s"%(self.word_emb_dim))
96 | print(" Biword embedding size: %s"%(self.biword_emb_dim))
97 | print(" Char embedding size: %s"%(self.char_emb_dim))
98 | print(" Gaz embedding size: %s"%(self.gaz_emb_dim))
99 | print(" Norm word emb: %s"%(self.norm_word_emb))
100 | print(" Norm biword emb: %s"%(self.norm_biword_emb))
101 | print(" Norm gaz emb: %s"%(self.norm_gaz_emb))
102 | print(" Norm gaz dropout: %s"%(self.gaz_dropout))
103 | print(" Train instance number: %s"%(len(self.train_texts)))
104 | print(" Dev instance number: %s"%(len(self.dev_texts)))
105 | print(" Test instance number: %s"%(len(self.test_texts)))
106 | print(" Raw instance number: %s"%(len(self.raw_texts)))
107 | print(" Hyperpara iteration: %s"%(self.HP_iteration))
108 | print(" Hyperpara batch size: %s"%(self.HP_batch_size))
109 | print(" Hyperpara lr: %s"%(self.HP_lr))
110 | print(" Hyperpara lr_decay: %s"%(self.HP_lr_decay))
111 | print(" Hyperpara HP_clip: %s"%(self.HP_clip))
112 | print(" Hyperpara momentum: %s"%(self.HP_momentum))
113 | print(" Hyperpara hidden_dim: %s"%(self.HP_hidden_dim))
114 | print(" Hyperpara dropout: %s"%(self.HP_dropout))
115 | print(" Hyperpara lstm_layer: %s"%(self.HP_lstm_layer))
116 | print(" Hyperpara bilstm: %s"%(self.HP_bilstm))
117 | print(" Hyperpara GPU: %s"%(self.HP_gpu))
118 | print(" Hyperpara use_gaz: %s"%(self.HP_use_gaz))
119 | print(" Hyperpara fix gaz emb: %s"%(self.HP_fix_gaz_emb))
120 | print(" Hyperpara use_char: %s"%(self.HP_use_char))
121 | if self.HP_use_char:
122 | print(" Char_features: %s"%(self.char_features))
123 | print("DATA SUMMARY END.")
124 | sys.stdout.flush()
125 |
126 | def refresh_label_alphabet(self, input_file):
127 | old_size = self.label_alphabet_size
128 | self.label_alphabet.clear(True)
129 | in_lines = open(input_file,'r', encoding="utf-8").readlines()
130 | for line in in_lines:
131 | if len(line) > 2:
132 | pairs = line.strip().split()
133 | label = pairs[-1]
134 | self.label_alphabet.add(label)
135 | self.label_alphabet_size = self.label_alphabet.size()
136 | startS = False
137 | startB = False
138 | for label,_ in self.label_alphabet.iteritems():
139 | if "S-" in label.upper():
140 | startS = True
141 | elif "B-" in label.upper():
142 | startB = True
143 | if startB:
144 | if startS:
145 | self.tagScheme = "BMES"
146 | else:
147 | self.tagScheme = "BIO"
148 | self.fix_alphabet()
149 | print("Refresh label alphabet finished: old:%s -> new:%s"%(old_size, self.label_alphabet_size))
150 |
151 |
152 |
153 | def build_alphabet(self, input_file):
154 | in_lines = open(input_file, 'r', encoding="utf-8").readlines()
155 | for idx in range(len(in_lines)):
156 | line = in_lines[idx]
157 | if len(line) > 2:
158 | pairs = line.strip().split()
159 | word = pairs[0]
160 | if self.number_normalized:
161 | word = normalize_word(word)
162 | label = pairs[-1]
163 | self.label_alphabet.add(label)
164 | self.word_alphabet.add(word)
165 | if idx < len(in_lines) - 1 and len(in_lines[idx+1]) > 2:
166 | biword = word + in_lines[idx+1].strip().split()[0]
167 | else:
168 | biword = word + NULLKEY
169 | self.biword_alphabet.add(biword)
170 | for char in word:
171 | self.char_alphabet.add(char)
172 | self.word_alphabet_size = self.word_alphabet.size()
173 | self.biword_alphabet_size = self.biword_alphabet.size()
174 | self.char_alphabet_size = self.char_alphabet.size()
175 | self.label_alphabet_size = self.label_alphabet.size()
176 | startS = False
177 | startB = False
178 | for label,_ in self.label_alphabet.iteritems():
179 | if "S-" in label.upper():
180 | startS = True
181 | elif "B-" in label.upper():
182 | startB = True
183 | if startB:
184 | if startS:
185 | self.tagScheme = "BMES"
186 | else:
187 | self.tagScheme = "BIO"
188 |
189 |
190 | def build_gaz_file(self, gaz_file):
191 | ## build gaz file,initial read gaz embedding file
192 | ## we only get the word, do not read embedding this step
193 | if gaz_file:
194 | fins = open(gaz_file, 'r', encoding="utf-8").readlines()
195 | for fin in fins:
196 | fin = fin.strip().split()[0]
197 | if fin:
198 | self.gaz.insert(fin, "one_source")
199 | print("Load gaz file: ", gaz_file, " total size:", self.gaz.size())
200 | else:
201 | print("Gaz file is None, load nothing")
202 |
203 |
204 | def build_gaz_alphabet(self, input_file):
205 | """
206 | based on the train, dev, test file, we only save the seb-sequence word that my be appear
207 | """
208 | in_lines = open(input_file,'r', encoding="utf-8").readlines()
209 | word_list = []
210 | for line in in_lines:
211 | if len(line) > 3:
212 | word = line.split()[0]
213 | if self.number_normalized:
214 | word = normalize_word(word)
215 | word_list.append(word)
216 | else:
217 | w_length = len(word_list)
218 | ## Travser from [0: n], [1: n] to [n-1: n]
219 | for idx in range(w_length):
220 | matched_entity = self.gaz.enumerateMatchList(word_list[idx:])
221 | for entity in matched_entity:
222 | # print entity, self.gaz.searchId(entity),self.gaz.searchType(entity)
223 | self.gaz_alphabet.add(entity)
224 | word_list = []
225 | print("gaz alphabet size:", self.gaz_alphabet.size())
226 |
227 |
228 | def fix_alphabet(self):
229 | self.word_alphabet.close()
230 | self.biword_alphabet.close()
231 | self.char_alphabet.close()
232 | self.label_alphabet.close()
233 | self.gaz_alphabet.close()
234 |
235 |
236 | def build_word_pretrain_emb(self, emb_path):
237 | print("build word pretrain emb...")
238 | self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.word_emb_dim, self.norm_word_emb)
239 |
240 | def build_biword_pretrain_emb(self, emb_path):
241 | print("build biword pretrain emb...")
242 | self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(emb_path, self.biword_alphabet, self.biword_emb_dim, self.norm_biword_emb)
243 |
244 | def build_gaz_pretrain_emb(self, emb_path):
245 | print("build gaz pretrain emb...")
246 | self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(emb_path, self.gaz_alphabet, self.gaz_emb_dim, self.norm_gaz_emb)
247 |
248 |
249 | def generate_word_instance(self, input_file, name):
250 | """
251 | every instance include: words, labels, word_ids, label_ids
252 | """
253 | self.fix_alphabet()
254 | if name == "train":
255 | self.train_texts, self.train_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
256 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
257 | elif name == "dev":
258 | self.dev_texts, self.dev_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
259 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
260 | elif name == "test":
261 | self.test_texts, self.test_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
262 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
263 | elif name == "raw":
264 | self.raw_texts, self.raw_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
265 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
266 | else:
267 | print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name))
268 |
269 |
270 | def generate_instance(self, input_file, name):
271 | """
272 | every instance include: words, biwords, chars, labels,
273 | word_ids, biword_ids, char_ids, label_ids
274 | """
275 | self.fix_alphabet()
276 | if name == "train":
277 | self.train_texts, self.train_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
278 | elif name == "dev":
279 | self.dev_texts, self.dev_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
280 | elif name == "test":
281 | self.test_texts, self.test_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
282 | elif name == "raw":
283 | self.raw_texts, self.raw_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
284 | else:
285 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
286 |
287 |
288 | def generate_instance_with_gaz(self, input_file, name):
289 | """
290 | every instance include: words, biwords, chars, labels, gazs,
291 | word_ids, biword_ids, char_ids, label_ids, gaz_ids
292 | """
293 | self.fix_alphabet()
294 | if name == "train":
295 | self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
296 | elif name == "dev":
297 | self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
298 | elif name == "test":
299 | self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
300 | elif name == "raw":
301 | self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
302 | else:
303 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
304 |
305 | def generate_instance_with_gaz_no_char(self, input_file, name):
306 | """
307 | every instance include:
308 | words, biwords, gazs, labels
309 | word_Ids, biword_Ids, gazs_Ids, label_Ids
310 | """
311 | self.fix_alphabet()
312 | if name == "train":
313 | self.train_texts, self.train_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
314 | elif name == "dev":
315 | self.dev_texts, self.dev_Ids = read_instance_with_gaz_no_char(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
316 | elif name == "test":
317 | self.test_texts, self.test_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
318 | elif name == "raw":
319 | self.raw_texts, self.raw_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
320 | else:
321 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
322 |
323 |
324 | def write_decoded_results(self, output_file, predict_results, name):
325 | fout = open(output_file,'w', encoding="utf-8")
326 | sent_num = len(predict_results)
327 | content_list = []
328 | if name == 'raw':
329 | content_list = self.raw_texts
330 | elif name == 'test':
331 | content_list = self.test_texts
332 | elif name == 'dev':
333 | content_list = self.dev_texts
334 | elif name == 'train':
335 | content_list = self.train_texts
336 | else:
337 | print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !")
338 | assert(sent_num == len(content_list))
339 | for idx in range(sent_num):
340 | sent_length = len(predict_results[idx])
341 | for idy in range(sent_length):
342 | ## content_list[idx] is a list with [word, char, label]
343 | fout.write(content_list[idx][0][idy].encode('utf-8') + " " + predict_results[idx][idy] + '\n')
344 |
345 | fout.write('\n')
346 | fout.close()
347 | print("Predict %s result has been written into file. %s"%(name, output_file))
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
--------------------------------------------------------------------------------
/ner/utils/resume_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import sys
4 | import numpy as np
5 | import pickle
6 | from ner.utils.alphabet import Alphabet
7 | from ner.utils.functions import *
8 | from ner.utils.gazetteer import Gazetteer
9 |
10 |
11 |
12 | START = ""
13 | UNKNOWN = ""
14 | PADDING = ""
15 | NULLKEY = "-null-"
16 |
17 | class Data:
18 | def __init__(self):
19 | self.MAX_SENTENCE_LENGTH = 5000
20 | self.MAX_WORD_LENGTH = -1
21 | self.number_normalized = True
22 | self.norm_word_emb = True
23 | self.norm_biword_emb = True
24 | self.norm_gaz_emb = False
25 | self.use_single = False
26 | self.word_alphabet = Alphabet('word')
27 | self.biword_alphabet = Alphabet('biword')
28 | self.char_alphabet = Alphabet('character')
29 | # self.word_alphabet.add(START)
30 | # self.word_alphabet.add(UNKNOWN)
31 | # self.char_alphabet.add(START)
32 | # self.char_alphabet.add(UNKNOWN)
33 | # self.char_alphabet.add(PADDING)
34 | self.label_alphabet = Alphabet('label', True)
35 | self.gaz_lower = False
36 | self.gaz = Gazetteer(self.gaz_lower, self.use_single)
37 | self.gaz_alphabet = Alphabet('gaz')
38 | self.HP_fix_gaz_emb = False
39 | self.HP_use_gaz = True
40 |
41 | self.tagScheme = "NoSeg"
42 | self.char_features = "LSTM"
43 |
44 | self.train_texts = []
45 | self.dev_texts = []
46 | self.test_texts = []
47 | self.raw_texts = []
48 |
49 | self.train_Ids = []
50 | self.dev_Ids = []
51 | self.test_Ids = []
52 | self.raw_Ids = []
53 | self.use_bigram = True
54 | self.word_emb_dim = 100
55 | self.biword_emb_dim = 50
56 | self.char_emb_dim = 30
57 | self.gaz_emb_dim = 50
58 | self.gaz_dropout = 0.5
59 | self.pretrain_word_embedding = None
60 | self.pretrain_biword_embedding = None
61 | self.pretrain_gaz_embedding = None
62 | self.label_size = 0
63 | self.word_alphabet_size = 0
64 | self.biword_alphabet_size = 0
65 | self.char_alphabet_size = 0
66 | self.label_alphabet_size = 0
67 | ### hyperparameters
68 | self.HP_iteration = 100
69 | self.HP_batch_size = 10
70 | self.HP_char_hidden_dim = 100
71 | self.HP_hidden_dim = 100
72 | self.HP_dropout = 0.5
73 | self.HP_lstm_layer = 1
74 | self.HP_bilstm = True
75 | self.HP_use_char = False
76 | self.HP_gpu = False
77 | self.HP_lr = 0.015
78 | self.HP_lr_decay = 0.05
79 | self.HP_clip = 5.0
80 | self.HP_momentum = 0
81 |
82 |
83 | def show_data_summary(self):
84 | print("DATA SUMMARY START:")
85 | print(" Tag scheme: %s"%(self.tagScheme))
86 | print(" MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH))
87 | print(" MAX WORD LENGTH: %s"%(self.MAX_WORD_LENGTH))
88 | print(" Number normalized: %s"%(self.number_normalized))
89 | print(" Use bigram: %s"%(self.use_bigram))
90 | print(" Word alphabet size: %s"%(self.word_alphabet_size))
91 | print(" Biword alphabet size: %s"%(self.biword_alphabet_size))
92 | print(" Char alphabet size: %s"%(self.char_alphabet_size))
93 | print(" Gaz alphabet size: %s"%(self.gaz_alphabet.size()))
94 | print(" Label alphabet size: %s"%(self.label_alphabet_size))
95 | print(" Word embedding size: %s"%(self.word_emb_dim))
96 | print(" Biword embedding size: %s"%(self.biword_emb_dim))
97 | print(" Char embedding size: %s"%(self.char_emb_dim))
98 | print(" Gaz embedding size: %s"%(self.gaz_emb_dim))
99 | print(" Norm word emb: %s"%(self.norm_word_emb))
100 | print(" Norm biword emb: %s"%(self.norm_biword_emb))
101 | print(" Norm gaz emb: %s"%(self.norm_gaz_emb))
102 | print(" Norm gaz dropout: %s"%(self.gaz_dropout))
103 | print(" Train instance number: %s"%(len(self.train_texts)))
104 | print(" Dev instance number: %s"%(len(self.dev_texts)))
105 | print(" Test instance number: %s"%(len(self.test_texts)))
106 | print(" Raw instance number: %s"%(len(self.raw_texts)))
107 | print(" Hyperpara iteration: %s"%(self.HP_iteration))
108 | print(" Hyperpara batch size: %s"%(self.HP_batch_size))
109 | print(" Hyperpara lr: %s"%(self.HP_lr))
110 | print(" Hyperpara lr_decay: %s"%(self.HP_lr_decay))
111 | print(" Hyperpara HP_clip: %s"%(self.HP_clip))
112 | print(" Hyperpara momentum: %s"%(self.HP_momentum))
113 | print(" Hyperpara hidden_dim: %s"%(self.HP_hidden_dim))
114 | print(" Hyperpara dropout: %s"%(self.HP_dropout))
115 | print(" Hyperpara lstm_layer: %s"%(self.HP_lstm_layer))
116 | print(" Hyperpara bilstm: %s"%(self.HP_bilstm))
117 | print(" Hyperpara GPU: %s"%(self.HP_gpu))
118 | print(" Hyperpara use_gaz: %s"%(self.HP_use_gaz))
119 | print(" Hyperpara fix gaz emb: %s"%(self.HP_fix_gaz_emb))
120 | print(" Hyperpara use_char: %s"%(self.HP_use_char))
121 | if self.HP_use_char:
122 | print(" Char_features: %s"%(self.char_features))
123 | print("DATA SUMMARY END.")
124 | sys.stdout.flush()
125 |
126 | def refresh_label_alphabet(self, input_file):
127 | old_size = self.label_alphabet_size
128 | self.label_alphabet.clear(True)
129 | in_lines = open(input_file,'r', encoding="utf-8").readlines()
130 | for line in in_lines:
131 | if len(line) > 2:
132 | pairs = line.strip().split()
133 | label = pairs[-1]
134 | self.label_alphabet.add(label)
135 | self.label_alphabet_size = self.label_alphabet.size()
136 | startS = False
137 | startB = False
138 | for label,_ in self.label_alphabet.iteritems():
139 | if "S-" in label.upper():
140 | startS = True
141 | elif "B-" in label.upper():
142 | startB = True
143 | if startB:
144 | if startS:
145 | self.tagScheme = "BMES"
146 | else:
147 | self.tagScheme = "BIO"
148 | self.fix_alphabet()
149 | print("Refresh label alphabet finished: old:%s -> new:%s"%(old_size, self.label_alphabet_size))
150 |
151 |
152 |
153 | def build_alphabet(self, input_file):
154 | in_lines = open(input_file, 'r', encoding="utf-8").readlines()
155 | for idx in range(len(in_lines)):
156 | line = in_lines[idx]
157 | if len(line) > 2:
158 | pairs = line.strip().split()
159 | word = pairs[0]
160 | if self.number_normalized:
161 | word = normalize_word(word)
162 | label = pairs[-1]
163 | self.label_alphabet.add(label)
164 | self.word_alphabet.add(word)
165 | if idx < len(in_lines) - 1 and len(in_lines[idx+1]) > 2:
166 | biword = word + in_lines[idx+1].strip().split()[0]
167 | else:
168 | biword = word + NULLKEY
169 | self.biword_alphabet.add(biword)
170 | for char in word:
171 | self.char_alphabet.add(char)
172 | self.word_alphabet_size = self.word_alphabet.size()
173 | self.biword_alphabet_size = self.biword_alphabet.size()
174 | self.char_alphabet_size = self.char_alphabet.size()
175 | self.label_alphabet_size = self.label_alphabet.size()
176 | startS = False
177 | startB = False
178 | for label,_ in self.label_alphabet.iteritems():
179 | if "S-" in label.upper():
180 | startS = True
181 | elif "B-" in label.upper():
182 | startB = True
183 | if startB:
184 | if startS:
185 | self.tagScheme = "BMES"
186 | else:
187 | self.tagScheme = "BIO"
188 |
189 |
190 | def build_gaz_file(self, gaz_file):
191 | ## build gaz file,initial read gaz embedding file
192 | ## we only get the word, do not read embedding this step
193 | if gaz_file:
194 | fins = open(gaz_file, 'r', encoding="utf-8").readlines()
195 | for fin in fins:
196 | fin = fin.strip().split()[0]
197 | if fin:
198 | self.gaz.insert(fin, "one_source")
199 | print("Load gaz file: ", gaz_file, " total size:", self.gaz.size())
200 | else:
201 | print("Gaz file is None, load nothing")
202 |
203 |
204 | def build_gaz_alphabet(self, input_file):
205 | """
206 | based on the train, dev, test file, we only save the seb-sequence word that my be appear
207 | """
208 | in_lines = open(input_file,'r', encoding="utf-8").readlines()
209 | word_list = []
210 | for line in in_lines:
211 | if len(line) > 3:
212 | word = line.split()[0]
213 | if self.number_normalized:
214 | word = normalize_word(word)
215 | word_list.append(word)
216 | else:
217 | w_length = len(word_list)
218 | ## Travser from [0: n], [1: n] to [n-1: n]
219 | for idx in range(w_length):
220 | matched_entity = self.gaz.enumerateMatchList(word_list[idx:])
221 | for entity in matched_entity:
222 | # print entity, self.gaz.searchId(entity),self.gaz.searchType(entity)
223 | self.gaz_alphabet.add(entity)
224 | word_list = []
225 | print("gaz alphabet size:", self.gaz_alphabet.size())
226 |
227 |
228 | def fix_alphabet(self):
229 | self.word_alphabet.close()
230 | self.biword_alphabet.close()
231 | self.char_alphabet.close()
232 | self.label_alphabet.close()
233 | self.gaz_alphabet.close()
234 |
235 |
236 | def build_word_pretrain_emb(self, emb_path):
237 | print("build word pretrain emb...")
238 | self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.word_emb_dim, self.norm_word_emb)
239 |
240 | def build_biword_pretrain_emb(self, emb_path):
241 | print("build biword pretrain emb...")
242 | self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(emb_path, self.biword_alphabet, self.biword_emb_dim, self.norm_biword_emb)
243 |
244 | def build_gaz_pretrain_emb(self, emb_path):
245 | print("build gaz pretrain emb...")
246 | self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(emb_path, self.gaz_alphabet, self.gaz_emb_dim, self.norm_gaz_emb)
247 |
248 |
249 | def generate_word_instance(self, input_file, name):
250 | """
251 | every instance include: words, labels, word_ids, label_ids
252 | """
253 | self.fix_alphabet()
254 | if name == "train":
255 | self.train_texts, self.train_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
256 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
257 | elif name == "dev":
258 | self.dev_texts, self.dev_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
259 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
260 | elif name == "test":
261 | self.test_texts, self.test_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
262 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
263 | elif name == "raw":
264 | self.raw_texts, self.raw_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
265 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
266 | else:
267 | print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name))
268 |
269 |
270 | def generate_instance(self, input_file, name):
271 | """
272 | every instance include: words, biwords, chars, labels,
273 | word_ids, biword_ids, char_ids, label_ids
274 | """
275 | self.fix_alphabet()
276 | if name == "train":
277 | self.train_texts, self.train_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
278 | elif name == "dev":
279 | self.dev_texts, self.dev_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
280 | elif name == "test":
281 | self.test_texts, self.test_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
282 | elif name == "raw":
283 | self.raw_texts, self.raw_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
284 | else:
285 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
286 |
287 |
288 | def generate_instance_with_gaz(self, input_file, name):
289 | """
290 | every instance include: words, biwords, chars, labels, gazs,
291 | word_ids, biword_ids, char_ids, label_ids, gaz_ids
292 | """
293 | self.fix_alphabet()
294 | if name == "train":
295 | self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
296 | elif name == "dev":
297 | self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
298 | elif name == "test":
299 | self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
300 | elif name == "raw":
301 | self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
302 | else:
303 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
304 |
305 | def generate_instance_with_gaz_no_char(self, input_file, name):
306 | """
307 | every instance include:
308 | words, biwords, gazs, labels
309 | word_Ids, biword_Ids, gazs_Ids, label_Ids
310 | """
311 | self.fix_alphabet()
312 | if name == "train":
313 | self.train_texts, self.train_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
314 | elif name == "dev":
315 | self.dev_texts, self.dev_Ids = read_instance_with_gaz_no_char(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
316 | elif name == "test":
317 | self.test_texts, self.test_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
318 | elif name == "raw":
319 | self.raw_texts, self.raw_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
320 | else:
321 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
322 |
323 |
324 | def write_decoded_results(self, output_file, predict_results, name):
325 | fout = open(output_file,'w', encoding="utf-8")
326 | sent_num = len(predict_results)
327 | content_list = []
328 | if name == 'raw':
329 | content_list = self.raw_texts
330 | elif name == 'test':
331 | content_list = self.test_texts
332 | elif name == 'dev':
333 | content_list = self.dev_texts
334 | elif name == 'train':
335 | content_list = self.train_texts
336 | else:
337 | print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !")
338 | assert(sent_num == len(content_list))
339 | for idx in range(sent_num):
340 | sent_length = len(predict_results[idx])
341 | for idy in range(sent_length):
342 | ## content_list[idx] is a list with [word, char, label]
343 | fout.write(content_list[idx][0][idy].encode('utf-8') + " " + predict_results[idx][idy] + '\n')
344 |
345 | fout.write('\n')
346 | fout.close()
347 | print("Predict %s result has been written into file. %s"%(name, output_file))
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
--------------------------------------------------------------------------------
/ner/utils/trie.py:
--------------------------------------------------------------------------------
1 | import collections
2 | class TrieNode:
3 | # Initialize your data structure here.
4 | def __init__(self):
5 | self.children = collections.defaultdict(TrieNode)
6 | self.is_word = False
7 |
8 | class Trie:
9 | """
10 | In fact, this Trie is a letter three.
11 | root is a fake node, its function is only the begin of a word, same as
12 |
13 | the the first layer is all the word's possible first letter, for example, '中国'
14 | its first letter is '中'
15 | the second the layer is all the word's possible second letter.
16 | and so on
17 | """
18 | def __init__(self, use_single):
19 | self.root = TrieNode()
20 | if use_single:
21 | self.min_len = 0
22 | else:
23 | self.min_len = 1
24 |
25 | def insert(self, word):
26 |
27 | current = self.root
28 | # Traversing all the letter in the chinese word, util the last letter
29 | for letter in word:
30 | current = current.children[letter]
31 | current.is_word = True
32 |
33 | def search(self, word):
34 | current = self.root
35 | for letter in word:
36 | current = current.children.get(letter)
37 |
38 | if current is None:
39 | return False
40 | return current.is_word
41 |
42 | def startsWith(self, prefix):
43 | current = self.root
44 | for letter in prefix:
45 | current = current.children.get(letter)
46 | if current is None:
47 | return False
48 | return True
49 |
50 |
51 | def enumerateMatch(self, word, space="_", backward=False):
52 | matched = []
53 | ## while len(word) > 1 does not keep character itself, while word keed character itself
54 | while len(word) > self.min_len:
55 | if self.search(word):
56 | matched.append(space.join(word[:]))
57 |
58 | # del the last letter one by one
59 | del word[-1]
60 | # the matched is all the possible sub-sequence word in the give sequence word
61 | return matched
62 |
63 |
--------------------------------------------------------------------------------
/ner/utils/weibo_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import sys
4 | import numpy as np
5 | import pickle
6 | from ner.utils.alphabet import Alphabet
7 | from ner.utils.functions import *
8 | from ner.utils.gazetteer import Gazetteer
9 |
10 |
11 |
12 | START = ""
13 | UNKNOWN = ""
14 | PADDING = ""
15 | NULLKEY = "-null-"
16 |
17 | class Data:
18 | def __init__(self):
19 | self.MAX_SENTENCE_LENGTH = 250
20 | self.MAX_WORD_LENGTH = -1
21 | self.number_normalized = True
22 | self.norm_word_emb = True
23 | self.norm_biword_emb = True
24 | self.norm_gaz_emb = False
25 | self.use_single = True
26 | self.word_alphabet = Alphabet('word')
27 | self.biword_alphabet = Alphabet('biword')
28 | self.char_alphabet = Alphabet('character')
29 | # self.word_alphabet.add(START)
30 | # self.word_alphabet.add(UNKNOWN)
31 | # self.char_alphabet.add(START)
32 | # self.char_alphabet.add(UNKNOWN)
33 | # self.char_alphabet.add(PADDING)
34 | self.label_alphabet = Alphabet('label', True)
35 | self.gaz_lower = False
36 | self.gaz = Gazetteer(self.gaz_lower, self.use_single)
37 | self.gaz_alphabet = Alphabet('gaz')
38 | self.HP_fix_gaz_emb = False
39 | self.HP_use_gaz = True
40 |
41 | self.tagScheme = "NoSeg"
42 | self.char_features = "LSTM"
43 |
44 | self.train_texts = []
45 | self.dev_texts = []
46 | self.test_texts = []
47 | self.raw_texts = []
48 |
49 | self.train_Ids = []
50 | self.dev_Ids = []
51 | self.test_Ids = []
52 | self.raw_Ids = []
53 | self.use_bigram = True
54 | self.word_emb_dim = 100
55 | self.biword_emb_dim = 50
56 | self.char_emb_dim = 30
57 | self.gaz_emb_dim = 50
58 | self.gaz_dropout = 0.5
59 | self.pretrain_word_embedding = None
60 | self.pretrain_biword_embedding = None
61 | self.pretrain_gaz_embedding = None
62 | self.label_size = 0
63 | self.word_alphabet_size = 0
64 | self.biword_alphabet_size = 0
65 | self.char_alphabet_size = 0
66 | self.label_alphabet_size = 0
67 | ### hyperparameters
68 | self.HP_iteration = 100
69 | self.HP_batch_size = 10
70 | self.HP_char_hidden_dim = 100
71 | self.HP_hidden_dim = 100
72 | self.HP_dropout = 0.5
73 | self.HP_lstm_layer = 1
74 | self.HP_bilstm = True
75 | self.HP_use_char = False
76 | self.HP_gpu = False
77 | self.HP_lr = 0.015
78 | self.HP_lr_decay = 0.05
79 | self.HP_clip = 5.0
80 | self.HP_momentum = 0
81 |
82 |
83 | def show_data_summary(self):
84 | print("DATA SUMMARY START:")
85 | print(" Tag scheme: %s"%(self.tagScheme))
86 | print(" MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH))
87 | print(" MAX WORD LENGTH: %s"%(self.MAX_WORD_LENGTH))
88 | print(" Number normalized: %s"%(self.number_normalized))
89 | print(" Use bigram: %s"%(self.use_bigram))
90 | print(" Word alphabet size: %s"%(self.word_alphabet_size))
91 | print(" Biword alphabet size: %s"%(self.biword_alphabet_size))
92 | print(" Char alphabet size: %s"%(self.char_alphabet_size))
93 | print(" Gaz alphabet size: %s"%(self.gaz_alphabet.size()))
94 | print(" Label alphabet size: %s"%(self.label_alphabet_size))
95 | print(" Word embedding size: %s"%(self.word_emb_dim))
96 | print(" Biword embedding size: %s"%(self.biword_emb_dim))
97 | print(" Char embedding size: %s"%(self.char_emb_dim))
98 | print(" Gaz embedding size: %s"%(self.gaz_emb_dim))
99 | print(" Norm word emb: %s"%(self.norm_word_emb))
100 | print(" Norm biword emb: %s"%(self.norm_biword_emb))
101 | print(" Norm gaz emb: %s"%(self.norm_gaz_emb))
102 | print(" Norm gaz dropout: %s"%(self.gaz_dropout))
103 | print(" Train instance number: %s"%(len(self.train_texts)))
104 | print(" Dev instance number: %s"%(len(self.dev_texts)))
105 | print(" Test instance number: %s"%(len(self.test_texts)))
106 | print(" Raw instance number: %s"%(len(self.raw_texts)))
107 | print(" Hyperpara iteration: %s"%(self.HP_iteration))
108 | print(" Hyperpara batch size: %s"%(self.HP_batch_size))
109 | print(" Hyperpara lr: %s"%(self.HP_lr))
110 | print(" Hyperpara lr_decay: %s"%(self.HP_lr_decay))
111 | print(" Hyperpara HP_clip: %s"%(self.HP_clip))
112 | print(" Hyperpara momentum: %s"%(self.HP_momentum))
113 | print(" Hyperpara hidden_dim: %s"%(self.HP_hidden_dim))
114 | print(" Hyperpara dropout: %s"%(self.HP_dropout))
115 | print(" Hyperpara lstm_layer: %s"%(self.HP_lstm_layer))
116 | print(" Hyperpara bilstm: %s"%(self.HP_bilstm))
117 | print(" Hyperpara GPU: %s"%(self.HP_gpu))
118 | print(" Hyperpara use_gaz: %s"%(self.HP_use_gaz))
119 | print(" Hyperpara fix gaz emb: %s"%(self.HP_fix_gaz_emb))
120 | print(" Hyperpara use_char: %s"%(self.HP_use_char))
121 | if self.HP_use_char:
122 | print(" Char_features: %s"%(self.char_features))
123 | print("DATA SUMMARY END.")
124 | sys.stdout.flush()
125 |
126 | def refresh_label_alphabet(self, input_file):
127 | old_size = self.label_alphabet_size
128 | self.label_alphabet.clear(True)
129 | in_lines = open(input_file,'r', encoding="utf-8").readlines()
130 | for line in in_lines:
131 | if len(line) > 2:
132 | pairs = line.strip().split()
133 | label = pairs[-1]
134 | self.label_alphabet.add(label)
135 | self.label_alphabet_size = self.label_alphabet.size()
136 | startS = False
137 | startB = False
138 | for label,_ in self.label_alphabet.iteritems():
139 | if "S-" in label.upper():
140 | startS = True
141 | elif "B-" in label.upper():
142 | startB = True
143 | if startB:
144 | if startS:
145 | self.tagScheme = "BMES"
146 | else:
147 | self.tagScheme = "BIO"
148 | self.fix_alphabet()
149 | print("Refresh label alphabet finished: old:%s -> new:%s"%(old_size, self.label_alphabet_size))
150 |
151 |
152 |
153 | def build_alphabet(self, input_file):
154 | in_lines = open(input_file, 'r', encoding="utf-8").readlines()
155 | for idx in range(len(in_lines)):
156 | line = in_lines[idx]
157 | if len(line) > 2:
158 | pairs = line.strip().split()
159 | word = pairs[0]
160 | if self.number_normalized:
161 | word = normalize_word(word)
162 | label = pairs[-1]
163 | self.label_alphabet.add(label)
164 | self.word_alphabet.add(word)
165 | if idx < len(in_lines) - 1 and len(in_lines[idx+1]) > 2:
166 | biword = word + in_lines[idx+1].strip().split()[0]
167 | else:
168 | biword = word + NULLKEY
169 | self.biword_alphabet.add(biword)
170 | for char in word:
171 | self.char_alphabet.add(char)
172 | self.word_alphabet_size = self.word_alphabet.size()
173 | self.biword_alphabet_size = self.biword_alphabet.size()
174 | self.char_alphabet_size = self.char_alphabet.size()
175 | self.label_alphabet_size = self.label_alphabet.size()
176 | startS = False
177 | startB = False
178 | for label,_ in self.label_alphabet.iteritems():
179 | if "S-" in label.upper():
180 | startS = True
181 | elif "B-" in label.upper():
182 | startB = True
183 | if startB:
184 | if startS:
185 | self.tagScheme = "BMES"
186 | else:
187 | self.tagScheme = "BIO"
188 |
189 |
190 | def build_gaz_file(self, gaz_file):
191 | ## build gaz file,initial read gaz embedding file
192 | ## we only get the word, do not read embedding this step
193 | if gaz_file:
194 | fins = open(gaz_file, 'r', encoding="utf-8").readlines()
195 | for fin in fins:
196 | fin = fin.strip().split()[0]
197 | if fin:
198 | self.gaz.insert(fin, "one_source")
199 | print("Load gaz file: ", gaz_file, " total size:", self.gaz.size())
200 | else:
201 | print("Gaz file is None, load nothing")
202 |
203 |
204 | def build_gaz_alphabet(self, input_file):
205 | """
206 | based on the train, dev, test file, we only save the seb-sequence word that my be appear
207 | """
208 | in_lines = open(input_file,'r', encoding="utf-8").readlines()
209 | word_list = []
210 | for line in in_lines:
211 | if len(line) > 3:
212 | word = line.split()[0]
213 | if self.number_normalized:
214 | word = normalize_word(word)
215 | word_list.append(word)
216 | else:
217 | w_length = len(word_list)
218 | ## Travser from [0: n], [1: n] to [n-1: n]
219 | for idx in range(w_length):
220 | matched_entity = self.gaz.enumerateMatchList(word_list[idx:])
221 | for entity in matched_entity:
222 | # print entity, self.gaz.searchId(entity),self.gaz.searchType(entity)
223 | self.gaz_alphabet.add(entity)
224 | word_list = []
225 | print("gaz alphabet size:", self.gaz_alphabet.size())
226 |
227 |
228 | def fix_alphabet(self):
229 | self.word_alphabet.close()
230 | self.biword_alphabet.close()
231 | self.char_alphabet.close()
232 | self.label_alphabet.close()
233 | self.gaz_alphabet.close()
234 |
235 |
236 | def build_word_pretrain_emb(self, emb_path):
237 | print("build word pretrain emb...")
238 | self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.word_emb_dim, self.norm_word_emb)
239 |
240 | def build_biword_pretrain_emb(self, emb_path):
241 | print("build biword pretrain emb...")
242 | self.pretrain_biword_embedding, self.biword_emb_dim = build_pretrain_embedding(emb_path, self.biword_alphabet, self.biword_emb_dim, self.norm_biword_emb)
243 |
244 | def build_gaz_pretrain_emb(self, emb_path):
245 | print("build gaz pretrain emb...")
246 | self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(emb_path, self.gaz_alphabet, self.gaz_emb_dim, self.norm_gaz_emb)
247 |
248 |
249 | def generate_word_instance(self, input_file, name):
250 | """
251 | every instance include: words, labels, word_ids, label_ids
252 | """
253 | self.fix_alphabet()
254 | if name == "train":
255 | self.train_texts, self.train_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
256 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
257 | elif name == "dev":
258 | self.dev_texts, self.dev_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
259 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
260 | elif name == "test":
261 | self.test_texts, self.test_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
262 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
263 | elif name == "raw":
264 | self.raw_texts, self.raw_Ids = read_word_instance(input_file, self.word_alphabet, self.label_alphabet,
265 | self.number_normalized, self.MAX_SENTENCE_LENGTH)
266 | else:
267 | print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name))
268 |
269 |
270 | def generate_instance(self, input_file, name):
271 | """
272 | every instance include: words, biwords, chars, labels,
273 | word_ids, biword_ids, char_ids, label_ids
274 | """
275 | self.fix_alphabet()
276 | if name == "train":
277 | self.train_texts, self.train_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
278 | elif name == "dev":
279 | self.dev_texts, self.dev_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
280 | elif name == "test":
281 | self.test_texts, self.test_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
282 | elif name == "raw":
283 | self.raw_texts, self.raw_Ids = read_seg_instance(input_file, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
284 | else:
285 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
286 |
287 |
288 | def generate_instance_with_gaz(self, input_file, name):
289 | """
290 | every instance include: words, biwords, chars, labels, gazs,
291 | word_ids, biword_ids, char_ids, label_ids, gaz_ids
292 | """
293 | self.fix_alphabet()
294 | if name == "train":
295 | self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
296 | elif name == "dev":
297 | self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
298 | elif name == "test":
299 | self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
300 | elif name == "raw":
301 | self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.char_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
302 | else:
303 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
304 |
305 | def generate_instance_with_gaz_no_char(self, input_file, name):
306 | """
307 | every instance include:
308 | words, biwords, gazs, labels
309 | word_Ids, biword_Ids, gazs_Ids, label_Ids
310 | """
311 | self.fix_alphabet()
312 | if name == "train":
313 | self.train_texts, self.train_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
314 | elif name == "dev":
315 | self.dev_texts, self.dev_Ids = read_instance_with_gaz_no_char(input_file, self.gaz,self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
316 | elif name == "test":
317 | self.test_texts, self.test_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet, self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
318 | elif name == "raw":
319 | self.raw_texts, self.raw_Ids = read_instance_with_gaz_no_char(input_file, self.gaz, self.word_alphabet,self.biword_alphabet, self.gaz_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH, self.use_single)
320 | else:
321 | print("Error: you can only generate train/dev/test instance! Illegal input:%s"%(name))
322 |
323 |
324 | def write_decoded_results(self, output_file, predict_results, name):
325 | fout = open(output_file,'w', encoding="utf-8")
326 | sent_num = len(predict_results)
327 | content_list = []
328 | if name == 'raw':
329 | content_list = self.raw_texts
330 | elif name == 'test':
331 | content_list = self.test_texts
332 | elif name == 'dev':
333 | content_list = self.dev_texts
334 | elif name == 'train':
335 | content_list = self.train_texts
336 | else:
337 | print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !")
338 | assert(sent_num == len(content_list))
339 | for idx in range(sent_num):
340 | sent_length = len(predict_results[idx])
341 | for idy in range(sent_length):
342 | ## content_list[idx] is a list with [word, char, label]
343 | fout.write(content_list[idx][0][idy].encode('utf-8') + " " + predict_results[idx][idy] + '\n')
344 |
345 | fout.write('\n')
346 | fout.close()
347 | print("Predict %s result has been written into file. %s"%(name, output_file))
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
--------------------------------------------------------------------------------
/note4.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | __author__ = "liuwei"
3 |
4 |
5 | import time
6 | import sys
7 | import argparse
8 | import random
9 | import copy
10 | import torch
11 | import gc
12 | import pickle
13 | import torch.autograd as autograd
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import torch.optim as optim
17 | import numpy as np
18 | import logging
19 | import re
20 | import os
21 |
22 | from ner.utils.metric import get_ner_fmeasure
23 | from ner.lw.cw_ner import CW_NER as SeqModel
24 | from ner.utils.note4_data import Data
25 | from tensorboardX import SummaryWriter
26 | from ner.lw.tbx_writer import TensorboardWriter
27 |
28 |
29 | seed_num = 100
30 | random.seed(seed_num)
31 | torch.manual_seed(seed_num)
32 | np.random.seed(seed_num)
33 |
34 | # set logger
35 | logger = logging.getLogger(__name__)
36 | logger.setLevel(logging.INFO)
37 | BASIC_FORMAT = "%(asctime)s:%(levelname)s: %(message)s"
38 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
39 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT)
40 | chlr = logging.StreamHandler() # 输出到控制台的handler
41 | chlr.setFormatter(formatter)
42 | logger.addHandler(chlr)
43 |
44 |
45 | # for checkpoint
46 | max_model_num = 100
47 | old_model_paths = []
48 |
49 | # tensorboard writer
50 | log_dir = "data/log"
51 |
52 | train_log = SummaryWriter(os.path.join(log_dir, "train"))
53 | validation_log = SummaryWriter(os.path.join(log_dir, "validation"))
54 | test_log = SummaryWriter(os.path.join(log_dir, "test"))
55 | tensorboard = TensorboardWriter(train_log, validation_log, test_log)
56 |
57 |
58 | def data_initialization(data, gaz_file, train_file, dev_file, test_file):
59 | data.build_alphabet(train_file)
60 | data.build_alphabet(dev_file)
61 | data.build_alphabet(test_file)
62 | data.build_gaz_file(gaz_file)
63 | data.build_gaz_alphabet(train_file)
64 | data.build_gaz_alphabet(dev_file)
65 | data.build_gaz_alphabet(test_file)
66 | data.fix_alphabet()
67 | return data
68 |
69 |
70 | def predict_check(pred_variable, gold_variable, mask_variable):
71 | """
72 | input:
73 | pred_variable (batch_size, sent_len): pred tag result, in numpy format
74 | gold_variable (batch_size, sent_len): gold result variable
75 | mask_variable (batch_size, sent_len): mask variable
76 | """
77 | pred = pred_variable.cpu().data.numpy()
78 | gold = gold_variable.cpu().data.numpy()
79 | mask = mask_variable.cpu().data.numpy()
80 | overlaped = (pred == gold)
81 | right_token = np.sum(overlaped * mask)
82 | total_token = mask.sum()
83 | # print("right: %s, total: %s"%(right_token, total_token))
84 | return right_token, total_token
85 |
86 |
87 | def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet, word_recover):
88 | """
89 | input:
90 | pred_variable (batch_size, sent_len): pred tag result
91 | gold_variable (batch_size, sent_len): gold result variable
92 | mask_variable (batch_size, sent_len): mask variable
93 | """
94 |
95 | pred_variable = pred_variable[word_recover]
96 | gold_variable = gold_variable[word_recover]
97 | mask_variable = mask_variable[word_recover]
98 | batch_size = gold_variable.size(0)
99 | seq_len = gold_variable.size(1)
100 | mask = mask_variable.cpu().data.numpy()
101 | pred_tag = pred_variable.cpu().data.numpy()
102 | gold_tag = gold_variable.cpu().data.numpy()
103 | batch_size = mask.shape[0]
104 | pred_label = []
105 | gold_label = []
106 | for idx in range(batch_size):
107 | pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
108 | gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
109 | # print "p:",pred, pred_tag.tolist()
110 | # print "g:", gold, gold_tag.tolist()
111 | assert (len(pred) == len(gold))
112 | pred_label.append(pred)
113 | gold_label.append(gold)
114 | return pred_label, gold_label
115 |
116 |
117 | def lr_decay(optimizer, epoch, decay_rate, init_lr):
118 | lr = init_lr * ((1 - decay_rate) ** epoch)
119 | print(" Learning rate is setted as:", lr)
120 | for param_group in optimizer.param_groups:
121 | param_group['lr'] = lr
122 | return optimizer
123 |
124 |
125 | def save_model(epoch, state, models_dir):
126 | """
127 | save the model state
128 | Args:
129 | epoch: the number of epoch
130 | state: [model_state, training_state]
131 | models_dir: the dir to save model
132 | """
133 | if models_dir is not None:
134 | model_path = os.path.join(models_dir, "model_state_epoch_{}.th".format(epoch))
135 | train_path = os.path.join(models_dir, "training_state_epoch_{}.th".format(epoch))
136 |
137 | model_state, training_state = state
138 | torch.save(model_state, model_path)
139 | torch.save(training_state, train_path)
140 |
141 | if max_model_num > 0:
142 | old_model_paths.append([model_path, train_path])
143 | if len(old_model_paths) > max_model_num:
144 | paths_to_remove = old_model_paths.pop(0)
145 |
146 | for fname in paths_to_remove:
147 | os.remove(fname)
148 |
149 | def find_last_model(models_dir):
150 | """
151 | find the lastes checkpoint file
152 | Args:
153 | models_dir: the dir save models
154 | """
155 | epoch_num = 0
156 | # return models_dir + "/model_state_epoch_{}.th".format(epoch_num), models_dir + "/training_state_epoch_{}.th".format(epoch_num)
157 |
158 | if models_dir is None:
159 | return None
160 |
161 | saved_models_path = os.listdir(models_dir)
162 | saved_models_path = [x for x in saved_models_path if 'model_state_epoch' in x]
163 |
164 | if len(saved_models_path) == 0:
165 | return None
166 |
167 | found_epochs = [
168 | re.search("model_state_epoch_([0-9]+).th", x).group(1)
169 | for x in saved_models_path
170 | ]
171 | int_epochs = [int(epoch) for epoch in found_epochs]
172 | print("len: ", len(int_epochs))
173 | last_epoch = sorted(int_epochs, reverse=True)[0]
174 | epoch_to_load = "{}".format(last_epoch)
175 |
176 | model_path = os.path.join(models_dir, "model_state_epoch_{}.th".format(epoch_to_load))
177 | training_state_path = os.path.join(models_dir, "training_state_epoch_{}.th".format(epoch_to_load))
178 |
179 | return model_path, training_state_path
180 |
181 |
182 | def restore_model(models_dir):
183 | """
184 | restore the lastes checkpoint file
185 | """
186 | lastest_checkpoint = find_last_model(models_dir)
187 |
188 | if lastest_checkpoint is None:
189 | return None
190 | else:
191 | model_path, training_state_path = lastest_checkpoint
192 |
193 | model_state = torch.load(model_path)
194 | training_state = torch.load(training_state_path)
195 |
196 | return (model_state, training_state)
197 |
198 |
199 | def evaluate(data, model, name):
200 | if name == "train":
201 | instances = data.train_Ids
202 | elif name == "dev":
203 | instances = data.dev_Ids
204 | elif name == 'test':
205 | instances = data.test_Ids
206 | elif name == 'raw':
207 | instances = data.raw_Ids
208 | else:
209 | print("Error: wrong evaluate name,", name)
210 | right_token = 0
211 | whole_token = 0
212 | pred_results = []
213 | gold_results = []
214 | ## set model in eval model
215 | model.eval()
216 | batch_size = 8
217 | start_time = time.time()
218 | train_num = len(instances)
219 | total_batch = train_num // batch_size + 1
220 | for batch_id in range(total_batch):
221 | start = batch_id * batch_size
222 | end = (batch_id + 1) * batch_size
223 | if end > train_num:
224 | end = train_num
225 | instance = instances[start:end]
226 | if not instance:
227 | continue
228 | gaz_list, reverse_gaz_list, batch_word, batch_biword, batch_wordlen, batch_wordrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu, True)
229 | tag_seq = model(gaz_list, reverse_gaz_list, batch_word, batch_wordlen, mask)
230 | pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet, batch_wordrecover)
231 | pred_results += pred_label
232 | gold_results += gold_label
233 | decode_time = time.time() - start_time
234 | speed = len(instances) / decode_time
235 | acc, p, r, f = get_ner_fmeasure(gold_results, pred_results, data.tagScheme)
236 | return speed, acc, p, r, f, pred_results
237 |
238 |
239 | def batchify_with_label(input_batch_list, gpu, volatile_flag=False):
240 | """
241 | input: list of words, chars and labels, various length. [[words,biwords,chars,gaz, labels],[words,biwords,chars,labels],...]
242 | words: word ids for one sentence. (batch_size, sent_len)
243 | chars: char ids for on sentences, various length. (batch_size, sent_len, each_word_length)
244 | output:
245 | zero padding for word and char, with their batch length
246 | word_seq_tensor: (batch_size, max_sent_len) Variable
247 | word_seq_lengths: (batch_size,1) Tensor
248 | char_seq_recover: (batch_size*max_sent_len,1) recover char sequence order
249 | label_seq_tensor: (batch_size, max_sent_len)
250 | mask: (batch_size, max_sent_len)
251 | """
252 | batch_size = len(input_batch_list)
253 | words = [sent[0] for sent in input_batch_list]
254 | biwords = [sent[1] for sent in input_batch_list]
255 | gazs = [sent[2] for sent in input_batch_list]
256 | reverse_gazs = [sent[3] for sent in input_batch_list]
257 | labels = [sent[4] for sent in input_batch_list]
258 | word_seq_lengths = torch.LongTensor(list(map(len, words)))
259 | max_seq_len = word_seq_lengths.max()
260 | word_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long()
261 | biword_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long()
262 | label_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long()
263 | mask = autograd.Variable(torch.zeros((batch_size, max_seq_len))).byte()
264 | for idx, (seq, biseq, label, seqlen) in enumerate(zip(words, biwords, labels, word_seq_lengths)):
265 | word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
266 | biword_seq_tensor[idx, :seqlen] = torch.LongTensor(biseq)
267 | label_seq_tensor[idx, :seqlen] = torch.LongTensor(label)
268 | mask[idx, :seqlen] = torch.Tensor([1] * int(seqlen))
269 | word_seq_lengths, word_perm_idx = word_seq_lengths.sort(0, descending=True)
270 | word_seq_tensor = word_seq_tensor[word_perm_idx]
271 | biword_seq_tensor = biword_seq_tensor[word_perm_idx]
272 | label_seq_tensor = label_seq_tensor[word_perm_idx]
273 | mask = mask[word_perm_idx]
274 |
275 | _, word_seq_recover = word_perm_idx.sort(0, descending=False)
276 |
277 | gaz_list = [gazs[i] for i in word_perm_idx]
278 | reverse_gaz_list = [reverse_gazs[i] for i in word_perm_idx]
279 |
280 | if gpu:
281 | word_seq_tensor = word_seq_tensor.cuda()
282 | biword_seq_tensor = biword_seq_tensor.cuda()
283 | word_seq_lengths = word_seq_lengths.cuda()
284 | word_seq_recover = word_seq_recover.cuda()
285 | label_seq_tensor = label_seq_tensor.cuda()
286 | mask = mask.cuda()
287 |
288 | # gaz_seq_tensor = gaz_seq_tensor.cuda()
289 | # gaz_seq_length = gaz_seq_length.cuda()
290 | # gaz_mask_tensor = gaz_mask_tensor.cuda()
291 |
292 | return gaz_list, reverse_gaz_list, word_seq_tensor, biword_seq_tensor, word_seq_lengths, word_seq_recover, label_seq_tensor, mask
293 |
294 |
295 | def train(data, save_model_dir, seg=True):
296 | print("Training model...")
297 | data.show_data_summary()
298 | model = SeqModel(data, type=5)
299 | print("finished built model.")
300 | loss_function = nn.NLLLoss()
301 | parameters = filter(lambda p: p.requires_grad, model.parameters())
302 | optimizer = optim.SGD(parameters, lr=data.HP_lr, momentum=data.HP_momentum)
303 | best_dev = -1
304 | data.HP_iteration = 75
305 |
306 | ## here we should restore the model
307 | state = restore_model(save_model_dir)
308 | epoch = 0
309 | if state is not None:
310 | model_state = state[0]
311 | training_state = state[1]
312 |
313 | model.load_state_dict(model_state)
314 | optimizer.load_state_dict(training_state['optimizer'])
315 | epoch = int(training_state['epoch'])
316 |
317 | batch_size = 8
318 | now_batch_num = epoch * len(data.train_Ids) // batch_size
319 |
320 | ## start training
321 | for idx in range(epoch, data.HP_iteration):
322 | epoch_start = time.time()
323 | temp_start = epoch_start
324 | print("Epoch: %s/%s" % (idx, data.HP_iteration))
325 | optimizer = lr_decay(optimizer, idx, data.HP_lr_decay, data.HP_lr)
326 | instance_count = 0
327 | sample_id = 0
328 | sample_loss = 0
329 | batch_loss = 0
330 | total_loss = 0
331 | right_token = 0
332 | whole_token = 0
333 | random.shuffle(data.train_Ids)
334 | ## set model in train model
335 | model.train()
336 | model.zero_grad()
337 | batch_size = 8
338 | batch_id = 0
339 | train_num = len(data.train_Ids)
340 | total_batch = train_num // batch_size + 1
341 | for batch_id in range(total_batch):
342 | start = batch_id * batch_size
343 | end = (batch_id + 1) * batch_size
344 | if end > train_num:
345 | end = train_num
346 | instance = data.train_Ids[start:end]
347 | if not instance:
348 | continue
349 | gaz_list, reverse_gaz_list, batch_word, batch_biword, batch_wordlen, batch_wordrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu)
350 | # print "gaz_list:",gaz_list
351 | # exit(0)
352 | instance_count += (end - start)
353 | loss, tag_seq = model.neg_log_likelihood_loss(gaz_list, reverse_gaz_list, batch_word, batch_wordlen, batch_label, mask)
354 | right, whole = predict_check(tag_seq, batch_label, mask)
355 | right_token += right
356 | whole_token += whole
357 | sample_loss += loss.item()
358 | total_loss += loss.item()
359 | batch_loss += loss
360 |
361 | if end % 4000 == 0:
362 | temp_time = time.time()
363 | temp_cost = temp_time - temp_start
364 | temp_start = temp_time
365 | print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token))
366 | sys.stdout.flush()
367 | sample_loss = 0
368 | if end % data.HP_batch_size == 0:
369 | batch_loss.backward()
370 | optimizer.step()
371 | model.zero_grad()
372 | batch_loss = 0
373 |
374 | tensorboard.add_train_scalar("train_loss", total_loss / (batch_id + 1), now_batch_num)
375 | tensorboard.add_train_scalar("train_accu", right_token / whole_token, now_batch_num)
376 | now_batch_num += 1
377 |
378 | temp_time = time.time()
379 | temp_cost = temp_time - temp_start
380 | print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token))
381 | epoch_finish = time.time()
382 | epoch_cost = epoch_finish - epoch_start
383 | print("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s" % (
384 | idx, epoch_cost, train_num / epoch_cost, total_loss))
385 | # exit(0)
386 | # continue
387 | speed, acc, p, r, f, _ = evaluate(data, model, "dev")
388 | dev_finish = time.time()
389 | dev_cost = dev_finish - epoch_finish
390 |
391 | if seg:
392 | current_score = f
393 | print("Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (dev_cost, speed, acc, p, r, f))
394 |
395 | tensorboard.add_validation_scalar("dev_accu", acc, idx)
396 | tensorboard.add_validation_scalar("dev_p", p, idx)
397 | tensorboard.add_validation_scalar("dev_r", r, idx)
398 | tensorboard.add_validation_scalar("dev_f", f, idx)
399 |
400 | else:
401 | current_score = acc
402 | print("Dev: time: %.2fs speed: %.2fst/s; acc: %.4f" % (dev_cost, speed, acc))
403 |
404 | if current_score > best_dev:
405 | if seg:
406 | print("Exceed previous best f score:", best_dev)
407 | else:
408 | print("Exceed previous best acc score:", best_dev)
409 | # model_name = save_model_dir + '.' + str(idx) + ".model"
410 | # torch.save(model.state_dict(), model_name)
411 | best_dev = current_score
412 | # ## decode test
413 | speed, acc, p, r, f, _ = evaluate(data, model, "test")
414 | test_finish = time.time()
415 | test_cost = test_finish - dev_finish
416 | if seg:
417 | print("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (test_cost, speed, acc, p, r, f))
418 |
419 | tensorboard.add_test_scalar("test_accu", acc, idx)
420 | tensorboard.add_test_scalar("test_p", p, idx)
421 | tensorboard.add_test_scalar("test_r", r, idx)
422 | tensorboard.add_test_scalar("test_f", f, idx)
423 |
424 | else:
425 | print("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f" % (test_cost, speed, acc))
426 | gc.collect()
427 |
428 | ## save the model every epoch
429 | model_state = model.state_dict()
430 | training_state = {
431 | "epoch": idx,
432 | "optimizer": optimizer.state_dict()
433 | }
434 | save_model(idx, (model_state, training_state), save_model_dir)
435 |
436 |
437 | if __name__ == '__main__':
438 | parser = argparse.ArgumentParser(description='Tuning with bi-directional LSTM-CRF')
439 | parser.add_argument('--embedding', help='Embedding for words', default='None')
440 | parser.add_argument('--status', choices=['train', 'test', 'decode'], help='update algorithm', default='train')
441 | parser.add_argument('--savemodel', default="data/model/note4")
442 | parser.add_argument('--savedset', help='Dir of saved data setting', default="data/save.dset")
443 | parser.add_argument('--train', default="data/note4/train.char.bmes")
444 | parser.add_argument('--dev', default="data/note4/dev.char.bmes")
445 | parser.add_argument('--test', default="data/note4/test.char.bmes")
446 | parser.add_argument('--seg', default="True")
447 | parser.add_argument('--extendalphabet', default="True")
448 | parser.add_argument('--raw')
449 | parser.add_argument('--loadmodel')
450 | parser.add_argument('--output')
451 | args = parser.parse_args()
452 |
453 | train_file = args.train
454 | dev_file = args.dev
455 | test_file = args.test
456 | raw_file = args.raw
457 | model_dir = args.loadmodel
458 | dset_dir = args.savedset
459 | output_file = args.output
460 | if args.seg.lower() == "true":
461 | seg = True
462 | else:
463 | seg = False
464 | status = args.status.lower()
465 |
466 | save_model_dir = args.savemodel
467 | gpu = torch.cuda.is_available()
468 |
469 | char_emb = "data/gigaword_chn.all.a2b.uni.ite50.vec"
470 | bichar_emb = None
471 | gaz_file = "data/ctb.50d.vec"
472 | # gaz_file = None
473 | # char_emb = None
474 | # bichar_emb = None
475 |
476 | print("CuDNN:", torch.backends.cudnn.enabled)
477 | # gpu = False
478 | print("GPU available:", gpu)
479 | print("Status:", status)
480 | print("Seg: ", seg)
481 | print("Train file:", train_file)
482 | print("Dev file:", dev_file)
483 | print("Test file:", test_file)
484 | print("Raw file:", raw_file)
485 | print("Char emb:", char_emb)
486 | print("Bichar emb:", bichar_emb)
487 | print("Gaz file:", gaz_file)
488 | if status == 'train':
489 | print("Model saved to:", save_model_dir)
490 | sys.stdout.flush()
491 |
492 | if status == 'train':
493 | data = Data()
494 | data.HP_gpu = gpu
495 | data.HP_use_char = False
496 | data.HP_batch_size = 1
497 | data.use_bigram = False
498 | data.gaz_dropout = 0.5
499 | data.norm_gaz_emb = False
500 | data.HP_fix_gaz_emb = False
501 | data_initialization(data, gaz_file, train_file, dev_file, test_file)
502 | data.generate_instance_with_gaz_no_char(train_file, 'train')
503 | data.generate_instance_with_gaz_no_char(dev_file, 'dev')
504 | data.generate_instance_with_gaz_no_char(test_file, 'test')
505 |
506 | data.build_word_pretrain_emb(char_emb)
507 | data.build_biword_pretrain_emb(bichar_emb)
508 | data.build_gaz_pretrain_emb(gaz_file)
509 |
510 | train(data, save_model_dir, seg)
511 |
512 |
513 |
514 |
--------------------------------------------------------------------------------
/resume.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | __author__ = "liuwei"
3 |
4 | import time
5 | import sys
6 | import argparse
7 | import random
8 | import copy
9 | import torch
10 | import gc
11 | import pickle
12 | import torch.autograd as autograd
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import torch.optim as optim
16 | import numpy as np
17 | import logging
18 | import re
19 | import os
20 |
21 | from ner.utils.metric import get_ner_fmeasure
22 | from ner.lw.cw_ner import CW_NER as SeqModel
23 | from ner.utils.resume_data import Data
24 | from tensorboardX import SummaryWriter
25 | from ner.lw.tbx_writer import TensorboardWriter
26 |
27 |
28 | seed_num = 100
29 | random.seed(seed_num)
30 | torch.manual_seed(seed_num)
31 | np.random.seed(seed_num)
32 |
33 | # set logger
34 | logger = logging.getLogger(__name__)
35 | logger.setLevel(logging.INFO)
36 | BASIC_FORMAT = "%(asctime)s:%(levelname)s: %(message)s"
37 | DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
38 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT)
39 | chlr = logging.StreamHandler() # 输出到控制台的handler
40 | chlr.setFormatter(formatter)
41 | logger.addHandler(chlr)
42 |
43 |
44 | # for checkpoint
45 | max_model_num = 100
46 | old_model_paths = []
47 |
48 | # tensorboard writer
49 | log_dir = "data/log"
50 |
51 | train_log = SummaryWriter(os.path.join(log_dir, "train"))
52 | validation_log = SummaryWriter(os.path.join(log_dir, "validation"))
53 | test_log = SummaryWriter(os.path.join(log_dir, "test"))
54 | tensorboard = TensorboardWriter(train_log, validation_log, test_log)
55 |
56 |
57 | def data_initialization(data, gaz_file, train_file, dev_file, test_file):
58 | data.build_alphabet(train_file)
59 | data.build_alphabet(dev_file)
60 | data.build_alphabet(test_file)
61 | data.build_gaz_file(gaz_file)
62 | data.build_gaz_alphabet(train_file)
63 | data.build_gaz_alphabet(dev_file)
64 | data.build_gaz_alphabet(test_file)
65 | data.fix_alphabet()
66 | return data
67 |
68 |
69 | def predict_check(pred_variable, gold_variable, mask_variable):
70 | """
71 | input:
72 | pred_variable (batch_size, sent_len): pred tag result, in numpy format
73 | gold_variable (batch_size, sent_len): gold result variable
74 | mask_variable (batch_size, sent_len): mask variable
75 | """
76 | pred = pred_variable.cpu().data.numpy()
77 | gold = gold_variable.cpu().data.numpy()
78 | mask = mask_variable.cpu().data.numpy()
79 | overlaped = (pred == gold)
80 | right_token = np.sum(overlaped * mask)
81 | total_token = mask.sum()
82 | # print("right: %s, total: %s"%(right_token, total_token))
83 | return right_token, total_token
84 |
85 |
86 | def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet, word_recover):
87 | """
88 | input:
89 | pred_variable (batch_size, sent_len): pred tag result
90 | gold_variable (batch_size, sent_len): gold result variable
91 | mask_variable (batch_size, sent_len): mask variable
92 | """
93 |
94 | pred_variable = pred_variable[word_recover]
95 | gold_variable = gold_variable[word_recover]
96 | mask_variable = mask_variable[word_recover]
97 | batch_size = gold_variable.size(0)
98 | seq_len = gold_variable.size(1)
99 | mask = mask_variable.cpu().data.numpy()
100 | pred_tag = pred_variable.cpu().data.numpy()
101 | gold_tag = gold_variable.cpu().data.numpy()
102 | batch_size = mask.shape[0]
103 | pred_label = []
104 | gold_label = []
105 | for idx in range(batch_size):
106 | pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
107 | gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
108 | # print "p:",pred, pred_tag.tolist()
109 | # print "g:", gold, gold_tag.tolist()
110 | assert (len(pred) == len(gold))
111 | pred_label.append(pred)
112 | gold_label.append(gold)
113 | return pred_label, gold_label
114 |
115 |
116 | def lr_decay(optimizer, epoch, decay_rate, init_lr):
117 | lr = init_lr * ((1 - decay_rate) ** epoch)
118 | print(" Learning rate is setted as:", lr)
119 | for param_group in optimizer.param_groups:
120 | param_group['lr'] = lr
121 | return optimizer
122 |
123 |
124 | def save_model(epoch, state, models_dir):
125 | """
126 | save the model state
127 | Args:
128 | epoch: the number of epoch
129 | state: [model_state, training_state]
130 | models_dir: the dir to save model
131 | """
132 | if models_dir is not None:
133 | model_path = os.path.join(models_dir, "model_state_epoch_{}.th".format(epoch))
134 | train_path = os.path.join(models_dir, "training_state_epoch_{}.th".format(epoch))
135 |
136 | model_state, training_state = state
137 | torch.save(model_state, model_path)
138 | torch.save(training_state, train_path)
139 |
140 | if max_model_num > 0:
141 | old_model_paths.append([model_path, train_path])
142 | if len(old_model_paths) > max_model_num:
143 | paths_to_remove = old_model_paths.pop(0)
144 |
145 | for fname in paths_to_remove:
146 | os.remove(fname)
147 |
148 | def find_last_model(models_dir):
149 | """
150 | find the lastes checkpoint file
151 | Args:
152 | models_dir: the dir save models
153 | """
154 | epoch_num = 0
155 | # return models_dir + "/model_state_epoch_{}.th".format(epoch_num), models_dir + "/training_state_epoch_{}.th".format(epoch_num)
156 |
157 | if models_dir is None:
158 | return None
159 |
160 | saved_models_path = os.listdir(models_dir)
161 | saved_models_path = [x for x in saved_models_path if 'model_state_epoch' in x]
162 |
163 | if len(saved_models_path) == 0:
164 | return None
165 |
166 | found_epochs = [
167 | re.search("model_state_epoch_([0-9]+).th", x).group(1)
168 | for x in saved_models_path
169 | ]
170 | int_epochs = [int(epoch) for epoch in found_epochs]
171 | print("len: ", len(int_epochs))
172 | last_epoch = sorted(int_epochs, reverse=True)[0]
173 | epoch_to_load = "{}".format(last_epoch)
174 |
175 | model_path = os.path.join(models_dir, "model_state_epoch_{}.th".format(epoch_to_load))
176 | training_state_path = os.path.join(models_dir, "training_state_epoch_{}.th".format(epoch_to_load))
177 |
178 | return model_path, training_state_path
179 |
180 |
181 | def restore_model(models_dir):
182 | """
183 | restore the lastes checkpoint file
184 | """
185 | lastest_checkpoint = find_last_model(models_dir)
186 |
187 | if lastest_checkpoint is None:
188 | return None
189 | else:
190 | model_path, training_state_path = lastest_checkpoint
191 |
192 | model_state = torch.load(model_path)
193 | training_state = torch.load(training_state_path)
194 |
195 | return (model_state, training_state)
196 |
197 |
198 | def evaluate(data, model, name):
199 | if name == "train":
200 | instances = data.train_Ids
201 | elif name == "dev":
202 | instances = data.dev_Ids
203 | elif name == 'test':
204 | instances = data.test_Ids
205 | elif name == 'raw':
206 | instances = data.raw_Ids
207 | else:
208 | print("Error: wrong evaluate name,", name)
209 | right_token = 0
210 | whole_token = 0
211 | pred_results = []
212 | gold_results = []
213 | ## set model in eval model
214 | model.eval()
215 | batch_size = 8
216 | start_time = time.time()
217 | train_num = len(instances)
218 | total_batch = train_num // batch_size + 1
219 | for batch_id in range(total_batch):
220 | start = batch_id * batch_size
221 | end = (batch_id + 1) * batch_size
222 | if end > train_num:
223 | end = train_num
224 | instance = instances[start:end]
225 | if not instance:
226 | continue
227 | gaz_list, reverse_gaz_list, batch_word, batch_biword, batch_wordlen, batch_wordrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu, True)
228 | tag_seq = model(gaz_list, reverse_gaz_list, batch_word, batch_wordlen, mask)
229 | pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet, batch_wordrecover)
230 | pred_results += pred_label
231 | gold_results += gold_label
232 | decode_time = time.time() - start_time
233 | speed = len(instances) / decode_time
234 | acc, p, r, f = get_ner_fmeasure(gold_results, pred_results, data.tagScheme)
235 | return speed, acc, p, r, f, pred_results
236 |
237 |
238 | def batchify_with_label(input_batch_list, gpu, volatile_flag=False):
239 | """
240 | input: list of words, chars and labels, various length. [[words,biwords,chars,gaz, labels],[words,biwords,chars,labels],...]
241 | words: word ids for one sentence. (batch_size, sent_len)
242 | chars: char ids for on sentences, various length. (batch_size, sent_len, each_word_length)
243 | output:
244 | zero padding for word and char, with their batch length
245 | word_seq_tensor: (batch_size, max_sent_len) Variable
246 | word_seq_lengths: (batch_size,1) Tensor
247 | char_seq_recover: (batch_size*max_sent_len,1) recover char sequence order
248 | label_seq_tensor: (batch_size, max_sent_len)
249 | mask: (batch_size, max_sent_len)
250 | """
251 | batch_size = len(input_batch_list)
252 | words = [sent[0] for sent in input_batch_list]
253 | biwords = [sent[1] for sent in input_batch_list]
254 | gazs = [sent[2] for sent in input_batch_list]
255 | reverse_gazs = [sent[3] for sent in input_batch_list]
256 | labels = [sent[4] for sent in input_batch_list]
257 | word_seq_lengths = torch.LongTensor(list(map(len, words)))
258 | max_seq_len = word_seq_lengths.max()
259 | word_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long()
260 | biword_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long()
261 | label_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long()
262 | mask = autograd.Variable(torch.zeros((batch_size, max_seq_len))).byte()
263 | for idx, (seq, biseq, label, seqlen) in enumerate(zip(words, biwords, labels, word_seq_lengths)):
264 | word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
265 | biword_seq_tensor[idx, :seqlen] = torch.LongTensor(biseq)
266 | label_seq_tensor[idx, :seqlen] = torch.LongTensor(label)
267 | mask[idx, :seqlen] = torch.Tensor([1] * int(seqlen))
268 | word_seq_lengths, word_perm_idx = word_seq_lengths.sort(0, descending=True)
269 | word_seq_tensor = word_seq_tensor[word_perm_idx]
270 | biword_seq_tensor = biword_seq_tensor[word_perm_idx]
271 | label_seq_tensor = label_seq_tensor[word_perm_idx]
272 | mask = mask[word_perm_idx]
273 |
274 | _, word_seq_recover = word_perm_idx.sort(0, descending=False)
275 |
276 | gaz_list = [gazs[i] for i in word_perm_idx]
277 | reverse_gaz_list = [reverse_gazs[i] for i in word_perm_idx]
278 |
279 | if gpu:
280 | word_seq_tensor = word_seq_tensor.cuda()
281 | biword_seq_tensor = biword_seq_tensor.cuda()
282 | word_seq_lengths = word_seq_lengths.cuda()
283 | word_seq_recover = word_seq_recover.cuda()
284 | label_seq_tensor = label_seq_tensor.cuda()
285 | mask = mask.cuda()
286 |
287 | # gaz_seq_tensor = gaz_seq_tensor.cuda()
288 | # gaz_seq_length = gaz_seq_length.cuda()
289 | # gaz_mask_tensor = gaz_mask_tensor.cuda()
290 |
291 | return gaz_list, reverse_gaz_list, word_seq_tensor, biword_seq_tensor, word_seq_lengths, word_seq_recover, label_seq_tensor, mask
292 |
293 |
294 | def train(data, save_model_dir, seg=True):
295 | print("Training model...")
296 | data.show_data_summary()
297 | model = SeqModel(data, type=2)
298 | print("finished built model.")
299 | loss_function = nn.NLLLoss()
300 | parameters = filter(lambda p: p.requires_grad, model.parameters())
301 | optimizer = optim.SGD(parameters, lr=data.HP_lr, momentum=data.HP_momentum)
302 | best_dev = -1
303 | data.HP_iteration = 75
304 |
305 | ## here we should restore the model
306 | state = restore_model(save_model_dir)
307 | epoch = 0
308 | if state is not None:
309 | model_state = state[0]
310 | training_state = state[1]
311 |
312 | model.load_state_dict(model_state)
313 | optimizer.load_state_dict(training_state['optimizer'])
314 | epoch = int(training_state['epoch'])
315 |
316 | batch_size = 1 ## model can be trained in any batch_size, here we set it to 1
317 | now_batch_num = epoch * len(data.train_Ids) // batch_size
318 |
319 | ## start training
320 | for idx in range(epoch, data.HP_iteration):
321 | epoch_start = time.time()
322 | temp_start = epoch_start
323 | print("Epoch: %s/%s" % (idx, data.HP_iteration))
324 | optimizer = lr_decay(optimizer, idx, data.HP_lr_decay, data.HP_lr)
325 | instance_count = 0
326 | sample_id = 0
327 | sample_loss = 0
328 | batch_loss = 0
329 | total_loss = 0
330 | right_token = 0
331 | whole_token = 0
332 | random.shuffle(data.train_Ids)
333 | ## set model in train model
334 | model.train()
335 | model.zero_grad()
336 | batch_size = 1 ## we can train in any batch_size, here we set it to 1
337 | batch_id = 0
338 | train_num = len(data.train_Ids)
339 | total_batch = train_num // batch_size + 1
340 | for batch_id in range(total_batch):
341 | start = batch_id * batch_size
342 | end = (batch_id + 1) * batch_size
343 | if end > train_num:
344 | end = train_num
345 | instance = data.train_Ids[start:end]
346 | if not instance:
347 | continue
348 | gaz_list, reverse_gaz_list, batch_word, batch_biword, batch_wordlen, batch_wordrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu)
349 | # print "gaz_list:",gaz_list
350 | # exit(0)
351 | instance_count += (end - start)
352 | loss, tag_seq = model.neg_log_likelihood_loss(gaz_list, reverse_gaz_list, batch_word, batch_wordlen, batch_label, mask)
353 | right, whole = predict_check(tag_seq, batch_label, mask)
354 | right_token += right
355 | whole_token += whole
356 | sample_loss += loss.item()
357 | total_loss += loss.item()
358 | batch_loss += loss
359 |
360 | if end % 4000 == 0:
361 | temp_time = time.time()
362 | temp_cost = temp_time - temp_start
363 | temp_start = temp_time
364 | print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token))
365 | sys.stdout.flush()
366 | sample_loss = 0
367 | if end % data.HP_batch_size == 0:
368 | batch_loss.backward()
369 | optimizer.step()
370 | model.zero_grad()
371 | batch_loss = 0
372 |
373 | tensorboard.add_train_scalar("train_loss", total_loss / (batch_id + 1), now_batch_num)
374 | tensorboard.add_train_scalar("train_accu", right_token / whole_token, now_batch_num)
375 | now_batch_num += 1
376 |
377 | temp_time = time.time()
378 | temp_cost = temp_time - temp_start
379 | print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token))
380 | epoch_finish = time.time()
381 | epoch_cost = epoch_finish - epoch_start
382 | print("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s" % (
383 | idx, epoch_cost, train_num / epoch_cost, total_loss))
384 | # exit(0)
385 | # continue
386 | speed, acc, p, r, f, _ = evaluate(data, model, "dev")
387 | dev_finish = time.time()
388 | dev_cost = dev_finish - epoch_finish
389 |
390 | if seg:
391 | current_score = f
392 | print("Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (dev_cost, speed, acc, p, r, f))
393 |
394 | tensorboard.add_validation_scalar("dev_accu", acc, idx)
395 | tensorboard.add_validation_scalar("dev_p", p, idx)
396 | tensorboard.add_validation_scalar("dev_r", r, idx)
397 | tensorboard.add_validation_scalar("dev_f", f, idx)
398 |
399 | else:
400 | current_score = acc
401 | print("Dev: time: %.2fs speed: %.2fst/s; acc: %.4f" % (dev_cost, speed, acc))
402 |
403 | if current_score > best_dev:
404 | if seg:
405 | print("Exceed previous best f score:", best_dev)
406 | else:
407 | print("Exceed previous best acc score:", best_dev)
408 | # model_name = save_model_dir + '.' + str(idx) + ".model"
409 | # torch.save(model.state_dict(), model_name)
410 | best_dev = current_score
411 | # ## decode test
412 | speed, acc, p, r, f, _ = evaluate(data, model, "test")
413 | test_finish = time.time()
414 | test_cost = test_finish - dev_finish
415 | if seg:
416 | print("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (test_cost, speed, acc, p, r, f))
417 |
418 | tensorboard.add_test_scalar("test_accu", acc, idx)
419 | tensorboard.add_test_scalar("test_p", p, idx)
420 | tensorboard.add_test_scalar("test_r", r, idx)
421 | tensorboard.add_test_scalar("test_f", f, idx)
422 |
423 | else:
424 | print("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f" % (test_cost, speed, acc))
425 | gc.collect()
426 |
427 | ## save the model every epoch
428 | model_state = model.state_dict()
429 | training_state = {
430 | "epoch": idx,
431 | "optimizer": optimizer.state_dict()
432 | }
433 | save_model(idx, (model_state, training_state), save_model_dir)
434 |
435 |
436 | if __name__ == '__main__':
437 | parser = argparse.ArgumentParser(description='Tuning with bi-directional LSTM-CRF')
438 | parser.add_argument('--embedding', help='Embedding for words', default='None')
439 | parser.add_argument('--status', choices=['train', 'test', 'decode'], help='update algorithm', default='train')
440 | parser.add_argument('--savemodel', default="data/model/resume")
441 | parser.add_argument('--savedset', help='Dir of saved data setting', default="data/save.dset")
442 | parser.add_argument('--train', default="data/resume/train.char.bmes")
443 | parser.add_argument('--dev', default="data/resume/dev.char.bmes")
444 | parser.add_argument('--test', default="data/resume/test.char.bmes")
445 | parser.add_argument('--seg', default="True")
446 | parser.add_argument('--extendalphabet', default="True")
447 | parser.add_argument('--raw')
448 | parser.add_argument('--loadmodel')
449 | parser.add_argument('--output')
450 | args = parser.parse_args()
451 |
452 | train_file = args.train
453 | dev_file = args.dev
454 | test_file = args.test
455 | raw_file = args.raw
456 | model_dir = args.loadmodel
457 | dset_dir = args.savedset
458 | output_file = args.output
459 | if args.seg.lower() == "true":
460 | seg = True
461 | else:
462 | seg = False
463 | status = args.status.lower()
464 |
465 | save_model_dir = args.savemodel
466 | gpu = torch.cuda.is_available()
467 |
468 | char_emb = "data/gigaword_chn.all.a2b.uni.ite50.vec"
469 | bichar_emb = None
470 | gaz_file = "data/ctb.50d.vec"
471 | # gaz_file = None
472 | # char_emb = None
473 | # bichar_emb = None
474 |
475 | print("CuDNN:", torch.backends.cudnn.enabled)
476 | # gpu = False
477 | print("GPU available:", gpu)
478 | print("Status:", status)
479 | print("Seg: ", seg)
480 | print("Train file:", train_file)
481 | print("Dev file:", dev_file)
482 | print("Test file:", test_file)
483 | print("Raw file:", raw_file)
484 | print("Char emb:", char_emb)
485 | print("Bichar emb:", bichar_emb)
486 | print("Gaz file:", gaz_file)
487 | if status == 'train':
488 | print("Model saved to:", save_model_dir)
489 | sys.stdout.flush()
490 |
491 | if status == 'train':
492 | data = Data()
493 | data.HP_gpu = gpu
494 | data.HP_use_char = False
495 | data.HP_batch_size = 1
496 | data.use_bigram = False
497 | data.gaz_dropout = 0.5
498 | data.norm_gaz_emb = False
499 | data.HP_fix_gaz_emb = False
500 | data_initialization(data, gaz_file, train_file, dev_file, test_file)
501 | data.generate_instance_with_gaz_no_char(train_file, 'train')
502 | data.generate_instance_with_gaz_no_char(dev_file, 'dev')
503 | data.generate_instance_with_gaz_no_char(test_file, 'test')
504 |
505 | data.build_word_pretrain_emb(char_emb)
506 | data.build_biword_pretrain_emb(bichar_emb)
507 | data.build_gaz_pretrain_emb(gaz_file)
508 |
509 | train(data, save_model_dir, seg)
510 |
511 |
512 |
513 |
--------------------------------------------------------------------------------