├── README.md ├── data ├── dataset │ └── README.md └── embeddings │ └── README.md ├── layer ├── crf.py └── gatlayer.py ├── main.py ├── model └── bilstm_gat_crf.py ├── paper └── Leverage Lexical Knowledge for Chinese Named Entity Recognition via Collaborative Graph Network.pdf ├── run_main.sh └── utils ├── alphabet.py ├── batchify.py ├── config.py ├── data.py ├── functions.py ├── gazetter.py ├── graph_generator.py ├── metric.py └── trie.py /README.md: -------------------------------------------------------------------------------- 1 | # Leverage Lexical Knowledge for Chinese Named Entity Recognition via Collaborative Graph Network 2 | [![GitHub stars](https://img.shields.io/github/stars/DianboWork/Graph4CNER?style=flat-square)](https://github.com/DianboWork/Graph4CNER/stargazers) 3 | [![GitHub forks](https://img.shields.io/github/forks/DianboWork/Graph4CNER?style=flat-square&color=blueviolet)](https://github.com/DianboWork/Graph4CNER/network/members) 4 | 5 | Source code for [Leverage Lexical Knowledge for Chinese Named Entity Recognition via Collaborative Graph Network](https://www.aclweb.org/anthology/D19-1396.pdf) in EMNLP 2019. If you use this code or our results in your research, we would appreciate it if you cite our paper as following: 6 | 7 | 8 | ``` 9 | @article{Sui2019Graph4CNER, 10 | title = {Leverage Lexical Knowledge for Chinese Named Entity Recognition via Collaborative Graph Network}, 11 | author = {Sui, Dianbo and Chen, Yubo and Liu, Kang and Zhao, Jun and Liu, Shengping}, 12 | journal = {Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing}, 13 | year = {2019} 14 | } 15 | ``` 16 | Requirements: 17 | ====== 18 | Python: 3.7 19 | PyTorch: 1.1.0 20 | 21 | Input format: 22 | ====== 23 | Input is in CoNLL format (We use BIO tag scheme), where each character and its label are in one line. Sentences are split with a null line. 24 | 25 | 叶 B-PER 26 | 嘉 I-PER 27 | 莹 I-PER 28 | 先 O 29 | 生 O 30 | 获 O 31 | 聘 O 32 | 南 B-ORG 33 | 开 I-ORG 34 | 大 I-ORG 35 | 学 I-ORG 36 | 终 O 37 | 身 O 38 | 校 O 39 | 董 O 40 | 。 O 41 | 42 | Pretrained Embeddings: 43 | ==== 44 | Character embeddings (gigaword_chn.all.a2b.uni.ite50.vec) can be downloaded in [Google Drive](https://drive.google.com/file/d/1_Zlf0OAZKVdydk7loUpkzD2KPEotUE8u/view?usp=sharing) or [Baidu Pan](https://pan.baidu.com/s/1pLO6T9D). 45 | 46 | Word embeddings (sgns.merge.word) can be downloaded in [Google Drive](https://drive.google.com/file/d/1Zh9ZCEu8_eSQ-qkYVQufQDNKPC4mtEKR/view) or 47 | [Baidu Pan](https://pan.baidu.com/s/1luy-GlTdqqvJ3j-A4FcIOw). 48 | 49 | Usage: 50 | ==== 51 | :one: Download the character embeddings and word embeddings and put them in the `data/embeddings` folder. 52 | 53 | :two: Modify the `run_main.sh` by adding your train/dev/test file directory. 54 | 55 | :three: `sh run_main.sh`. Note that the default hyperparameters is may not be the optimal hyperparameters, and you need to adjust these. 56 | 57 | :four: Enjoy it! :smile: 58 | 59 | Result: 60 | ==== 61 | For WeiboNER dataset, using the default hyperparameters in `run_main.sh` can achieve the state-of-art results (Test F1: 66.66%). Model parameters can be download in [Baidu Pan](https://pan.baidu.com/s/1ysy_eNF0oYJwjXiy4x7gtQ) (key: bg3q):sunglasses: 62 | 63 | Speed: 64 | === 65 | I have optimized the code and this version is faster than the one in our paper. :muscle: 66 | -------------------------------------------------------------------------------- /data/dataset/README.md: -------------------------------------------------------------------------------- 1 | Please download the [WeiboNER dataset](https://github.com/hltcoe/golden-horse/tree/master/data). Due to the copyrights, we can't provide MSRA and OntoNotes dataset. I am sorry for that. 2 | -------------------------------------------------------------------------------- /data/embeddings/README.md: -------------------------------------------------------------------------------- 1 | A folder for embeddings file 2 | -------------------------------------------------------------------------------- /layer/crf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import pdb 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | START_TAG = -2 9 | STOP_TAG = -1 10 | 11 | 12 | # Compute log sum exp in a numerically stable way for the forward algorithm 13 | def log_sum_exp(vec, m_size): 14 | """ 15 | calculate log of exp sum 16 | args: 17 | vec (batch_size, vanishing_dim, hidden_dim) : input tensor 18 | m_size : hidden_dim 19 | return: 20 | batch_size, hidden_dim 21 | """ 22 | _, idx = torch.max(vec, 1) # B * 1 * M 23 | max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M 24 | return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, 25 | m_size) # B * M 26 | 27 | 28 | class CRF(nn.Module): 29 | 30 | def __init__(self, tagset_size, gpu): 31 | super(CRF, self).__init__() 32 | print 33 | "build batched crf..." 34 | self.gpu = gpu 35 | # Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j. 36 | self.average_batch = False 37 | self.tagset_size = tagset_size 38 | # # We add 2 here, because of START_TAG and STOP_TAG 39 | # # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag 40 | init_transitions = torch.zeros(self.tagset_size + 2, self.tagset_size + 2) 41 | # init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2) 42 | # init_transitions[:,START_TAG] = -1000.0 43 | # init_transitions[STOP_TAG,:] = -1000.0 44 | # init_transitions[:,0] = -1000.0 45 | # init_transitions[0,:] = -1000.0 46 | if self.gpu: 47 | init_transitions = init_transitions.cuda() 48 | self.transitions = nn.Parameter(init_transitions) 49 | 50 | # self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2)) 51 | # self.transitions.data.zero_() 52 | 53 | def _calculate_PZ(self, feats, mask): 54 | """ 55 | input: 56 | feats: (batch, seq_len, self.tag_size+2) 57 | masks: (batch, seq_len) 58 | """ 59 | batch_size = feats.size(0) 60 | seq_len = feats.size(1) 61 | tag_size = feats.size(2) 62 | # print feats.view(seq_len, tag_size) 63 | assert (tag_size == self.tagset_size + 2) 64 | mask = mask.transpose(1, 0).contiguous() 65 | ins_num = seq_len * batch_size 66 | ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) 67 | feats = feats.transpose(1, 0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) 68 | ## need to consider start 69 | scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) 70 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 71 | # build iter 72 | seq_iter = enumerate(scores) 73 | _, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size 74 | # only need start from start_tag 75 | partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size 76 | 77 | ## add start score (from start to all tag, duplicate to batch_size) 78 | # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1) 79 | # iter over last scores 80 | for idx, cur_values in seq_iter: 81 | # previous to_target is current from_target 82 | # partition: previous results log(exp(from_target)), #(batch_size * from_target) 83 | # cur_values: bat_size * from_target * to_target 84 | 85 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, 86 | tag_size) 87 | cur_partition = log_sum_exp(cur_values, tag_size) 88 | # print cur_partition.data 89 | 90 | # (bat_size * from_target * to_target) -> (bat_size * to_target) 91 | # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1) 92 | mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) 93 | 94 | ## effective updated partition part, only keep the partition value of mask value = 1 95 | masked_cur_partition = cur_partition.masked_select(mask_idx) 96 | ## let mask_idx broadcastable, to disable warning 97 | mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) 98 | 99 | ## replace the partition where the maskvalue=1, other partition value keeps the same 100 | partition.masked_scatter_(mask_idx, masked_cur_partition) 101 | # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG 102 | cur_values = self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, 103 | tag_size) + partition.contiguous().view( 104 | batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 105 | cur_partition = log_sum_exp(cur_values, tag_size) 106 | final_partition = cur_partition[:, STOP_TAG] 107 | return final_partition.sum(), scores 108 | 109 | def _viterbi_decode(self, feats, mask): 110 | """ 111 | input: 112 | feats: (batch, seq_len, self.tag_size+2) 113 | mask: (batch, seq_len) 114 | output: 115 | decode_idx: (batch, seq_len) decoded sequence 116 | path_score: (batch, 1) corresponding score for each sequence (to be implementated) 117 | """ 118 | batch_size = feats.size(0) 119 | seq_len = feats.size(1) 120 | tag_size = feats.size(2) 121 | assert (tag_size == self.tagset_size + 2) 122 | ## calculate sentence length for each sentence 123 | length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long() 124 | ## mask to (seq_len, batch_size) 125 | mask = mask.transpose(1, 0).contiguous() 126 | ins_num = seq_len * batch_size 127 | ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) 128 | feats = feats.transpose(1, 0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) 129 | ## need to consider start 130 | scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) 131 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 132 | 133 | # build iter 134 | seq_iter = enumerate(scores) 135 | ## record the position of best score 136 | back_points = list() 137 | partition_history = list() 138 | 139 | ## reverse mask (bug for mask = 1- mask, use this as alternative choice) 140 | # mask = 1 + (-1)*mask 141 | mask = (1 - mask.long()).byte() 142 | _, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size 143 | # only need start from start_tag 144 | partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size) # bat_size * to_target_size 145 | partition_history.append(partition) 146 | # iter over last scores 147 | for idx, cur_values in seq_iter: 148 | # previous to_target is current from_target 149 | # partition: previous results log(exp(from_target)), #(batch_size * from_target) 150 | # cur_values: batch_size * from_target * to_target 151 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, 152 | tag_size) 153 | ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG 154 | partition, cur_bp = torch.max(cur_values, 1) 155 | partition_history.append(partition) 156 | ## cur_bp: (batch_size, tag_size) max source score position in current tag 157 | ## set padded label as 0, which will be filtered in post processing 158 | cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) 159 | back_points.append(cur_bp) 160 | ### add score to final STOP_TAG 161 | partition_history = torch.cat(partition_history, dim=0).view(seq_len, batch_size, -1).transpose(1, 162 | 0).contiguous() ## (batch_size, seq_len. tag_size) 163 | ### get the last position for each setences, and select the last partitions using gather() 164 | last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1 165 | last_partition = torch.gather(partition_history, 1, last_position).view(batch_size, tag_size, 1) 166 | ### calculate the score from last partition to end state (and then select the STOP_TAG from it) 167 | last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1, tag_size, 168 | tag_size).expand( 169 | batch_size, tag_size, tag_size) 170 | _, last_bp = torch.max(last_values, 1) 171 | pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long() 172 | if self.gpu: 173 | pad_zero = pad_zero.cuda() 174 | back_points.append(pad_zero) 175 | back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) 176 | 177 | ## select end ids in STOP_TAG 178 | pointer = last_bp[:, STOP_TAG] 179 | insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size) 180 | back_points = back_points.transpose(1, 0).contiguous() 181 | ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values 182 | # print "lp:",last_position 183 | # print "il:",insert_last 184 | back_points.scatter_(1, last_position, insert_last) 185 | # print "bp:",back_points 186 | # exit(0) 187 | back_points = back_points.transpose(1, 0).contiguous() 188 | ## decode from the end, padded position ids are 0, which will be filtered if following evaluation 189 | decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size)) 190 | if self.gpu: 191 | decode_idx = decode_idx.cuda() 192 | decode_idx[-1] = pointer.data 193 | for idx in range(len(back_points) - 2, -1, -1): 194 | pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) 195 | decode_idx[idx] = pointer.squeeze() 196 | path_score = None 197 | decode_idx = decode_idx.transpose(1, 0) 198 | return path_score, decode_idx 199 | 200 | def forward(self, feats): 201 | path_score, best_path = self.viterbi_decode(feats) 202 | return path_score, best_path 203 | 204 | def _score_sentence(self, scores, mask, tags): 205 | """ 206 | input: 207 | scores: variable (seq_len, batch, tag_size, tag_size) 208 | mask: (batch, seq_len) 209 | tags: tensor (batch, seq_len) 210 | output: 211 | score: sum of score for gold sequences within whole batch 212 | """ 213 | # Gives the score of a provided tag sequence 214 | batch_size = scores.size(1) 215 | seq_len = scores.size(0) 216 | tag_size = scores.size(2) 217 | ## convert tag value into a new format, recorded label bigram information to index 218 | new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len)) 219 | if self.gpu: 220 | new_tags = new_tags.cuda() 221 | for idx in range(seq_len): 222 | if idx == 0: 223 | ## start -> first score 224 | new_tags[:, 0] = (tag_size - 2) * tag_size + tags[:, 0] 225 | 226 | else: 227 | new_tags[:, idx] = tags[:, idx - 1] * tag_size + tags[:, idx] 228 | 229 | ## transition for label to STOP_TAG 230 | end_transition = self.transitions[:, STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) 231 | ## length for batch, last word position = length - 1 232 | length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long() 233 | ## index the label id of last word 234 | end_ids = torch.gather(tags, 1, length_mask - 1) 235 | 236 | ## index the transition score for end_id to STOP_TAG 237 | end_energy = torch.gather(end_transition, 1, end_ids) 238 | 239 | ## convert tag as (seq_len, batch_size, 1) 240 | new_tags = new_tags.transpose(1, 0).contiguous().view(seq_len, batch_size, 1) 241 | ### need convert tags id to search from 400 positions of scores 242 | tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, 243 | batch_size) # seq_len * bat_size 244 | ## mask transpose to (seq_len, batch_size) 245 | tg_energy = tg_energy.masked_select(mask.transpose(1, 0)) 246 | 247 | # ## calculate the score from START_TAG to first label 248 | # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) 249 | # start_energy = torch.gather(start_transition, 1, tags[0,:]) 250 | 251 | ## add all score together 252 | # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() 253 | gold_score = tg_energy.sum() + end_energy.sum() 254 | return gold_score 255 | 256 | def neg_log_likelihood_loss(self, feats, mask, tags): 257 | # nonegative log likelihood 258 | batch_size = feats.size(0) 259 | forward_score, scores = self._calculate_PZ(feats, mask) 260 | gold_score = self._score_sentence(scores, mask, tags) 261 | # print "batch, f:", forward_score.data[0], " g:", gold_score.data[0], " dis:", forward_score.data[0] - gold_score.data[0] 262 | # exit(0) 263 | if self.average_batch: 264 | return (forward_score - gold_score) / batch_size 265 | else: 266 | return forward_score - gold_score -------------------------------------------------------------------------------- /layer/gatlayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GraphAttentionLayer(nn.Module): 7 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 8 | super(GraphAttentionLayer, self).__init__() 9 | self.dropout = dropout 10 | self.in_features = in_features 11 | self.out_features = out_features 12 | self.alpha = alpha 13 | self.concat = concat 14 | self.W = nn.Linear(in_features, out_features, bias=False) 15 | # self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 16 | nn.init.xavier_uniform_(self.W.weight, gain=1.414) 17 | # self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) 18 | self.a1 = nn.Parameter(torch.zeros(size=(out_features, 1))) 19 | self.a2 = nn.Parameter(torch.zeros(size=(out_features, 1))) 20 | nn.init.xavier_uniform_(self.a1.data, gain=1.414) 21 | nn.init.xavier_uniform_(self.a2.data, gain=1.414) 22 | self.leakyrelu = nn.LeakyReLU(self.alpha) 23 | 24 | def forward(self, input, adj): 25 | h = self.W(input) 26 | # [batch_size, N, out_features] 27 | batch_size, N, _ = h.size() 28 | middle_result1 = torch.matmul(h, self.a1).expand(-1, -1, N) 29 | middle_result2 = torch.matmul(h, self.a2).expand(-1, -1, N).transpose(1, 2) 30 | e = self.leakyrelu(middle_result1 + middle_result2) 31 | attention = e.masked_fill(adj == 0, -1e9) 32 | attention = F.softmax(attention, dim=2) 33 | attention = F.dropout(attention, self.dropout, training=self.training) 34 | h_prime = torch.matmul(attention, h) 35 | if self.concat: 36 | return F.elu(h_prime) 37 | else: 38 | return h_prime 39 | 40 | def __repr__(self): 41 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 42 | 43 | 44 | class GAT(nn.Module): 45 | def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads, layer): 46 | super(GAT, self).__init__() 47 | self.dropout = dropout 48 | self.layer = layer 49 | if self.layer == 1: 50 | self.attentions = [GraphAttentionLayer(nfeat, nclass, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)] 51 | else: 52 | self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)] 53 | self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False) 54 | for i, attention in enumerate(self.attentions): 55 | self.add_module('attention_{}'.format(i), attention) 56 | 57 | def forward(self, x, adj): 58 | x = F.dropout(x, self.dropout, training=self.training) 59 | if self.layer == 1: 60 | x = torch.stack([att(x, adj) for att in self.attentions], dim=2) 61 | x = x.sum(2) 62 | x = F.dropout(x, self.dropout, training=self.training) 63 | return F.log_softmax(x, dim=2) 64 | else: 65 | x = torch.cat([att(x, adj) for att in self.attentions], dim=2) 66 | x = F.dropout(x, self.dropout, training=self.training) 67 | x = F.elu(self.out_att(x, adj)) 68 | return F.log_softmax(x, dim=2) 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from utils.data import Data 2 | from utils.batchify import batchify 3 | from utils.config import get_args 4 | from utils.metric import get_ner_fmeasure 5 | from model.bilstm_gat_crf import BLSTM_GAT_CRF 6 | import os 7 | import numpy as np 8 | import copy 9 | import pickle 10 | import torch 11 | import torch.optim as optim 12 | import time 13 | import random 14 | import sys 15 | import gc 16 | 17 | 18 | def data_initialization(args): 19 | data_stored_directory = args.data_stored_directory 20 | file = data_stored_directory + args.dataset_name + "_dataset.dset" 21 | if os.path.exists(file) and not args.refresh: 22 | data = load_data_setting(data_stored_directory, args.dataset_name) 23 | else: 24 | data = Data() 25 | data.dataset_name = args.dataset_name 26 | data.norm_char_emb = args.norm_char_emb 27 | data.norm_gaz_emb = args.norm_gaz_emb 28 | data.number_normalized = args.number_normalized 29 | data.max_sentence_length = args.max_sentence_length 30 | data.build_gaz_file(args.gaz_file) 31 | data.generate_instance(args.train_file, "train", False) 32 | data.generate_instance(args.dev_file, "dev") 33 | data.generate_instance(args.test_file, "test") 34 | data.build_char_pretrain_emb(args.char_embedding_path) 35 | data.build_gaz_pretrain_emb(args.gaz_file) 36 | data.fix_alphabet() 37 | data.get_tag_scheme() 38 | save_data_setting(data, data_stored_directory) 39 | return data 40 | 41 | 42 | def save_data_setting(data, data_stored_directory): 43 | new_data = copy.deepcopy(data) 44 | data.show_data_summary() 45 | if not os.path.exists(data_stored_directory): 46 | os.makedirs(data_stored_directory) 47 | dataset_saved_name = data_stored_directory + data.dataset_name +"_dataset.dset" 48 | with open(dataset_saved_name, 'wb') as fp: 49 | pickle.dump(new_data, fp) 50 | print("Data setting saved to file: ", dataset_saved_name) 51 | 52 | 53 | def load_data_setting(data_stored_directory, name): 54 | dataset_saved_name = data_stored_directory + name + "_dataset.dset" 55 | with open(dataset_saved_name, 'rb') as fp: 56 | data = pickle.load(fp) 57 | print("Data setting loaded from file: ", dataset_saved_name) 58 | data.show_data_summary() 59 | return data 60 | 61 | 62 | def lr_decay(optimizer, epoch, decay_rate, init_lr): 63 | lr = init_lr * ((1-decay_rate)**epoch) 64 | print(" Learning rate is setted as:", lr) 65 | for param_group in optimizer.param_groups: 66 | param_group['lr'] = lr 67 | return optimizer 68 | 69 | 70 | def predict_check(pred_variable, gold_variable, mask_variable): 71 | """ 72 | input: 73 | pred_variable (batch_size, sent_len): pred tag result, in numpy format 74 | gold_variable (batch_size, sent_len): gold result variable 75 | mask_variable (batch_size, sent_len): mask variable 76 | """ 77 | pred = pred_variable.cpu().data.numpy() 78 | gold = gold_variable.cpu().data.numpy() 79 | mask = mask_variable.cpu().data.numpy() 80 | overlaped = (pred == gold) 81 | right_token = np.sum(overlaped * mask) 82 | total_token = mask.sum() 83 | # print("right: %s, total: %s"%(right_token, total_token)) 84 | return right_token, total_token 85 | 86 | 87 | def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet, word_recover): 88 | """ 89 | input: 90 | pred_variable (batch_size, sent_len): pred tag result 91 | gold_variable (batch_size, sent_len): gold result variable 92 | mask_variable (batch_size, sent_len): mask variable 93 | """ 94 | pred_variable = pred_variable[word_recover] 95 | gold_variable = gold_variable[word_recover] 96 | mask_variable = mask_variable[word_recover] 97 | seq_len = gold_variable.size(1) 98 | mask = mask_variable.cpu().data.numpy() 99 | pred_tag = pred_variable.cpu().data.numpy() 100 | gold_tag = gold_variable.cpu().data.numpy() 101 | batch_size = mask.shape[0] 102 | pred_label = [] 103 | gold_label = [] 104 | for idx in range(batch_size): 105 | pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] 106 | gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] 107 | assert (len(pred) == len(gold)) 108 | pred_label.append(pred) 109 | gold_label.append(gold) 110 | return pred_label, gold_label 111 | 112 | 113 | def evaluate(data, model, args, name): 114 | if name == "train": 115 | instances = data.train_ids 116 | elif name == "dev": 117 | instances = data.dev_ids 118 | elif name == 'test': 119 | instances = data.test_ids 120 | else: 121 | print("Error: wrong evaluate name,", name) 122 | pred_results = [] 123 | gold_results = [] 124 | model.eval() 125 | batch_size = args.batch_size 126 | start_time = time.time() 127 | train_num = len(instances) 128 | total_batch = train_num//batch_size+1 129 | for batch_id in range(total_batch): 130 | start = batch_id*batch_size 131 | end = (batch_id+1)*batch_size 132 | if end > train_num: 133 | end = train_num 134 | instance = instances[start:end] 135 | if not instance: 136 | continue 137 | char, c_len, gazs, mask, label, recover, t_graph, c_graph, l_graph = batchify(instance, args.use_gpu) 138 | tag_seq = model(char, c_len, gazs, t_graph, c_graph, l_graph, mask) 139 | pred_label, gold_label = recover_label(tag_seq, label, mask, data.label_alphabet, recover) 140 | pred_results += pred_label 141 | gold_results += gold_label 142 | decode_time = time.time() - start_time 143 | speed = len(instances)/decode_time 144 | acc, p, r, f = get_ner_fmeasure(gold_results, pred_results, data.tagscheme) 145 | return speed, acc, p, r, f, pred_results 146 | 147 | 148 | def train(data, model, args): 149 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 150 | if args.optimizer == "Adam": 151 | optimizer = optim.Adam(parameters, lr=args.lr, weight_decay=args.l2_penalty) 152 | elif args.optimizer == "SGD": 153 | optimizer = optim.SGD(parameters, lr=args.lr, weight_decay=args.l2_penalty) 154 | best_dev = -1 155 | for idx in range(args.max_epoch): 156 | epoch_start = time.time() 157 | temp_start = epoch_start 158 | print("Epoch: %s/%s" % (idx, args.max_epoch)) 159 | optimizer = lr_decay(optimizer, idx, args.lr_decay, args.lr) 160 | instance_count = 0 161 | sample_loss = 0 162 | total_loss = 0 163 | random.shuffle(data.train_ids) 164 | model.train() 165 | model.zero_grad() 166 | batch_size = args.batch_size 167 | train_num = len(data.train_ids) 168 | total_batch = train_num // batch_size + 1 169 | for batch_id in range(total_batch): 170 | start = batch_id*batch_size 171 | end = (batch_id+1)*batch_size 172 | if end > train_num: 173 | end = train_num 174 | instance = data.train_ids[start:end] 175 | if not instance: 176 | continue 177 | model.zero_grad() 178 | char, c_len, gazs, mask, label, recover, t_graph, c_graph, l_graph = batchify(instance, args.use_gpu) 179 | loss = model.neg_log_likelihood(char, c_len, gazs, t_graph, c_graph, l_graph, mask, label) 180 | instance_count += 1 181 | sample_loss += loss.item() 182 | total_loss += loss.item() 183 | loss.backward() 184 | if args.use_clip: 185 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 186 | optimizer.step() 187 | model.zero_grad() 188 | if end % 500 == 0: 189 | temp_time = time.time() 190 | temp_cost = temp_time - temp_start 191 | temp_start = temp_time 192 | print(" Instance: %s; Time: %.2fs; loss: %.4f" % ( 193 | end, temp_cost, sample_loss)) 194 | sys.stdout.flush() 195 | sample_loss = 0 196 | temp_time = time.time() 197 | temp_cost = temp_time - temp_start 198 | print(" Instance: %s; Time: %.2fs; loss: %.4f" % (end, temp_cost, sample_loss)) 199 | epoch_finish = time.time() 200 | epoch_cost = epoch_finish - epoch_start 201 | print("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s"%(idx, epoch_cost, train_num/epoch_cost, total_loss)) 202 | speed, acc, p, r, f, _ = evaluate(data, model, args, "dev") 203 | dev_finish = time.time() 204 | dev_cost = dev_finish - epoch_finish 205 | current_score = f 206 | print( 207 | "Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (dev_cost, speed, acc, p, r, f)) 208 | if current_score > best_dev: 209 | print("Exceed previous best f score:", best_dev) 210 | if not os.path.exists(args.param_stored_directory + args.dataset_name + "_param"): 211 | os.makedirs(args.param_stored_directory + args.dataset_name + "_param") 212 | model_name = "{}epoch_{}_f1_{}.model".format(args.param_stored_directory + args.dataset_name + "_param/", idx, current_score) 213 | torch.save(model.state_dict(), model_name) 214 | best_dev = current_score 215 | gc.collect() 216 | 217 | 218 | if __name__ == '__main__': 219 | args, unparsed = get_args() 220 | for arg in vars(args): 221 | print(arg, ":", getattr(args, arg)) 222 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.visible_gpu) 223 | seed = args.random_seed 224 | torch.manual_seed(seed) 225 | torch.cuda.manual_seed_all(seed) 226 | np.random.seed(seed) 227 | random.seed(seed) 228 | torch.backends.cudnn.deterministic = True 229 | data = data_initialization(args) 230 | model = BLSTM_GAT_CRF(data, args) 231 | train(data, model, args) 232 | 233 | 234 | -------------------------------------------------------------------------------- /model/bilstm_gat_crf.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from layer.crf import CRF 6 | from layer.gatlayer import GAT 7 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 8 | 9 | 10 | class BLSTM_GAT_CRF(nn.Module): 11 | def __init__(self, data, args): 12 | super(BLSTM_GAT_CRF, self).__init__() 13 | print("build BLSTM_GAT_CRF model...") 14 | self.name = "BLSTM_GAT_CRF" 15 | self.strategy = args.strategy 16 | self.char_emb_dim = data.char_emb_dim 17 | self.gaz_emb_dim = data.gaz_emb_dim 18 | self.gaz_embeddings = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) 19 | self.char_embeddings = nn.Embedding(data.char_alphabet.size(), self.char_emb_dim) 20 | if data.pretrain_char_embedding is not None: 21 | self.char_embeddings.weight.data.copy_(torch.from_numpy(data.pretrain_char_embedding)) 22 | else: 23 | self.char_embeddings.weight.data.copy_( 24 | torch.from_numpy(self.random_embedding(data.char_alphabet.size(), self.char_emb_dim))) 25 | if data.pretrain_gaz_embedding is not None: 26 | self.gaz_embeddings.weight.data.copy_(torch.from_numpy(data.pretrain_gaz_embedding)) 27 | else: 28 | self.gaz_embeddings.weight.data.copy_( 29 | torch.from_numpy(self.random_embedding(data.gaz_alphabet.size(), self.gaz_emb_dim))) 30 | if args.fix_gaz_emb: 31 | self.gaz_embeddings.weight.requires_grad = False 32 | else: 33 | self.gaz_embeddings.weight.requires_grad = True 34 | self.hidden_dim = self.gaz_emb_dim 35 | self.bilstm_flag = args.bilstm_flag 36 | self.lstm_layer = args.lstm_layer 37 | if self.bilstm_flag: 38 | lstm_hidden = self.hidden_dim // 2 39 | else: 40 | lstm_hidden = self.hidden_dim 41 | crf_input_dim = data.label_alphabet.size()+1 42 | self.lstm = nn.LSTM(self.char_emb_dim, lstm_hidden, num_layers=self.lstm_layer, batch_first=True, bidirectional=self.bilstm_flag) 43 | self.hidden2hidden = nn.Linear(self.hidden_dim, crf_input_dim) 44 | self.gat_1 = GAT(self.hidden_dim, args.gat_nhidden, crf_input_dim, args.dropgat, args.alpha, args.gat_nhead, args.gat_layer) 45 | self.gat_2 = GAT(self.hidden_dim, args.gat_nhidden, crf_input_dim, args.dropgat, args.alpha, args.gat_nhead, args.gat_layer) 46 | self.gat_3 = GAT(self.hidden_dim, args.gat_nhidden, crf_input_dim, args.dropgat, args.alpha, args.gat_nhead, args.gat_layer) 47 | self.crf = CRF(data.label_alphabet.size()-1, args.use_gpu) 48 | if self.strategy == "v": 49 | self.weight1 = nn.Parameter(torch.ones(crf_input_dim)) 50 | self.weight2 = nn.Parameter(torch.ones(crf_input_dim)) 51 | self.weight3 = nn.Parameter(torch.ones(crf_input_dim)) 52 | self.weight4 = nn.Parameter(torch.ones(crf_input_dim)) 53 | elif self.strategy == "n": 54 | self.weight1 = nn.Parameter(torch.ones(1)) 55 | self.weight2 = nn.Parameter(torch.ones(1)) 56 | self.weight3 = nn.Parameter(torch.ones(1)) 57 | self.weight4 = nn.Parameter(torch.ones(1)) 58 | else: 59 | self.weight = nn.Linear(crf_input_dim*4, crf_input_dim) 60 | self.dropout = nn.Dropout(args.dropout) 61 | self.droplstm = nn.Dropout(args.droplstm) 62 | self.gaz_dropout = nn.Dropout(args.gaz_dropout) 63 | self.reset_parameters() 64 | if args.use_gpu: 65 | self.to_cuda() 66 | 67 | def to_cuda(self): 68 | self.char_embeddings = self.char_embeddings.cuda() 69 | self.gaz_embeddings = self.gaz_embeddings.cuda() 70 | self.lstm = self.lstm.cuda() 71 | self.gat_1 = self.gat_1.cuda() 72 | self.gat_2 = self.gat_2.cuda() 73 | self.gat_3 = self.gat_3.cuda() 74 | self.hidden2hidden = self.hidden2hidden.cuda() 75 | self.gaz_dropout = self.gaz_dropout.cuda() 76 | self.dropout = self.dropout.cuda() 77 | self.droplstm = self.droplstm.cuda() 78 | self.gaz_dropout = self.gaz_dropout.cuda() 79 | if self.strategy in ["v", "n"]: 80 | self.weight1.data = self.weight1.data.cuda() 81 | self.weight2.data = self.weight2.data.cuda() 82 | self.weight3.data = self.weight3.data.cuda() 83 | self.weight4.data = self.weight4.data.cuda() 84 | else: 85 | self.weight = self.weight.cuda() 86 | 87 | def reset_parameters(self): 88 | nn.init.orthogonal_(self.lstm.weight_ih_l0) 89 | nn.init.orthogonal_(self.lstm.weight_hh_l0) 90 | nn.init.orthogonal_(self.lstm.weight_hh_l0_reverse) 91 | nn.init.orthogonal_(self.lstm.weight_ih_l0_reverse) 92 | nn.init.orthogonal_(self.hidden2hidden.weight) 93 | nn.init.constant_(self.hidden2hidden.bias, 0) 94 | 95 | def random_embedding(self, vocab_size, embedding_dim): 96 | pretrain_emb = np.empty([vocab_size, embedding_dim]) 97 | scale = np.sqrt(3.0 / embedding_dim) 98 | for index in range(vocab_size): 99 | pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim]) 100 | return pretrain_emb 101 | 102 | def _get_lstm_features(self, batch_char, batch_len): 103 | embeds = self.char_embeddings(batch_char) 104 | embeds = self.dropout(embeds) 105 | embeds_pack = pack_padded_sequence(embeds, batch_len, batch_first=True) 106 | out_packed, (_, _) = self.lstm(embeds_pack) 107 | lstm_feature, _ = pad_packed_sequence(out_packed, batch_first=True) 108 | lstm_feature = self.droplstm(lstm_feature) 109 | return lstm_feature 110 | 111 | def _get_crf_feature(self, batch_char, batch_len, gaz_list, t_graph, c_graph, l_graph): 112 | gaz_feature = self.gaz_embeddings(gaz_list) 113 | gaz_feature = self.gaz_dropout(gaz_feature) 114 | lstm_feature = self._get_lstm_features(batch_char, batch_len) 115 | max_seq_len = lstm_feature.size()[1] 116 | gat_input = torch.cat((lstm_feature, gaz_feature), dim=1) 117 | gat_feature_1 = self.gat_1(gat_input, t_graph) 118 | gat_feature_1 = gat_feature_1[:, :max_seq_len, :] 119 | gat_feature_2 = self.gat_2(gat_input, c_graph) 120 | gat_feature_2 = gat_feature_2[:, :max_seq_len, :] 121 | gat_feature_3 = self.gat_3(gat_input, l_graph) 122 | gat_feature_3 = gat_feature_3[:, :max_seq_len, :] 123 | lstm_feature = self.hidden2hidden(lstm_feature) 124 | if self.strategy == "m": 125 | crf_feature = torch.cat((lstm_feature, gat_feature_1, gat_feature_2, gat_feature_3), dim=2) 126 | crf_feature = self.weight(crf_feature) 127 | elif self.strategy == "v": 128 | crf_feature = torch.mul(lstm_feature, self.weight1) + torch.mul(gat_feature_1, self.weight2) + torch.mul( 129 | gat_feature_2, self.weight3) + torch.mul(gat_feature_3, self.weight4) 130 | else: 131 | crf_feature = self.weight1 * lstm_feature + self.weight2 * gat_feature_1 + self.weight3 * gat_feature_2 + self.weight4 * gat_feature_3 132 | return crf_feature 133 | 134 | def neg_log_likelihood(self, batch_char, batch_len, gaz_list, t_graph, c_graph, l_graph, mask, batch_label): 135 | crf_feature = self._get_crf_feature(batch_char, batch_len, gaz_list, t_graph, c_graph, l_graph) 136 | total_loss = self.crf.neg_log_likelihood_loss(crf_feature, mask, batch_label) 137 | return total_loss 138 | 139 | def forward(self, batch_char, batch_len, gaz_list, t_graph, c_graph, l_graph, mask): 140 | crf_feature = self._get_crf_feature(batch_char, batch_len, gaz_list, t_graph, c_graph, l_graph) 141 | _, best_path = self.crf._viterbi_decode(crf_feature, mask) 142 | return best_path 143 | 144 | -------------------------------------------------------------------------------- /paper/Leverage Lexical Knowledge for Chinese Named Entity Recognition via Collaborative Graph Network.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DianboWork/Graph4CNER/d2ba4eaf7fd6138467446eea216bc914de37306e/paper/Leverage Lexical Knowledge for Chinese Named Entity Recognition via Collaborative Graph Network.pdf -------------------------------------------------------------------------------- /run_main.sh: -------------------------------------------------------------------------------- 1 | # For WeiboNER 2 | python -m main --train_file Weibo_TRAIN_FILE --dev_file Weibo_DEV_FILE --test_file Weibo_TEST_FILE --gat_nhead 5 --gat_layer 2 --strategy n --batch_size 10 --lr 0.001 --lr_decay 0.01 --use_clip False --optimizer SGD --droplstm 0 --dropout 0.6 --dropgat 0 --gaz_dropout 0.4 --norm_char_emb True --norm_gaz_emb False --random_seed 100 3 | -------------------------------------------------------------------------------- /utils/alphabet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | class Alphabet: 6 | def __init__(self, name, padflag=True, unkflag=True, keep_growing=True): 7 | self.name = name 8 | self.PAD = "" 9 | self.UNKNOWN = "" 10 | self.padflag = padflag 11 | self.unkflag = unkflag 12 | self.instance2index = {} 13 | self.instances = [] 14 | self.keep_growing = keep_growing 15 | self.next_index = 0 16 | if self.padflag: 17 | self.add(self.PAD) 18 | if self.unkflag: 19 | self.add(self.UNKNOWN) 20 | 21 | def clear(self, keep_growing=True): 22 | self.instance2index = {} 23 | self.instances = [] 24 | self.keep_growing = keep_growing 25 | self.next_index = 0 26 | 27 | def add(self, instance): 28 | if instance not in self.instance2index: 29 | self.instances.append(instance) 30 | self.instance2index[instance] = self.next_index 31 | self.next_index += 1 32 | 33 | def get_index(self, instance): 34 | try: 35 | return self.instance2index[instance] 36 | except KeyError: 37 | if self.keep_growing: 38 | index = self.next_index 39 | self.add(instance) 40 | return index 41 | else: 42 | if self.UNKNOWN in self.instance2index: 43 | return self.instance2index[self.UNKNOWN] 44 | else: 45 | print(self.name + " get_index raise wrong, return 0. Please check it") 46 | return 0 47 | 48 | def get_instance(self, index): 49 | if index == 0: 50 | if self.padflag: 51 | print(self.name + " get_instance of , wrong?") 52 | if not self.padflag and self.unkflag: 53 | print(self.name + " get_instance of , wrong?") 54 | return self.instances[index] 55 | try: 56 | return self.instances[index] 57 | except IndexError: 58 | print('WARNING: '+ self.name + ' Alphabet get_instance, unknown instance, return the label.') 59 | return '' 60 | 61 | def size(self): 62 | return len(self.instances) 63 | 64 | def iteritems(self): 65 | return self.instance2index.items() 66 | 67 | def enumerate_items(self, start=1): 68 | if start < 1 or start >= self.size(): 69 | raise IndexError("Enumerate is allowed between [1 : size of the alphabet)") 70 | return zip(range(start, len(self.instances) + 1), self.instances[start - 1:]) 71 | 72 | def close(self): 73 | self.keep_growing = False 74 | 75 | def open(self): 76 | self.keep_growing = True 77 | 78 | def get_content(self): 79 | return {'instance2index': self.instance2index, 'instances': self.instances} 80 | 81 | def from_json(self, data): 82 | self.instances = data["instances"] 83 | self.instance2index = data["instance2index"] 84 | 85 | def save(self, output_directory, name=None): 86 | """ 87 | Save both alhpabet records to the given directory. 88 | :param output_directory: Directory to save model and weights. 89 | :param name: The alphabet saving name, optional. 90 | :return: 91 | """ 92 | saving_name = name if name else self.__name 93 | try: 94 | json.dump(self.get_content(), open(os.path.join(output_directory, saving_name + ".json"), 'w')) 95 | except Exception as e: 96 | print("Exception: Alphabet is not saved: " % repr(e)) 97 | 98 | def load(self, input_directory, name=None): 99 | """ 100 | Load model architecture and weights from the give directory. This allow we use old models even the structure 101 | changes. 102 | :param input_directory: Directory to save model and weights 103 | :return: 104 | """ 105 | loading_name = name if name else self.__name 106 | self.from_json(json.load(open(os.path.join(input_directory, loading_name + ".json")))) 107 | 108 | 109 | -------------------------------------------------------------------------------- /utils/batchify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.graph_generator import * 3 | 4 | 5 | def batchify(input_batch_list, gpu): 6 | batch_size = len(input_batch_list) 7 | words = [sent[0] for sent in input_batch_list] 8 | gazs = [sent[1] for sent in input_batch_list] 9 | labels = [sent[2] for sent in input_batch_list] 10 | word_seq_lengths = list(map(len, words)) 11 | max_seq_len = max(word_seq_lengths) 12 | gazs_list, gaz_lens, max_gaz_len = seq_gaz(gazs) 13 | tmp_matrix = list(map(graph_generator, [(max_gaz_len, max_seq_len, gaz) for gaz in gazs])) 14 | batch_t_matrix = torch.ByteTensor([ele[0] for ele in tmp_matrix]) 15 | batch_c_matrix = torch.ByteTensor([ele[1] for ele in tmp_matrix]) 16 | batch_l_matrix = torch.ByteTensor([ele[2] for ele in tmp_matrix]) 17 | gazs_tensor = torch.zeros((batch_size, max_gaz_len), requires_grad=False).long() 18 | word_seq_tensor = torch.zeros((batch_size, max_seq_len), requires_grad=False).long() 19 | label_seq_tensor = torch.zeros((batch_size, max_seq_len), requires_grad=False).long() 20 | mask = torch.zeros((batch_size, max_seq_len), requires_grad=False).byte() 21 | for idx, (seq, gaz, gaz_len, label, seqlen) in enumerate(zip(words, gazs_list, gaz_lens, labels, word_seq_lengths)): 22 | word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq) 23 | gazs_tensor[idx, :gaz_len] = torch.LongTensor(gaz) 24 | label_seq_tensor[idx, :seqlen] = torch.LongTensor(label) 25 | mask[idx, :seqlen] = torch.Tensor([1]*seqlen) 26 | word_seq_lengths = torch.LongTensor(word_seq_lengths) 27 | word_seq_lengths, word_perm_idx = word_seq_lengths.sort(0, descending=True) 28 | word_seq_tensor = word_seq_tensor[word_perm_idx] 29 | label_seq_tensor = label_seq_tensor[word_perm_idx] 30 | gazs_tensor = gazs_tensor[word_perm_idx] 31 | mask = mask[word_perm_idx] 32 | batch_t_matrix = batch_t_matrix[word_perm_idx] 33 | batch_c_matrix = batch_c_matrix[word_perm_idx] 34 | batch_l_matrix = batch_l_matrix[word_perm_idx] 35 | _, word_seq_recover = word_perm_idx.sort(0, descending=False) 36 | if gpu: 37 | word_seq_tensor = word_seq_tensor.cuda() 38 | word_seq_lengths = word_seq_lengths.cuda() 39 | word_seq_recover = word_seq_recover.cuda() 40 | label_seq_tensor = label_seq_tensor.cuda() 41 | mask = mask.cuda() 42 | batch_t_matrix = batch_t_matrix.cuda() 43 | gazs_tensor = gazs_tensor.cuda() 44 | batch_c_matrix = batch_c_matrix.cuda() 45 | batch_l_matrix = batch_l_matrix.cuda() 46 | return word_seq_tensor, word_seq_lengths, gazs_tensor, mask, label_seq_tensor, word_seq_recover, batch_t_matrix, batch_c_matrix, batch_l_matrix 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | 4 | 5 | def str2bool(v): 6 | return v.lower() in ('true') 7 | 8 | 9 | def add_argument_group(name): 10 | arg = parser.add_argument_group(name) 11 | return arg 12 | 13 | net_arg = add_argument_group('Network') 14 | net_arg.add_argument('--fix_gaz_emb', type=str2bool, default=True) 15 | net_arg.add_argument('--lstm_layer', type=int, default=1) 16 | net_arg.add_argument('--bilstm_flag', type=str2bool, default=True) 17 | net_arg.add_argument('--gat_nhidden', type=int, default=30) 18 | net_arg.add_argument('--gat_nhead', type=int, default=3) 19 | net_arg.add_argument('--gat_layer', type=int, default=2, choices=[1, 2]) 20 | net_arg.add_argument('--strategy', type=str, default="m", choices=['v', 'n', 'm']) 21 | net_arg.add_argument("--alpha", type=float, default=0.1) 22 | net_arg.add_argument('--dropout', type=float, default=0.5) 23 | net_arg.add_argument('--droplstm', type=float, default=0.5) 24 | net_arg.add_argument('--dropgat', type=float, default=0.5) 25 | net_arg.add_argument('--gaz_dropout', type=float, default=0.5) 26 | 27 | # Data 28 | data_arg = add_argument_group('Data') 29 | data_arg.add_argument('--dataset_name', type=str, default='Data') 30 | data_arg.add_argument('--train_file', type=str, help="train file") 31 | data_arg.add_argument('--test_file', type=str, help="test file") 32 | data_arg.add_argument('--dev_file', type=str, help="dev file") 33 | data_arg.add_argument('--gaz_file', type=str, default="./data/embeddings/sgns.merge.word", help="lexical embeddings file") 34 | data_arg.add_argument('--char_embedding_path', type=str, default="./data/embeddings/gigaword_chn.all.a2b.uni.ite50.vec",help="characher embeddings file") 35 | data_arg.add_argument('--data_stored_directory', type=str, default="./data/generated_data/") 36 | data_arg.add_argument('--param_stored_directory', type=str, default="./data/model_param/") 37 | 38 | preprocess_arg = add_argument_group('Preprocess') 39 | preprocess_arg.add_argument('--norm_char_emb', type=str2bool, default=False) 40 | preprocess_arg.add_argument('--norm_gaz_emb', type=str2bool, default=True) 41 | preprocess_arg.add_argument('--number_normalized', type=str2bool, default=True) 42 | preprocess_arg.add_argument('--max_sentence_length', type=int, default=250) 43 | 44 | learn_arg = add_argument_group('Learning') 45 | learn_arg.add_argument('--batch_size', type=int, default=20) 46 | learn_arg.add_argument('--max_epoch', type=int, default=150) 47 | learn_arg.add_argument('--lr', type=float, default=0.001) 48 | learn_arg.add_argument('--lr_decay', type=float, default=0.01) 49 | learn_arg.add_argument('--use_clip', type=str2bool, default=False) 50 | learn_arg.add_argument('--clip', type=float, default=5.0) 51 | learn_arg.add_argument("--optimizer", type=str, default="Adam", choices=['Adam', 'SGD']) 52 | learn_arg.add_argument("--l2_penalty", type=float, default=0.00000005) 53 | # Misc 54 | misc_arg = add_argument_group('Misc') 55 | misc_arg.add_argument('--refresh', type=str2bool, default=False) 56 | misc_arg.add_argument('--use_gpu', type=str2bool, default=True) 57 | misc_arg.add_argument('--visible_gpu', type=int, default=0) 58 | misc_arg.add_argument('--random_seed', type=int, default=1) 59 | 60 | 61 | def get_args(): 62 | args, unparsed = parser.parse_known_args() 63 | if len(unparsed) > 1: 64 | print("Unparsed args: {}".format(unparsed)) 65 | return args, unparsed 66 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | from utils.alphabet import Alphabet 4 | from utils.functions import * 5 | from utils.gazetter import Gazetteer 6 | 7 | 8 | class Data: 9 | def __init__(self): 10 | self.max_sentence_length = 200 11 | self.number_normalized = True 12 | self.norm_char_emb = True 13 | self.norm_gaz_emb = True 14 | self.dataset_name = 'msra' 15 | self.tagscheme = "NoSeg" 16 | self.char_alphabet = Alphabet('character') 17 | self.label_alphabet = Alphabet('label', unkflag=False) 18 | self.gaz_lower = False 19 | self.gaz = Gazetteer(self.gaz_lower) 20 | self.gaz_alphabet = Alphabet('gaz') 21 | self.train_ids = [] 22 | self.dev_ids = [] 23 | self.test_ids = [] 24 | self.train_texts = [] 25 | self.dev_texts = [] 26 | self.test_texts = [] 27 | self.char_emb_dim = 100 28 | self.gaz_emb_dim = 100 29 | self.pretrain_char_embedding = None 30 | self.pretrain_gaz_embedding = None 31 | self.dev_cut_num = 0 32 | self.train_cut_num = 0 33 | self.test_cut_num = 0 34 | self.cut_num = 0 35 | 36 | def show_data_summary(self): 37 | print("DATA SUMMARY START:") 38 | print(" Dataset name: %s" % self.dataset_name) 39 | print(" Tag scheme: %s" % (self.tagscheme)) 40 | print(" Max Sentence Length: %s" % self.max_sentence_length) 41 | print(" Char alphabet size: %s" % self.char_alphabet.size()) 42 | print(" Gaz alphabet size: %s" % self.gaz_alphabet.size()) 43 | print(" Label alphabet size: %s" % self.label_alphabet.size()) 44 | print(" Char embedding size: %s" % self.char_emb_dim) 45 | print(" Gaz embedding size: %s" % self.gaz_emb_dim) 46 | print(" Number normalized: %s" % self.number_normalized) 47 | print(" Norm char emb: %s" % self.norm_char_emb) 48 | print(" Norm gaz emb: %s" % self.norm_gaz_emb) 49 | print(" Train instance number: %s" % (len(self.train_ids))) 50 | print(" Dev instance number: %s" % (len(self.dev_ids))) 51 | print(" Test instance number: %s" % (len(self.test_ids))) 52 | if self.cut_num != 0: 53 | print(" Train&Dev cut number: %s" % self.cut_num) 54 | else: 55 | print(" Train cut number: %s" % self.train_cut_num) 56 | print(" Dev cut number: %s" % self.dev_cut_num) 57 | print(" Test cut number: %s" % self.test_cut_num) 58 | print("DATA SUMMARY END.") 59 | sys.stdout.flush() 60 | 61 | def build_gaz_file(self, gaz_file, skip_first_row=False, separator=" "): 62 | ## build gaz file,initial read gaz embedding file 63 | if gaz_file: 64 | with open(gaz_file, 'r') as file: 65 | i = 0 66 | for line in tqdm(file): 67 | if i == 0: 68 | i = i + 1 69 | if skip_first_row: 70 | _ = line.strip() 71 | continue 72 | fin = line.strip().split(separator)[0] 73 | if fin: 74 | self.gaz.insert(fin, "one_source") 75 | print("Load gaz file: ", gaz_file, " total size:", self.gaz.size()) 76 | else: 77 | print("Gaz file is None, load nothing") 78 | 79 | def fix_alphabet(self): 80 | self.char_alphabet.close() 81 | self.label_alphabet.close() 82 | self.gaz_alphabet.close() 83 | 84 | def build_char_pretrain_emb(self, emb_path, skip_first_row=False, separator=" "): 85 | print("build char pretrain emb...") 86 | self.pretrain_char_embedding, self.char_emb_dim = \ 87 | build_pretrain_embedding(emb_path, self.char_alphabet, skip_first_row, separator, 88 | self.char_emb_dim, 89 | self.norm_char_emb) 90 | 91 | def build_gaz_pretrain_emb(self, emb_path, skip_first_row=True, separator=" "): 92 | print("build gaz pretrain emb...") 93 | self.pretrain_gaz_embedding, self.gaz_emb_dim = build_pretrain_embedding(emb_path, self.gaz_alphabet, skip_first_row, separator, 94 | self.gaz_emb_dim, 95 | self.norm_gaz_emb) 96 | 97 | def generate_instance(self, input_file, name, random_split=False): 98 | texts, ids, cut_num = read_instance(input_file, self.gaz, self.char_alphabet, self.label_alphabet, self.gaz_alphabet, self.number_normalized, self.max_sentence_length) 99 | if name == "train": 100 | if random_split: 101 | random.seed(1) 102 | ix = [i for i in range(len(ids))] 103 | train_ix = random.sample(ix, int(len(ids) * 0.9)) 104 | dev_ix = list(set(ix).difference(set(train_ix))) 105 | self.train_ids = [ids[ele] for ele in train_ix] 106 | self.dev_ids = [ids[ele] for ele in dev_ix] 107 | self.train_texts = [texts[ele] for ele in train_ix] 108 | self.dev_texts = [texts[ele] for ele in dev_ix] 109 | self.cut_num = cut_num 110 | else: 111 | self.train_ids = ids 112 | self.train_texts = texts 113 | self.train_cut_num = cut_num 114 | elif name == "dev": 115 | self.dev_ids = ids 116 | self.dev_texts = texts 117 | self.dev_cut_num = cut_num 118 | elif name == "test": 119 | self.test_ids = ids 120 | self.test_texts = texts 121 | self.test_cut_num = cut_num 122 | else: 123 | print("Error: you can only generate train/dev/test instance! Illegal input:%s" % name) 124 | 125 | def get_tag_scheme(self): 126 | startS = False 127 | startB = False 128 | for label, _ in self.label_alphabet.iteritems(): 129 | if "S-" in label.upper(): 130 | startS = True 131 | elif "B-" in label.upper(): 132 | startB = True 133 | if startB: 134 | if startS: 135 | self.tagscheme = "BMES" 136 | else: 137 | self.tagscheme = "BIO" 138 | 139 | -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | 5 | def normalize_word(word): 6 | new_word = "" 7 | for char in word: 8 | if char.isdigit(): 9 | new_word += '0' 10 | else: 11 | new_word += char 12 | return new_word 13 | 14 | 15 | def read_instance(input_file, gaz, char_alphabet, label_alphabet, gaz_alphabet, number_normalized, max_sent_length): 16 | in_lines = open(input_file, 'r').readlines() 17 | instance_texts = [] 18 | instance_ids = [] 19 | chars = [] 20 | labels = [] 21 | char_ids = [] 22 | label_ids = [] 23 | cut_num = 0 24 | for idx in range(len(in_lines)): 25 | line = in_lines[idx] 26 | if len(line) > 2: 27 | pairs = line.strip().split() 28 | char = pairs[0] 29 | if number_normalized: 30 | char = normalize_word(char) 31 | label = pairs[-1] 32 | chars.append(char) 33 | labels.append(label) 34 | char_ids.append(char_alphabet.get_index(char)) 35 | label_ids.append(label_alphabet.get_index(label)) 36 | else: 37 | if ((max_sent_length < 0) or (len(chars) < max_sent_length)) and (len(chars) > 0): 38 | gazs = [] 39 | gaz_ids = [] 40 | s_length = len(chars) 41 | for idx in range(s_length): 42 | matched_list = gaz.enumerateMatchList(chars[idx:]) 43 | matched_length = [len(a) for a in matched_list] 44 | gazs.append(matched_list) 45 | matched_id = [gaz_alphabet.get_index(entity) for entity in matched_list] 46 | if matched_id: 47 | gaz_ids.append([matched_id, matched_length]) 48 | else: 49 | gaz_ids.append([]) 50 | instance_texts.append([chars, gazs, labels]) 51 | instance_ids.append([char_ids, gaz_ids, label_ids]) 52 | elif len(chars) < max_sent_length: 53 | cut_num += 1 54 | chars = [] 55 | labels = [] 56 | char_ids = [] 57 | label_ids = [] 58 | gazs = [] 59 | gaz_ids = [] 60 | return instance_texts, instance_ids, cut_num 61 | 62 | 63 | def build_pretrain_embedding(embedding_path, alphabet, skip_first_row=False, separator=" ", embedd_dim=100, norm=True): 64 | embedd_dict = dict() 65 | if embedding_path != None: 66 | embedd_dict, embedd_dim = load_pretrain_emb(embedding_path, skip_first_row, separator) 67 | scale = np.sqrt(3.0 / embedd_dim) 68 | pretrain_emb = np.empty([alphabet.size(), embedd_dim]) 69 | perfect_match = 0 70 | case_match = 0 71 | not_match = 0 72 | for alph, index in alphabet.iteritems(): 73 | if alph in embedd_dict: 74 | if norm: 75 | pretrain_emb[index, :] = norm2one(embedd_dict[alph]) 76 | else: 77 | pretrain_emb[index, :] = embedd_dict[alph] 78 | perfect_match += 1 79 | elif alph.lower() in embedd_dict: 80 | if norm: 81 | pretrain_emb[index, :] = norm2one(embedd_dict[alph.lower()]) 82 | else: 83 | pretrain_emb[index, :] = embedd_dict[alph.lower()] 84 | case_match += 1 85 | else: 86 | pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedd_dim]) 87 | not_match += 1 88 | pretrained_size = len(embedd_dict) 89 | print("Embedding: %s\n pretrain num:%s, prefect match:%s, case_match:%s, oov:%s, oov%%:%s" % ( 90 | embedding_path, pretrained_size, perfect_match, case_match, not_match, (not_match + 0.) / alphabet.size())) 91 | return pretrain_emb, embedd_dim 92 | 93 | 94 | def norm2one(vec): 95 | root_sum_square = np.sqrt(np.sum(np.square(vec))) 96 | return vec / root_sum_square 97 | 98 | 99 | def load_pretrain_emb(embedding_path, skip_first_row=False, separator=" "): 100 | embedd_dim = -1 101 | embedd_dict = dict() 102 | with open(embedding_path, 'r') as file: 103 | i = 0 104 | j = 0 105 | for line in file: 106 | if i == 0: 107 | i = i + 1 108 | if skip_first_row: 109 | _ = line.strip() 110 | continue 111 | j = j+1 112 | line = line.strip() 113 | if len(line) == 0: 114 | continue 115 | tokens = line.split(separator) 116 | if embedd_dim < 0: 117 | embedd_dim = len(tokens) - 1 118 | else: 119 | if embedd_dim + 1 == len(tokens): 120 | embedd = np.empty([1, embedd_dim]) 121 | embedd[:] = tokens[1:] 122 | embedd_dict[tokens[0]] = embedd 123 | else: 124 | continue 125 | return embedd_dict, embedd_dim 126 | -------------------------------------------------------------------------------- /utils/gazetter.py: -------------------------------------------------------------------------------- 1 | from utils.trie import Trie 2 | 3 | class Gazetteer: 4 | def __init__(self, lower):#lower = False 5 | self.trie = Trie() 6 | self.ent2type = {} ## word list to type 7 | self.ent2id = {"":0} ## word list to id 8 | self.lower = lower 9 | self.space = "" 10 | 11 | def enumerateMatchList(self, word_list): 12 | if self.lower: 13 | word_list = [word.lower() for word in word_list] 14 | match_list = self.trie.enumerateMatch(word_list, self.space) 15 | return match_list 16 | 17 | def insert(self, word_list, source): 18 | if self.lower: 19 | word_list = [word.lower() for word in word_list] 20 | self.trie.insert(word_list) 21 | string = self.space.join(word_list) 22 | if string not in self.ent2type: 23 | self.ent2type[string] = source 24 | if string not in self.ent2id: 25 | self.ent2id[string] = len(self.ent2id) 26 | 27 | def searchId(self, word_list): 28 | if self.lower: 29 | word_list = [word.lower() for word in word_list] 30 | string = self.space.join(word_list) 31 | if string in self.ent2id: 32 | return self.ent2id[string] 33 | return self.ent2id[""] 34 | 35 | def searchType(self, word_list): 36 | if self.lower: 37 | word_list = [word.lower() for word in word_list] 38 | string = self.space.join(word_list) 39 | if string in self.ent2type: 40 | return self.ent2type[string] 41 | print("Error in finding entity type at gazetteer.py, exit program! String:", string) 42 | exit(0) 43 | 44 | def size(self): 45 | return len(self.ent2type) 46 | 47 | def clean(self): 48 | self.trie = Trie() 49 | self.ent2type = {} 50 | self.ent2id = {"": 0} 51 | self.space = "" 52 | 53 | -------------------------------------------------------------------------------- /utils/graph_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def seq_gaz(batch_gaz_ids): 5 | gaz_len = [] 6 | gaz_list = [] 7 | for gaz_id in batch_gaz_ids: 8 | gaz = [] 9 | length = 0 10 | for ele in gaz_id: 11 | if ele: 12 | length = length + len(ele[0]) 13 | for j in range(len(ele[0])): 14 | gaz.append(ele[0][j]) 15 | gaz_list.append(gaz) 16 | gaz_len.append(length) 17 | return gaz_list, gaz_len, max(gaz_len) 18 | 19 | 20 | def graph_generator(input): 21 | max_gaz_len, max_seq_len, gaz_ids = input 22 | gaz_seq = [] 23 | sentence_len = len(gaz_ids) 24 | gaz_len = 0 25 | for ele in gaz_ids: 26 | if ele: 27 | gaz_len += len(ele[0]) 28 | matrix_size = max_gaz_len + max_seq_len 29 | t_matrix = np.eye(matrix_size, dtype=int) 30 | l_matrix = np.eye(matrix_size, dtype=int) 31 | c_matrix = np.eye(matrix_size, dtype=int) 32 | add_matrix1 = np.zeros((matrix_size, matrix_size), dtype=int) 33 | add_matrix2 = np.zeros((matrix_size, matrix_size), dtype=int) 34 | add_matrix1[:sentence_len, :sentence_len] = np.eye(sentence_len, k=1, dtype=int) 35 | add_matrix2[:sentence_len, :sentence_len] = np.eye(sentence_len, k=-1, dtype=int) 36 | t_matrix = t_matrix + add_matrix1 + add_matrix2 37 | l_matrix = l_matrix + add_matrix1 + add_matrix2 38 | # give word a index 39 | word_id = [[]] * sentence_len 40 | index = max_seq_len 41 | for i in range(sentence_len): 42 | if gaz_ids[i]: 43 | word_id[i] = [0] * len(gaz_ids[i][1]) 44 | for j in range(len(gaz_ids[i][1])): 45 | word_id[i][j] = index 46 | index = index + 1 47 | index_gaz = max_seq_len 48 | index_char = 0 49 | for k in range(len(gaz_ids)): 50 | ele = gaz_ids[k] 51 | if ele: 52 | for i in range(len(ele[0])): 53 | gaz_seq.append(ele[0][i]) 54 | l_matrix[index_gaz, index_char] = 1 55 | l_matrix[index_char, index_gaz] = 1 56 | l_matrix[index_gaz, index_char + ele[1][i] - 1] = 1 57 | l_matrix[index_char + ele[1][i] - 1, index_gaz] = 1 58 | for m in range(ele[1][i]): 59 | c_matrix[index_gaz, index_char + m] = 1 60 | c_matrix[index_char + m, index_gaz] = 1 61 | # char and word connection 62 | if index_char > 0: 63 | t_matrix[index_gaz, index_char - 1] = 1 64 | t_matrix[index_char - 1, index_gaz] = 1 65 | 66 | if index_char + ele[1][i] < sentence_len: 67 | t_matrix[index_gaz, index_char + ele[1][i]] = 1 68 | t_matrix[index_char + ele[1][i], index_gaz] = 1 69 | else: 70 | t_matrix[index_gaz, index_char + ele[1][i]] = 1 71 | t_matrix[index_char + ele[1][i], index_gaz] = 1 72 | # word and word connection 73 | if index_char + ele[1][i] < sentence_len: 74 | if gaz_ids[index_char + ele[1][i]]: 75 | for p in range(len(gaz_ids[index_char + ele[1][i]][1])): 76 | q = word_id[index_char + ele[1][i]][p] 77 | t_matrix[index_gaz, q] = 1 78 | t_matrix[q, index_gaz] = 1 79 | index_gaz = index_gaz + 1 80 | index_char = index_char + 1 81 | return (t_matrix, c_matrix, l_matrix) 82 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import sys 4 | import os 5 | 6 | 7 | ## input as sentence level labels 8 | def get_ner_fmeasure(golden_lists, predict_lists, label_type="BMES"): 9 | sent_num = len(golden_lists) 10 | golden_full = [] 11 | predict_full = [] 12 | right_full = [] 13 | right_tag = 0 14 | all_tag = 0 15 | for idx in range(0, sent_num): 16 | # word_list = sentence_lists[idx] 17 | golden_list = golden_lists[idx] 18 | predict_list = predict_lists[idx] 19 | for idy in range(len(golden_list)): 20 | if golden_list[idy] == predict_list[idy]: 21 | right_tag += 1 22 | all_tag += len(golden_list) 23 | if label_type == "BMES": 24 | gold_matrix = get_ner_BMES(golden_list) 25 | pred_matrix = get_ner_BMES(predict_list) 26 | else: 27 | gold_matrix = get_ner_BIO(golden_list) 28 | pred_matrix = get_ner_BIO(predict_list) 29 | # print "gold", gold_matrix 30 | # print "pred", pred_matrix 31 | right_ner = list(set(gold_matrix).intersection(set(pred_matrix))) 32 | golden_full += gold_matrix 33 | predict_full += pred_matrix 34 | right_full += right_ner 35 | right_num = len(right_full) 36 | golden_num = len(golden_full) 37 | predict_num = len(predict_full) 38 | if predict_num == 0: 39 | precision = -1 40 | else: 41 | precision = (right_num + 0.0) / predict_num 42 | if golden_num == 0: 43 | recall = -1 44 | else: 45 | recall = (right_num + 0.0) / golden_num 46 | if (precision == -1) or (recall == -1) or (precision + recall) <= 0.: 47 | f_measure = -1 48 | else: 49 | f_measure = 2 * precision * recall / (precision + recall) 50 | accuracy = (right_tag + 0.0) / all_tag 51 | # print "Accuracy: ", right_tag,"/",all_tag,"=",accuracy 52 | print("gold_num = ", golden_num, " pred_num = ", predict_num, " right_num = ", right_num) 53 | return accuracy, precision, recall, f_measure 54 | 55 | 56 | def reverse_style(input_string): 57 | target_position = input_string.index('[') 58 | input_len = len(input_string) 59 | output_string = input_string[target_position:input_len] + input_string[0:target_position] 60 | return output_string 61 | 62 | 63 | def get_ner_BMES(label_list): 64 | # list_len = len(word_list) 65 | # assert(list_len == len(label_list)), "word list size unmatch with label list" 66 | list_len = len(label_list) 67 | begin_label = 'B-' 68 | end_label = 'E-' 69 | single_label = 'S-' 70 | whole_tag = '' 71 | index_tag = '' 72 | tag_list = [] 73 | stand_matrix = [] 74 | for i in range(0, list_len): 75 | # wordlabel = word_list[i] 76 | current_label = label_list[i].upper() 77 | if begin_label in current_label: 78 | if index_tag != '': 79 | tag_list.append(whole_tag + ',' + str(i - 1)) 80 | whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i) 81 | index_tag = current_label.replace(begin_label, "", 1) 82 | 83 | elif single_label in current_label: 84 | if index_tag != '': 85 | tag_list.append(whole_tag + ',' + str(i - 1)) 86 | whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i) 87 | tag_list.append(whole_tag) 88 | whole_tag = "" 89 | index_tag = "" 90 | elif end_label in current_label: 91 | if index_tag != '': 92 | tag_list.append(whole_tag + ',' + str(i)) 93 | whole_tag = '' 94 | index_tag = '' 95 | else: 96 | continue 97 | if (whole_tag != '') & (index_tag != ''): 98 | tag_list.append(whole_tag) 99 | tag_list_len = len(tag_list) 100 | 101 | for i in range(0, tag_list_len): 102 | if len(tag_list[i]) > 0: 103 | tag_list[i] = tag_list[i] + ']' 104 | insert_list = reverse_style(tag_list[i]) 105 | stand_matrix.append(insert_list) 106 | # print stand_matrix 107 | return stand_matrix 108 | 109 | 110 | def get_ner_BIO(label_list): 111 | # list_len = len(word_list) 112 | # assert(list_len == len(label_list)), "word list size unmatch with label list" 113 | list_len = len(label_list) 114 | begin_label = 'B-' 115 | inside_label = 'I-' 116 | whole_tag = '' 117 | index_tag = '' 118 | tag_list = [] 119 | stand_matrix = [] 120 | for i in range(0, list_len): 121 | # wordlabel = word_list[i] 122 | current_label = label_list[i].upper() 123 | if begin_label in current_label: 124 | if index_tag == '': 125 | whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i) 126 | index_tag = current_label.replace(begin_label, "", 1) 127 | else: 128 | tag_list.append(whole_tag + ',' + str(i - 1)) 129 | whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i) 130 | index_tag = current_label.replace(begin_label, "", 1) 131 | 132 | elif inside_label in current_label: 133 | if current_label.replace(inside_label, "", 1) == index_tag: 134 | whole_tag = whole_tag 135 | else: 136 | if (whole_tag != '') & (index_tag != ''): 137 | tag_list.append(whole_tag + ',' + str(i - 1)) 138 | whole_tag = '' 139 | index_tag = '' 140 | else: 141 | if (whole_tag != '') & (index_tag != ''): 142 | tag_list.append(whole_tag + ',' + str(i - 1)) 143 | whole_tag = '' 144 | index_tag = '' 145 | 146 | if (whole_tag != '') & (index_tag != ''): 147 | tag_list.append(whole_tag) 148 | tag_list_len = len(tag_list) 149 | 150 | for i in range(0, tag_list_len): 151 | if len(tag_list[i]) > 0: 152 | tag_list[i] = tag_list[i] + ']' 153 | insert_list = reverse_style(tag_list[i]) 154 | stand_matrix.append(insert_list) 155 | return stand_matrix 156 | 157 | 158 | def readSentence(input_file): 159 | in_lines = open(input_file, 'r').readlines() 160 | sentences = [] 161 | labels = [] 162 | sentence = [] 163 | label = [] 164 | for line in in_lines: 165 | if len(line) < 2: 166 | sentences.append(sentence) 167 | labels.append(label) 168 | sentence = [] 169 | label = [] 170 | else: 171 | pair = line.strip('\n').split(' ') 172 | sentence.append(pair[0]) 173 | label.append(pair[-1]) 174 | return sentences, labels 175 | 176 | 177 | def readTwoLabelSentence(input_file, pred_col=-1): 178 | in_lines = open(input_file, 'r').readlines() 179 | sentences = [] 180 | predict_labels = [] 181 | golden_labels = [] 182 | sentence = [] 183 | predict_label = [] 184 | golden_label = [] 185 | for line in in_lines: 186 | if "##score##" in line: 187 | continue 188 | if len(line) < 2: 189 | sentences.append(sentence) 190 | golden_labels.append(golden_label) 191 | predict_labels.append(predict_label) 192 | sentence = [] 193 | golden_label = [] 194 | predict_label = [] 195 | else: 196 | pair = line.strip('\n').split(' ') 197 | sentence.append(pair[0]) 198 | golden_label.append(pair[1]) 199 | predict_label.append(pair[pred_col]) 200 | 201 | return sentences, golden_labels, predict_labels 202 | 203 | 204 | def fmeasure_from_file(golden_file, predict_file, label_type="BMES"): 205 | print("Get f measure from file:", golden_file, predict_file) 206 | print("Label format:", label_type) 207 | golden_sent, golden_labels = readSentence(golden_file) 208 | predict_sent, predict_labels = readSentence(predict_file) 209 | acc, P, R, F = get_ner_fmeasure(golden_labels, predict_labels, label_type) 210 | print("Acc:%s, P:%s R:%s, F:%s" % (acc, P, R, F)) 211 | 212 | 213 | def fmeasure_from_singlefile(twolabel_file, label_type="BMES", pred_col=-1): 214 | sent, golden_labels, predict_labels = readTwoLabelSentence(twolabel_file, pred_col) 215 | P, R, F = get_ner_fmeasure(golden_labels, predict_labels, label_type) 216 | print("P:%s, R:%s, F:%s" % (P, R, F)) 217 | 218 | 219 | if __name__ == '__main__': 220 | # print "sys:",len(sys.argv) 221 | if len(sys.argv) == 3: 222 | fmeasure_from_singlefile(sys.argv[1], "BMES", int(sys.argv[2])) 223 | else: 224 | fmeasure_from_singlefile(sys.argv[1], "BMES") 225 | 226 | -------------------------------------------------------------------------------- /utils/trie.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | 4 | class TrieNode: 5 | # Initialize your data structure here. 6 | def __init__(self): 7 | self.children = collections.defaultdict(TrieNode) 8 | self.is_word = False 9 | 10 | 11 | class Trie: 12 | def __init__(self): 13 | self.root = TrieNode() 14 | 15 | def insert(self, word): 16 | 17 | current = self.root 18 | for letter in word: 19 | current = current.children[letter] 20 | current.is_word = True 21 | 22 | def search(self, word): 23 | current = self.root 24 | for letter in word: 25 | current = current.children.get(letter) 26 | 27 | if current is None: 28 | return False 29 | return current.is_word 30 | 31 | def startsWith(self, prefix): 32 | current = self.root 33 | for letter in prefix: 34 | current = current.children.get(letter) 35 | if current is None: 36 | return False 37 | return True 38 | 39 | def enumerateMatch(self, word, space="_", backward=False): 40 | matched = [] 41 | ## while len(word) > 1 does not keep character itself, while word keed character itself 42 | while len(word) > 1: 43 | if self.search(word): 44 | matched.append(space.join(word[:])) 45 | del word[-1] 46 | return matched 47 | 48 | --------------------------------------------------------------------------------