├── .gitignore ├── KB ├── kb_processing.ipynb └── kb_processing_modify.ipynb ├── NER ├── bert.py ├── bert_main.py ├── ccks_bert.cfg ├── ccks_run.sh ├── crf.py ├── data.py ├── data │ ├── questions_ws.txt │ └── train_sp_embed ├── eval.py ├── word_conll ├── ws.py ├── ws_old.py └── 实体识别的优化与实体链接.ipynb ├── PathRanking ├── model │ ├── args.py │ ├── bert_function.py │ ├── data.py │ ├── field.py │ ├── main.py │ ├── model.py │ ├── process_test.py │ └── train.sh ├── predict.py ├── predict_stage1.sh ├── predict_stage2.sh ├── search_path_stage2.sh └── utils │ ├── ans_tools.py │ └── search_ans.py ├── PreScreen ├── data │ ├── ans_tools.py │ ├── count.py │ ├── merge_path.py │ ├── mix_paths.py │ ├── onehop_path.py │ ├── search_ans.py │ └── search_ans.sh ├── modules │ ├── charlstm.py │ └── model.py ├── preprocess │ ├── cand_relations_50000.txt │ ├── check.py │ ├── data.ipynb │ └── func.py └── utils │ └── corpus.py ├── Preprocess.ipynb ├── Question_classification ├── BERT_LSTM_char │ ├── bert.py │ └── main.py ├── BERT_LSTM_word │ ├── args.py │ ├── bert.py │ ├── compare.py │ ├── main.py │ ├── run.sh │ └── test_acc.py └── data │ └── convert_data.py ├── README.md ├── evaluation_answer.ipynb ├── question_classes.png ├── results.png ├── test.json ├── train.json ├── utils.py └── valid.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /NER/bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_pretrained_bert.modeling import BertModel 4 | from pytorch_pretrained_bert.tokenization import BertTokenizer 5 | from torch.nn.utils.rnn import pad_sequence 6 | from crf import CRF 7 | import unicodedata 8 | from torch.nn.utils.rnn import pack_padded_sequence as pack 9 | from torch.nn.utils.rnn import pad_packed_sequence as pad 10 | 11 | def _is_whitespace(char): 12 | """Checks whether `chars` is a whitespace character.""" 13 | # \t, \n, and \r are technically contorl characters but we treat them 14 | # as whitespace since they are generally considered as such. 15 | if char == " " or char == "\t" or char == "\n" or char == "\r": 16 | return True 17 | cat = unicodedata.category(char) 18 | if cat == "Zs": 19 | return True 20 | return False 21 | 22 | 23 | def _is_control(char): 24 | """Checks whether `chars` is a control character.""" 25 | # These are technically control characters but we count them as whitespace 26 | # characters. 27 | if char == "\t" or char == "\n" or char == "\r": 28 | return False 29 | cat = unicodedata.category(char) 30 | if cat.startswith("C"): 31 | return True 32 | return False 33 | 34 | 35 | def _is_punctuation(char): 36 | """Checks whether `chars` is a punctuation character.""" 37 | cp = ord(char) 38 | # We treat all non-letter/number ASCII as punctuation. 39 | # Characters such as "^", "$", and "`" are not in the Unicode 40 | # Punctuation class but we treat them as punctuation anyways, for 41 | # consistency. 42 | if ( 43 | (cp >= 33 and cp <= 47) 44 | or (cp >= 58 and cp <= 64) 45 | or (cp >= 91 and cp <= 96) 46 | or (cp >= 123 and cp <= 126) 47 | ): 48 | return True 49 | cat = unicodedata.category(char) 50 | if cat.startswith("P"): 51 | return True 52 | return False 53 | 54 | 55 | def _clean_text(text): 56 | output = [] 57 | for char in text: 58 | cp = ord(char) 59 | if cp == 0 or cp == 0xFFFD or _is_control(char): 60 | continue 61 | if _is_whitespace(char): 62 | output.append(" ") 63 | else: 64 | output.append(char) 65 | return "".join(output) 66 | 67 | 68 | def judge_ignore(word): 69 | if len(_clean_text(word)) == 0: 70 | return True 71 | for char in word: 72 | cp = ord(char) 73 | if cp == 0 or cp == 0xFFFD or _is_control(char): 74 | return True 75 | return False 76 | 77 | def flatten(list_of_lists): 78 | for list in list_of_lists: 79 | for item in list: 80 | yield item 81 | 82 | class Vocab(object): 83 | def __init__(self, bert_vocab_path): 84 | self.tokenizer = BertTokenizer.from_pretrained( 85 | bert_vocab_path, do_lower_case=False 86 | ) 87 | 88 | def convert_tokens_to_ids(self, tokens): 89 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens) 90 | ids = torch.tensor(token_ids, dtype=torch.long) 91 | mask = torch.ones(len(ids), dtype=torch.long) 92 | return ids, mask 93 | 94 | def subword_tokenize(self, tokens): 95 | subwords = list(map(self.tokenizer.tokenize, tokens)) 96 | subword_lengths = [1] + list(map(len, subwords)) + [1] 97 | subwords = ["[CLS]"] + list(flatten(subwords)) + ["[SEP]"] 98 | token_start_idxs = torch.cumsum(torch.tensor([0] + subword_lengths[:-1]), dim=0) 99 | return subwords, token_start_idxs 100 | 101 | def subword_tokenize_to_ids(self, tokens): 102 | tokens = ["[PAD]" if judge_ignore(t) else t for t in tokens] 103 | subwords, token_start_idxs = self.subword_tokenize(tokens) 104 | subword_ids, mask = self.convert_tokens_to_ids(subwords) 105 | token_starts = torch.zeros(len(subword_ids), dtype=torch.uint8) 106 | token_starts[token_start_idxs] = 1 107 | return subword_ids, mask, token_starts 108 | 109 | def tokenize(self, tokens): 110 | subwords = list(map(self.tokenizer.tokenize, tokens)) 111 | subword_lengths = [1] + list(map(len, subwords)) + [1] 112 | subwords = ["[CLS]"] + list(flatten(subwords)) + ["[SEP]"] 113 | return subwords 114 | 115 | class Bert_Tagger(nn.Module): 116 | def __init__(self,data): 117 | super(Bert_Tagger,self).__init__() 118 | self.bert = BertModel.from_pretrained(data.bert_path) 119 | self.args = data 120 | self.lstm = nn.LSTM(data.bert_embedding_size, data.hidden_dim, num_layers=data.lstm_layer,batch_first=True, bidirectional=data.bilstm) 121 | if data.use_crf: 122 | self.linear = nn.Linear(data.bert_embedding_size,data.label_alphabet_size+2) 123 | self.crf = CRF(data.label_alphabet_size, True) 124 | else: 125 | #self.linear = nn.Linear(data.bert_embedding_size, data.label_alphabet_size) 126 | if data.bilstm: 127 | self.linear = nn.Linear(data.hidden_dim*2, data.label_alphabet_size) 128 | else: 129 | self.linear = nn.Linear(data.hidden_dim,data.label_alphabet_size) 130 | self.dropout = nn.Dropout(data.dropout) 131 | 132 | def forward(self,subword_idxs,subword_masks,token_start,batch_label): 133 | #self.args.use_crf = True 134 | bert_outs, _ = self.bert( 135 | subword_idxs, 136 | token_type_ids=None, 137 | attention_mask=subword_masks, 138 | output_all_encoded_layers=False, 139 | ) 140 | lens = token_start.sum(dim=1) 141 | bert_outs = torch.split(bert_outs[token_start], lens.tolist()) 142 | bert_outs = pad_sequence(bert_outs, batch_first=True) 143 | max_len = bert_outs.size(1) 144 | mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1) 145 | 146 | 147 | # add lstm after bert 148 | sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True) 149 | reverse_idx = torch.sort(sorted_idx, dim=0)[1] 150 | bert_outs = bert_outs[sorted_idx] 151 | bert_outs = pack(bert_outs, sorted_lens, batch_first=True) 152 | bert_outs, hidden = self.lstm(bert_outs) 153 | bert_outs, _ = pad(bert_outs, batch_first=True) 154 | bert_outs = bert_outs[reverse_idx] 155 | 156 | 157 | out = self.linear(torch.tanh(bert_outs)) 158 | if self.args.use_crf: 159 | score, seq = self.crf.viterbi_decode(out, mask) 160 | else: 161 | batch_size = out.size(0) 162 | seq_len = out.size(1) 163 | out = out.view(-1, out.size(2)) 164 | _, seq = torch.max(out, 1) 165 | seq = seq.view(batch_size, seq_len) 166 | seq = mask.long() * seq 167 | return seq 168 | 169 | def neg_log_likehood(self,subword_idxs,subword_masks,token_start,batch_label): 170 | #self.args.use_crf = False 171 | bert_outs, _ = self.bert( 172 | subword_idxs, 173 | token_type_ids=None, 174 | attention_mask=subword_masks, 175 | output_all_encoded_layers=False, 176 | ) 177 | lens = token_start.sum(dim=1) 178 | 179 | #x = bert_outs[token_start] 180 | bert_outs = torch.split(bert_outs[token_start], lens.tolist()) 181 | bert_outs = pad_sequence(bert_outs, batch_first=True) 182 | max_len = bert_outs.size(1) 183 | mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1) 184 | 185 | 186 | # add lstm after bert 187 | sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True) 188 | reverse_idx = torch.sort(sorted_idx, dim=0)[1] 189 | bert_outs = bert_outs[sorted_idx] 190 | bert_outs = pack(bert_outs, sorted_lens, batch_first=True) 191 | bert_outs, hidden = self.lstm(bert_outs) 192 | bert_outs, _ = pad(bert_outs, batch_first=True) 193 | bert_outs = bert_outs[reverse_idx] 194 | 195 | 196 | out = self.linear(bert_outs) 197 | #out = self.dropout(out) 198 | if self.args.use_crf: 199 | loss = self.crf(out, mask, batch_label) 200 | score, seq = self.crf.viterbi_decode(out, mask) 201 | else: 202 | batch_size = out.size(0) 203 | seq_len = out.size(1) 204 | out = out.view(-1, out.size(2)) 205 | score = torch.nn.functional.log_softmax(out, 1) 206 | loss_function = nn.NLLLoss(ignore_index=0, reduction="sum") 207 | loss = loss_function(score, batch_label.view(-1)) 208 | _, seq = torch.max(score, 1) 209 | seq = seq.view(batch_size, seq_len) 210 | if self.args.average_loss: 211 | loss = loss / mask.float().sum() 212 | return loss,seq 213 | 214 | def extract_feature(self,subword_idxs,subword_masks,token_start,batch_label,layers): 215 | out_layers,outs = [],[] 216 | bert_outs, _ = self.bert( 217 | subword_idxs, 218 | token_type_ids=None, 219 | attention_mask=subword_masks, 220 | output_all_encoded_layers=True, 221 | ) 222 | lens = token_start.sum(dim=1) 223 | #x = bert_outs[token_start] 224 | #bert_outs = torch.split(bert_outs[token_start].cpu(), lens.tolist()) 225 | for layer in layers: 226 | out_layers.append(torch.split(bert_outs[layer][token_start].cpu(), lens.tolist())) 227 | batch_size = subword_idxs.size(0) 228 | for idx in range(batch_size): 229 | items = [] 230 | for idy,item in enumerate(out_layers): 231 | items.append(item[idx].unsqueeze(1)) 232 | outs.append(torch.cat(items,dim=1)) 233 | return outs 234 | -------------------------------------------------------------------------------- /NER/ccks_bert.cfg: -------------------------------------------------------------------------------- 1 | ###train config file 2 | #file path and dir 3 | train_file=data/train_bert_ner_input.txt 4 | dev_file=data/valid_bert_ner_input.txt 5 | #test_file=data/valid_bert_ner_input.txt 6 | test_file=data/test_bert_ner_input.txt 7 | #oov_file=../data/ali_7k/test_oov_pos_bi 8 | word_embed_path= 9 | word_embed_save=data/ccks/train_sp_embed 10 | char_embed_path= 11 | char_embed_save= 12 | bert_dim=768 13 | model_save_dir=snapshot/modelbest.pkl 14 | result_save_dir=result_shot/ 15 | model_path=snapshot/modelbest.pkl 16 | 17 | #hyperparameters 18 | use_char=False 19 | use_cuda=True 20 | use_crf=False 21 | use_elmo=False 22 | use_bert=True 23 | pretrain=True 24 | word_embed_dim=100 25 | char_embed_dim=30 26 | optimizer=BertAdam 27 | hidden_dim=150 28 | fine_tune=True 29 | elmo_fine_tune=True 30 | char_hidden_dim=150 31 | lstm_layer=1 32 | bilstm=True 33 | cnn_layer=1 34 | dropout=0.5 35 | lr=1e-5 36 | lr_decay=0.05 37 | momentum=0 38 | weight_decay=0 39 | iter=50 40 | batch_size=10 41 | attention=False 42 | lstm_attention=False 43 | attention_dim=300 44 | average_loss=True 45 | norm_word_emb=False 46 | norm_char_emb=False 47 | tag_scheme=BIOES 48 | number_normalized=True 49 | threshold=0 50 | max_sent_len=250 51 | entity_mask=False 52 | mask_percent=0.1 53 | stopwords= 54 | #feature=POS embed_dim=20 55 | hyperlstm=False 56 | hyper_hidden_dim=50 57 | hyper_emb_dim=512 58 | 59 | #other_config 60 | dataset=conll2003 61 | status=tag 62 | #status=train 63 | word_seq_feature=LSTM 64 | char_seq_feature=LSTM 65 | 66 | #bert 67 | bert_path=../data/bert-base-chinese.tar.gz 68 | bert_embedding_size=768 69 | bert_vocab_path=../data/bert-base-chinese-vocab.txt -------------------------------------------------------------------------------- /NER/ccks_run.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=1 2 | export CUDA_VISIBLE_DEVICES=1 3 | 4 | nohup python -u bert_main.py --config ccks_bert.cfg >ccks_log/log_1 2>&1 & 5 | tail -f ccks_log/log_1 6 | 7 | -------------------------------------------------------------------------------- /NER/crf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.autograd as autograd 4 | import torch.nn.functional as F 5 | 6 | START_TAG = -2 7 | STOP_TAG = -1 8 | 9 | 10 | class CRF(nn.Module): 11 | def __init__(self, target_size, gpu): 12 | super(CRF, self).__init__() 13 | self.target_size = target_size 14 | self.gpu = gpu 15 | transition_mat = torch.zeros(self.target_size + 2, self.target_size + 2) 16 | transition_mat[:, START_TAG] = -10000.0 17 | transition_mat[STOP_TAG, :] = -10000.0 18 | transition_mat[:, 0] = -10000.0 # pad index 19 | transition_mat[0, :] = -10000.0 20 | if self.gpu: 21 | transition_mat = transition_mat.cuda() 22 | self.transitions = nn.Parameter(transition_mat) 23 | 24 | def forward(self, input, mask, tags): 25 | forward_scores, score = self.cal_forward_score(input, mask) 26 | gold_score = self.cal_gold_score(score, mask, tags) 27 | loss = forward_scores - gold_score 28 | return loss 29 | 30 | def cal_forward_score(self, input, mask): 31 | batch_size = input.size(0) 32 | seq_len = input.size(1) 33 | tag_size = input.size(2) 34 | assert (tag_size == self.target_size+2) 35 | mask = mask.transpose(1, 0).contiguous() 36 | ins_num = seq_len * batch_size 37 | inputs = input.transpose(1, 0).contiguous().view(ins_num, 1, 38 | tag_size).expand(ins_num, tag_size, tag_size) 39 | # emit score + T 40 | scores = inputs + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) 41 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 42 | seq_iter = enumerate(scores) 43 | _, inivalues = next(seq_iter) 44 | # from start tag to word 45 | partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) 46 | 47 | for idx, cur_values in seq_iter: 48 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, 49 | tag_size) 50 | cur_partition = log_sum_up(cur_values, tag_size) 51 | mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) 52 | masked_cur_partition = cur_partition.masked_select(mask_idx) 53 | mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) 54 | partition.masked_scatter_(mask_idx, masked_cur_partition) 55 | 56 | # from last word to stop tag 57 | cur_values = self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, 58 | tag_size) + partition.contiguous().view( 59 | batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) 60 | cur_partition = log_sum_up(cur_values, tag_size) 61 | final_partition = cur_partition[:, STOP_TAG] 62 | return final_partition.sum(), scores 63 | 64 | def cal_gold_score(self, scores, mask, tags): 65 | batch_size = scores.size(1) 66 | seq_len = scores.size(0) 67 | tag_size = scores.size(2) 68 | new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len)) 69 | if self.gpu: 70 | new_tags = new_tags.cuda() 71 | for idx in range(seq_len): 72 | if idx == 0: 73 | new_tags[:, 0] = (tag_size - 2) * tag_size + tags[:, 0] 74 | else: 75 | new_tags[:, idx] = tags[:, idx - 1] * tag_size + tags[:, idx] 76 | end_transition = self.transitions[:, STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) 77 | length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long() 78 | # cal end word id,length-1 79 | end_ids = torch.gather(tags, 1, length_mask - 1) 80 | end_energy = torch.gather(end_transition, 1, end_ids) 81 | new_tags = new_tags.transpose(1, 0).contiguous().view(seq_len, batch_size, 1) 82 | tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) 83 | tg_energy = tg_energy.masked_select(mask.transpose(1, 0)) 84 | 85 | return tg_energy.sum() + end_energy.sum() 86 | 87 | def viterbi_decode(self, feats, mask): 88 | batch_size = feats.size(0) 89 | seq_len = feats.size(1) 90 | tag_size = feats.size(2) 91 | assert (tag_size == self.target_size + 2) 92 | length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long() 93 | mask = mask.transpose(1, 0).contiguous() 94 | ins_num = seq_len * batch_size 95 | feats = feats.transpose(1, 0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) 96 | scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) 97 | scores = scores.view(seq_len, batch_size, tag_size, tag_size) 98 | seq_iter = enumerate(scores) 99 | back_points = list() 100 | partition_history = list() 101 | mask = (1 - mask.long()).byte() 102 | _, inivalues = next(seq_iter) 103 | partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size) 104 | partition_history.append(partition) 105 | for idx, cur_values in seq_iter: 106 | cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, 107 | tag_size) 108 | partition, cur_bp = torch.max(cur_values, 1) 109 | partition_history.append(partition) 110 | cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) 111 | back_points.append(cur_bp) 112 | partition_history = torch.cat(partition_history, 0).view(seq_len, batch_size, -1).transpose(1, 0).contiguous() 113 | last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1 114 | last_partition = torch.gather(partition_history, 1, last_position).view(batch_size, tag_size, 1) 115 | last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1, tag_size, 116 | tag_size).expand( 117 | batch_size, tag_size, tag_size) 118 | _, last_bp = torch.max(last_values, 1) 119 | pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long() 120 | if self.gpu: 121 | pad_zero = pad_zero.cuda() 122 | back_points.append(pad_zero) 123 | back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) 124 | 125 | pointer = last_bp[:, STOP_TAG] 126 | insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size) 127 | back_points = back_points.transpose(1, 0).contiguous() 128 | back_points.scatter_(1, last_position, insert_last) 129 | back_points = back_points.transpose(1, 0).contiguous() 130 | decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size)) 131 | if self.gpu: 132 | decode_idx = decode_idx.cuda() 133 | decode_idx[-1] = pointer.data 134 | for idx in range(len(back_points) - 2, -1, -1): 135 | pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) 136 | decode_idx[idx] = pointer.detach().view(batch_size) 137 | path_score = None 138 | decode_idx = decode_idx.transpose(1, 0) 139 | return path_score, decode_idx 140 | 141 | 142 | def log_sum_up(vec, m_size): 143 | max_score, idx = torch.max(vec, 1) 144 | max_score_broadcast = max_score.unsqueeze(1).view(-1, 1, m_size).expand(-1, m_size, m_size) 145 | return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast), 1)) 146 | -------------------------------------------------------------------------------- /NER/data/train_sp_embed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThisIsSoMe/CCKS2019-CKBQA/142169a07b285d147cb70c2d118b2d7df3cd7836/NER/data/train_sp_embed -------------------------------------------------------------------------------- /NER/eval.py: -------------------------------------------------------------------------------- 1 | def seq_eval(data, pred, gold, mask, recover): 2 | pred_list = [] 3 | gold_list = [] 4 | pred = pred[recover] 5 | gold = gold[recover] 6 | mask = mask[recover] 7 | batch_size = gold.size(0) 8 | seq_len = gold.size(1) 9 | pred_tag = pred.cpu().data.numpy() 10 | gold_tag = gold.cpu().data.numpy() 11 | mask = mask.cpu().data.numpy() 12 | for idx in range(batch_size): 13 | pred = [data.label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] 14 | gold = [data.label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] 15 | pred_list.append(pred) 16 | gold_list.append(gold) 17 | return gold_list, pred_list 18 | 19 | def bert_eval(data, pred, gold, mask): 20 | pred_list = [] 21 | gold_list = [] 22 | batch_size = gold.size(0) 23 | seq_len = gold.size(1) 24 | pred_tag = pred.cpu().data.numpy() 25 | gold_tag = gold.cpu().data.numpy() 26 | mask = mask.cpu().data.numpy() 27 | for idx in range(batch_size): 28 | pred = [data.label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] 29 | gold = [data.label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] 30 | pred_list.append(pred) 31 | gold_list.append(gold) 32 | return gold_list, pred_list 33 | 34 | 35 | def get_ner_measure(pred, gold, scheme): 36 | sen_num = len(pred) 37 | predict_num = correct_num = gold_num = 0 38 | for idx in range(sen_num): 39 | if scheme == "BIO": 40 | gold_entity = get_entity(gold[idx]) 41 | pred_entity = get_entity(pred[idx]) 42 | predict_num += len(pred_entity) 43 | gold_num += len(gold_entity) 44 | correct_num += len(list(set(gold_entity).intersection(set(pred_entity)))) 45 | elif scheme == "BIOES": # "BMES" 46 | gold_entity = get_BIOES_entity(gold[idx]) 47 | pred_entity = get_BIOES_entity(pred[idx]) 48 | predict_num += len(pred_entity) 49 | gold_num += len(gold_entity) 50 | correct_num += len(list(set(gold_entity).intersection(set(pred_entity)))) 51 | else: 52 | raise RuntimeError("Scheme Error") 53 | return predict_num, correct_num, gold_num 54 | 55 | def get_BIOES_entity(label_list): 56 | sen_len = len(label_list) 57 | entity_list = [] 58 | entity = None 59 | entity_index = None 60 | for idx, current in enumerate(label_list): 61 | if "B-" in current: 62 | if entity is not None: 63 | entity_list.append("[" + entity_index + "]" + entity) 64 | entity = current.split("-")[1] 65 | entity_index = str(idx) 66 | elif "S-" in current: 67 | if entity is not None: 68 | entity_list.append("[" + entity_index + "]" + entity) 69 | entity = current.split("-")[1] 70 | entity_index = str(idx) 71 | elif "I-" in current or "E-" in current: 72 | if entity is not None: 73 | entity_index += str(idx) 74 | else: 75 | # print('single I start') 76 | continue 77 | entity = current.split("-")[1] 78 | entity_index = str(idx) 79 | elif "O" in current: 80 | if entity is not None: 81 | entity_list.append("[" + entity_index + "]" + entity) 82 | entity = None 83 | entity_index = None 84 | else: 85 | print("Label Error. current:{}".format(current)) 86 | if entity is not None: 87 | entity_list.append("[" + entity_index + "]" + entity) 88 | return entity_list 89 | 90 | def get_entity(label_list): 91 | entity_list = [] 92 | entity = None 93 | entity_index = None 94 | for idx, current in enumerate(label_list): 95 | if "B-" in current: 96 | if entity is not None: 97 | entity_list.append("[" + entity_index + "]" + entity) 98 | entity = current.split("-")[1] 99 | entity_index = str(idx) 100 | elif "I-" in current: 101 | if entity is not None: 102 | entity_index += str(idx) 103 | else: 104 | #print('single I start') 105 | continue 106 | entity = current.split("-")[1] 107 | entity_index = str(idx) 108 | else: 109 | # if current != 'O': 110 | # print(current) 111 | if entity is not None: 112 | entity_list.append("[" + entity_index + "]" + entity) 113 | entity = None 114 | entity_index = None 115 | if entity is not None: 116 | entity_list.append("[" + entity_index + "]" + entity) 117 | return entity_list 118 | 119 | 120 | def output_result(texts,pred_list,result_dir,info): 121 | with open(result_dir+'result_'+info,'w',encoding='utf-8') as fout: 122 | for idx,text in enumerate(texts): 123 | for idy,t in enumerate(text[0]): 124 | fout.write(t+'\t'+pred_list[idx][idy]+'\n') 125 | fout.write('\n') 126 | -------------------------------------------------------------------------------- /NER/ws.py: -------------------------------------------------------------------------------- 1 | import pymysql 2 | import pandas as pd 3 | import gc 4 | import jieba 5 | import re 6 | from sqlalchemy import create_engine 7 | import json 8 | 9 | read_con = pymysql.connect(host="192.168.126.143",port = 3337, user='root', password='pjzhang', database='ccks_2019',charset='utf8') 10 | save_con = create_engine('mysql+pymysql://root:pjzhang@192.168.126.143:3337/ccks_2019?charset=utf8') 11 | cur = read_con.cursor() 12 | 13 | END = '\n' 14 | 15 | # 去除‘的’字 16 | def delete_de(line): 17 | cut_line = jieba.lcut(line) 18 | for i in range(0, len(cut_line)): 19 | if(cut_line[i] == '的'): 20 | cut_line[i] = '\n' 21 | return ''.join(cut_line) 22 | 23 | # 去除标点,用‘\n’代替 24 | def delete_punc(line): 25 | # line = delete_de(line) 26 | punc = ['《', '》', '\"', '\'', '<', '>', '?', '?', ',', ',', ':'] 27 | # punc = [] 28 | res = '' 29 | for ch in line: 30 | if(ch not in punc): 31 | res += ch 32 | else: 33 | res += END 34 | return res 35 | 36 | # 根据三元组表pkubase查找实体,逆向最大匹配算法 37 | def search_entity_pkubase_backward(line, max_string = True): 38 | end = len(line)+1 39 | entity = [] 40 | new_line = '' 41 | while end > 0: 42 | begin = 0 43 | while (end > begin): 44 | # print('begin', begin, end = ' ') 45 | word = line[begin: end] 46 | # print('w', word) 47 | sql = "select count(*) from `pkubase` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % (word, word, word) 48 | cur.execute(sql) 49 | data = cur.fetchall() 50 | begin += 1 51 | if (data[0][0] > 0): 52 | # print(word) 53 | entity.append(word) 54 | # print('data', data[0:3]) 55 | if (max_string): # 是否最大匹配 56 | end = begin 57 | break 58 | sql = "select count(*) from `pkubase` where `value`='%s' or `value`='<%s>' or `value`='\"%s\"'" % (word, word, word) 59 | cur.execute(sql) 60 | data = cur.fetchall() 61 | if (data[0][0] > 0): 62 | # print(word) 63 | entity.append(word) 64 | # print('data', data[0:3]) 65 | if (max_string): # 是否最大匹配 66 | end = begin 67 | break 68 | sql = "select count(*) from `pkubase` where `prop`='%s' or `prop`='<%s>' or `prop`='\"%s\"'" % (word, word, word) 69 | cur.execute(sql) 70 | data = cur.fetchall() 71 | if(data[0][0] > 0): 72 | entity.append(word) 73 | if (max_string): # 是否最大匹配 74 | end = begin 75 | break 76 | sql = "select count(*) from `pkuorder` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % (word, word, word) 77 | cur.execute(sql) 78 | data = cur.fetchall() 79 | if (data[0][0] > 0): 80 | entity.append(word) 81 | if (max_string): # 是否最大匹配 82 | end = begin 83 | break 84 | end -= 1 85 | entity.reverse() 86 | return entity 87 | 88 | # 根据三元组表pkubase查找实体,正向最大匹配算法 89 | # def search_entity_pkubase_forward(line, max_string = True): 90 | # begin = 0 91 | # end = len(line)+1 92 | # entity = [] 93 | # while begin < len(line): 94 | # # print('begin', begin) 95 | # while(end >= begin): 96 | # word = line[begin: end] 97 | # # print('end', end, end = ' ') 98 | # sql = "select count(*) from `pkubase` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % (word, word, word) 99 | # cur.execute(sql) 100 | # data = cur.fetchall() 101 | # end -= 1 102 | # if(data[0][0] > 0): 103 | # # print('word', word) 104 | # entity.append(word) 105 | # if (max_string): # 是否最大匹配 106 | # begin = end 107 | # continue 108 | # sql = "select count(*) from `pkubase` where `value`='%s' or `value`='<%s>' or `value`='\"%s\"'" % (word, word, word) 109 | # cur.execute(sql) 110 | # data = cur.fetchall() 111 | # if (data[0][0] > 0): 112 | # # print('word', word) 113 | # entity.append(word) 114 | # if (max_string): # 是否最大匹配 115 | # begin = end 116 | # continue 117 | # sql = "select count(*) from `pkubase` where `prop`='%s' or `prop`='<%s>' or `prop`='\"%s\"'" % (word, word, word) 118 | # cur.execute(sql) 119 | # data = cur.fetchall() 120 | # if (data[0][0] > 0): 121 | # entity.append(word) 122 | # if (max_string): # 是否最大匹配 123 | # begin = end 124 | # continue 125 | # sql = "select count(*) from `pkuorder` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % ( 126 | # word, word, word) 127 | # cur.execute(sql) 128 | # data = cur.fetchall() 129 | # if (data[0][0] > 0): 130 | # entity.append(word) 131 | # if (max_string): # 是否最大匹配 132 | # begin = end 133 | # continue 134 | # begin += 1 135 | # return entity 136 | 137 | def search_entity_pkubase_forward(line, max_string = True): 138 | begin = 0 139 | end = len(line)+1 140 | entity = [] 141 | while begin < len(line): 142 | # print('begin', begin) 143 | while(end > begin): 144 | word = line[begin: end] 145 | # print('end', end, end = ' ') 146 | sql = "select count(*) from `pkubase` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % (word, word, word) 147 | cur.execute(sql) 148 | data = cur.fetchall() 149 | #end -= 1 150 | if(data[0][0] > 0): 151 | entity.append(word) 152 | if (max_string): # 是否最大匹配 153 | begin = end -1 154 | end = len(line)+1 155 | break 156 | sql = "select count(*) from `pkubase` where `value`='%s' or `value`='<%s>' or `value`='\"%s\"'" % (word, word, word) 157 | cur.execute(sql) 158 | data = cur.fetchall() 159 | if (data[0][0] > 0): 160 | entity.append(word) 161 | if (max_string): # 是否最大匹配 162 | begin = end -1 163 | end = len(line)+1 164 | break 165 | sql = "select count(*) from `pkubase` where `prop`='%s' or `prop`='<%s>' or `prop`='\"%s\"'" % (word, word, word) 166 | cur.execute(sql) 167 | data = cur.fetchall() 168 | if (data[0][0] > 0): 169 | entity.append(word) 170 | if (max_string): # 是否最大匹配 171 | begin = end-1 172 | end = len(line)+1 173 | break 174 | sql = "select count(*) from `pkuorder` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % (word, word, word) 175 | cur.execute(sql) 176 | data = cur.fetchall() 177 | if (data[0][0] > 0): 178 | entity.append(word) 179 | if (max_string): # 是否最大匹配 180 | begin = end-1 181 | end = len(line)+1 182 | break 183 | end -= 1 184 | begin += 1 185 | return entity 186 | 187 | # 选择频率较高的分词 188 | def select_cut_word(entity1, entity2): 189 | num1 = 0 190 | num2 = 0 191 | if(len(entity1) * 2 < len(entity2)): 192 | return entity1 193 | elif len(entity1) > len(entity2) * 2: 194 | return entity2 195 | for word in entity1: 196 | if(len(word) > 1): 197 | sql = "select count(*) from `pkubase` where `entry` = '<%s>' or `entry` = '%s' or `entry` = '\"%s\"' or `value` = '%s' or `value` = '<%s>' or `value` = '\"%s\"'" % (word, word, word, word, word, word) 198 | cur.execute(sql) 199 | data = cur.fetchall() 200 | num1 += data[0][0] 201 | sql = "select count(*) from `pkuorder` where `entry` = '%s'" % (word) 202 | cur.execute(sql) 203 | data = cur.fetchall() 204 | num1 += data[0][0] * 10000 205 | for word in entity2: 206 | if(len(word) > 1): 207 | sql = "select count(*) from `pkubase` where `entry` = '<%s>' or `entry` = '%s' or `entry` = '\"%s\"' or `value` = '%s' or `value` = '<%s>' or `value` = '\"%s\"'" % (word, word, word, word, word, word) 208 | cur.execute(sql) 209 | data = cur.fetchall() 210 | num2 += data[0][0] 211 | sql = "select count(*) from `pkuorder` where `entry` = '%s'" % (word) 212 | cur.execute(sql) 213 | data = cur.fetchall() 214 | num2 += data[0][0] * 10000 215 | if (num1 * 1.0 / len(entity1) > num2 * 1.0 / len(entity2) * 2):#差距比较大时才会直接做出选择 216 | print(num1, num1 * 1.0 / len(entity1), num2, num2 * 1.0 / len(entity2), entity1) 217 | return entity1 218 | elif (num1 * 1.0 / len(entity1) * 2 < num2 * 1.0 / len(entity2)): 219 | print(num1, num1 * 1.0 / len(entity1), num2, num2 * 1.0 / len(entity2), entity2) 220 | return entity2 221 | if '的' in entity1 and '的' not in entity2: 222 | return entity1 223 | elif '的' not in entity1 and '的' in entity2: 224 | return entity2 225 | if(num1 * 1.0 / len(entity1) >= num2 * 1.0 / len(entity2)): 226 | # print(num1, num1 * 1.0 / len(entity1), num2, num2 * 1.0 / len(entity2), entity1) 227 | return entity1 228 | else: 229 | # print(num1, num1 * 1.0 / len(entity1), num2, num2 * 1.0 / len(entity2), entity2) 230 | return entity2 231 | 232 | # 选择和结巴分词相近的分词结果 233 | def select_cut_word_by_jieba(entity1, entity2): 234 | # if '的' in entity1 and '的' not in entity2: 235 | # return entity1 236 | # elif '的' not in entity1 and '的' in entity2: 237 | # return entity2 238 | sen = ''.join(entity1) 239 | seg_list = jieba.lcut(sen) 240 | num1 = 0 241 | num2 = 0 242 | for word in entity1: 243 | if(word in seg_list): 244 | num1 += 1 245 | for word in entity2: 246 | if(word in seg_list): 247 | num2 += 1 248 | if(num1 > num2): 249 | return entity1 250 | elif num1 == num2: 251 | return select_cut_word(entity1, entity2) 252 | else: 253 | return entity2 254 | # 根据jieba分词后的个数比较 255 | def select_cut_word_by_jieba_num(entity1, entity2): 256 | num1 = 0 257 | num2 = 0 258 | max1 = 0 259 | max2 = 0 260 | for word in entity1: 261 | num1 += len(jieba.lcut(word)) 262 | if(len(word) > max1): 263 | max1 = len(word) 264 | for word in entity2: 265 | num2 += len(jieba.lcut(word)) 266 | if (len(word) > max2): 267 | max2 = len(word) 268 | if(num1 > num2): 269 | return entity2 270 | elif num1 < num2: 271 | return entity1 272 | elif(max1 > max2): 273 | return entity1 274 | elif(max1 < max2): 275 | return entity2 276 | else: 277 | return select_cut_word_by_jieba(entity1, entity2) 278 | 279 | 280 | #将正向最大匹配算法和逆向最大匹配算法结合的方法,即相互补充的思想 281 | def combine_forward_backward(entity1, entity2): 282 | i = 0 283 | j = 0 284 | temp1 = '' 285 | temp1_list = [] 286 | temp2 = '' 287 | temp2_list = [] 288 | entity_final = [] 289 | while(i < len(entity1)): 290 | temp1 += entity1[i] 291 | temp1_list.append(entity1[i]) 292 | i += 1 293 | if(len(temp1) == len(temp2)): 294 | if ('\t'.join(temp1_list) == '\t'.join(temp2_list)): 295 | entity_final.extend(temp1_list) 296 | temp1 = '' 297 | temp2 = '' 298 | temp1_list = [] 299 | temp2_list = [] 300 | else: 301 | # print('选择频率高的') 302 | # entity_final.extend(select_cut_word(temp1_list, temp2_list)) 303 | entity_final.extend(select_cut_word_by_jieba_num(temp1_list, temp2_list)) 304 | # entity_final.extend(select_cut_word_by_jieba(temp1_list, temp2_list)) 305 | temp1 ='' 306 | temp2 ='' 307 | temp1_list = [] 308 | temp2_list = [] 309 | elif(len(temp1) < len(temp2)): 310 | continue 311 | else: 312 | while(j < len(entity2)): 313 | # print('j', j) 314 | temp2 += entity2[j] 315 | temp2_list.append(entity2[j]) 316 | j += 1 317 | if(len(temp1) == len(temp2)): 318 | if('\t'.join(temp1_list) == '\t'.join(temp2_list)): 319 | entity_final.extend(temp1_list) 320 | temp1 = '' 321 | temp2 = '' 322 | temp1_list = [] 323 | temp2_list = [] 324 | else: 325 | # print('选择频率高的') 326 | # entity_final.extend(select_cut_word(temp1_list, temp2_list)) 327 | entity_final.extend(select_cut_word_by_jieba_num(temp1_list, temp2_list)) 328 | # entity_final.extend(select_cut_word_by_jieba(temp1_list, temp2_list)) 329 | temp1 ='' 330 | temp2 = '' 331 | temp1_list = [] 332 | temp2_list = [] 333 | elif(len(temp1) > len(temp2)): 334 | continue 335 | else: 336 | break 337 | return entity_final 338 | 339 | def compare_count(word1, word2): 340 | sql = "select count(*) from `pkubase` where `prop`='%s' or `prop`='<%s>' or `prop`='\"%s\"'" % (word1, word1, word1) 341 | cur.execute(sql) 342 | data = cur.fetchall() 343 | count1 = data[0][0] 344 | sql = "select count(*) from `pkubase` where `prop`='%s' or `prop`='<%s>' or `prop`='\"%s\"'" % (word2, word2, word2) 345 | cur.execute(sql) 346 | data = cur.fetchall() 347 | count2 = data[0][0] 348 | if(count1 > count2): 349 | return 1 350 | elif(count1 == count2): 351 | return 0 352 | else: 353 | return -1 354 | 355 | def jieba_count(word): 356 | cut_word = jieba.lcut(word) 357 | if(len(cut_word) == 1): 358 | return 1 359 | return 0 360 | 361 | def del_repetition(relation): 362 | entity = [] 363 | for item in relation: 364 | if(item not in entity): 365 | entity.append(item) 366 | return entity 367 | 368 | # 去掉开头和结尾的'的' 369 | def del_de_extra(word): 370 | # print('word1', word) 371 | cut_word = jieba.lcut(word) 372 | while(len(cut_word) > 0): 373 | if(cut_word[0] == '的'): 374 | cut_word[0] = '\n' 375 | elif(cut_word[-1] == '的'): 376 | cut_word[-1] = '\n' 377 | else: 378 | break 379 | # print('cut_word', cut_word) 380 | new_word = '\n'.join(cut_word) 381 | # print('new_word', new_word) 382 | # word = new_word.replace('的\n', '\n') 383 | # word = word.strip('的') 384 | word = ''.join(cut_word) 385 | # word = word.replace('\n', '') 386 | # print('word2', word) 387 | return word 388 | 389 | # 去除开始和结尾处的标点 390 | def delete_punc_extra(line): 391 | punc = ['《', '》', '\"', '\'', '<', '>', '?', '?', ',', ',', ':'] 392 | # punc = [] 393 | for item in punc: 394 | line = line.strip(item) 395 | return line 396 | 397 | # 先预过一遍,得到初步结果,然后去除‘的’ 398 | def pre_cut(line_with_punc): 399 | entity1 = search_entity_pkubase1(line_with_punc) 400 | entity2 = search_entity_pkubase2(line_with_punc) 401 | relation = combine_forward_backward(entity1, entity2) 402 | i = 0 403 | punc = ['《', '》', '\"', '\'', '<', '>'] 404 | while i < len(relation): 405 | if(i + 1 < len(relation)): 406 | if(relation[i + 1] not in punc): 407 | relation[i] = del_de_extra(relation[i]) 408 | else: 409 | relation[i] = del_de_extra(relation[i]) 410 | i += 1 411 | return ''.join(relation) 412 | 413 | with open("../data/test.json",'r')as f: 414 | all_test_data = json.load(f) 415 | questions = all_test_data[1] 416 | mentions = all_test_data[6] 417 | 418 | # 测试集切分keywords 419 | i = 0 420 | entity_all = [] 421 | while (i < len(questions)): 422 | relation = [] 423 | 424 | line_with_punc = questions[i] 425 | #line_with_punc = pre_cut(line_with_punc) 426 | line = line_with_punc 427 | print('line:', i, line) 428 | entity2 = search_entity_pkubase_forward(line)#正向最大匹配分词 429 | #entity_all.append(entity2) 430 | entity1 = search_entity_pkubase_backward(line)#逆向最大匹配分词 431 | #entity_all.append(entity1) 432 | # print('entity1', entity1) 433 | # print('entity2', entity2) 434 | relation = combine_forward_backward(entity1, entity2)#将正向和逆向结合,得到相对更合适的分词结果 435 | entity = [] 436 | for k in range(len(relation)): 437 | temp = delete_punc_extra(relation[k]) 438 | if(temp not in entity and len(temp) > 1): 439 | entity.append(temp) 440 | print('keywords', entity2) 441 | entity_all.append(entity) 442 | i += 1 443 | 444 | 445 | fn_out = 'data/questions_ws.txt' 446 | fp_out = open(fn_out, 'w', encoding='utf-8') 447 | lines = questions 448 | i = 0 449 | while (i < len(lines)): 450 | line_with_punc = lines[i] 451 | fp_out.write(line_with_punc + '\n') 452 | fp_out.write('\t'.join(entity_all[i]) + '\n') 453 | i += 1 454 | fp_out.close() -------------------------------------------------------------------------------- /NER/ws_old.py: -------------------------------------------------------------------------------- 1 | import pymysql 2 | import pandas as pd 3 | import gc 4 | import jieba 5 | import json 6 | import re 7 | 8 | read_con = pymysql.connect(host="192.168.126.143",port = 3337, user='root', password='pjzhang', database='ccks_2019',charset='utf8') 9 | from sqlalchemy import create_engine 10 | save_con = create_engine('mysql+pymysql://root:pjzhang@192.168.126.143:3337/ccks_2019?charset=utf8') 11 | cur = read_con.cursor() 12 | 13 | END = '\n' 14 | 15 | # 根据三元组表pkubase查找实体,逆向最大匹配算法 16 | def search_entity_pkubase_backward(line, max_string = True): 17 | end = len(line) 18 | entity = [] 19 | new_line = '' 20 | while end > 0: 21 | begin = 0 22 | # print('end', end) 23 | while (end > begin): 24 | # print('begin', begin, end = ' ') 25 | word = line[begin: end] 26 | # print('w', word) 27 | sql = "select count(*) from `pkubase` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % (word, word, word) 28 | cur.execute(sql) 29 | data = cur.fetchall() 30 | begin += 1 31 | if (data[0][0] > 0): 32 | # print(word) 33 | entity.append(word) 34 | # print('data', data[0:3]) 35 | if (max_string): # 是否最大匹配 36 | end = begin 37 | break 38 | sql = "select count(*) from `pkubase` where `value`='%s' or `value`='<%s>' or `value`='\"%s\"'" % (word, word, word) 39 | cur.execute(sql) 40 | data = cur.fetchall() 41 | if (data[0][0] > 0): 42 | # print(word) 43 | entity.append(word) 44 | # print('data', data[0:3]) 45 | if (max_string): # 是否最大匹配 46 | end = begin 47 | break 48 | sql = "select count(*) from `pkubase` where `prop`='%s' or `prop`='<%s>' or `prop`='\"%s\"'" % (word, word, word) 49 | cur.execute(sql) 50 | data = cur.fetchall() 51 | if(data[0][0] > 0): 52 | entity.append(word) 53 | if (max_string): # 是否最大匹配 54 | end = begin 55 | break 56 | sql = "select count(*) from `pkuorder` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % (word, word, word) 57 | cur.execute(sql) 58 | data = cur.fetchall() 59 | if (data[0][0] > 0): 60 | entity.append(word) 61 | if (max_string): # 是否最大匹配 62 | end = begin 63 | break 64 | end -= 1 65 | entity.reverse() 66 | # print('entity') 67 | # entity = delete_stopwords(entity) 68 | return entity 69 | 70 | # 根据三元组表pkubase查找实体,正向最大匹配算法 71 | def search_entity_pkubase_forward(line, max_string = True): 72 | begin = 0 73 | entity = [] 74 | while begin < len(line): 75 | end = len(line) 76 | # print('begin', begin) 77 | while(end > begin): 78 | word = line[begin: end] 79 | # print('end', end, end = ' ') 80 | sql = "select count(*) from `pkubase` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % (word, word, word) 81 | cur.execute(sql) 82 | data = cur.fetchall() 83 | end -= 1 84 | if(data[0][0] > 0): 85 | # print('word', word) 86 | entity.append(word) 87 | if (max_string): # 是否最大匹配 88 | begin = end 89 | break 90 | sql = "select count(*) from `pkubase` where `value`='%s' or `value`='<%s>' or `value`='\"%s\"'" % (word, word, word) 91 | cur.execute(sql) 92 | data = cur.fetchall() 93 | if (data[0][0] > 0): 94 | # print('word', word) 95 | entity.append(word) 96 | if (max_string): # 是否最大匹配 97 | begin = end 98 | break 99 | sql = "select count(*) from `pkubase` where `prop`='%s' or `prop`='<%s>' or `prop`='\"%s\"'" % (word, word, word) 100 | cur.execute(sql) 101 | data = cur.fetchall() 102 | if (data[0][0] > 0): 103 | entity.append(word) 104 | if (max_string): # 是否最大匹配 105 | begin = end 106 | break 107 | sql = "select count(*) from `pkuorder` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % ( 108 | word, word, word) 109 | cur.execute(sql) 110 | data = cur.fetchall() 111 | if (data[0][0] > 0): 112 | entity.append(word) 113 | if (max_string): # 是否最大匹配 114 | begin = end 115 | break 116 | # print('end', end) 117 | begin += 1 118 | # entity = delete_stopwords(entity) 119 | return entity 120 | 121 | # 根据三元组表pkubase查找实体,逆向最大匹配算法 122 | def search_entity_pkubase_all(line, max_string = False): 123 | end = len(line) 124 | entity = [] 125 | new_line = '' 126 | while end > 0: 127 | begin = 0 128 | # print('end', end) 129 | while (end > begin): 130 | # print('begin', begin, end = ' ') 131 | word = line[begin: end] 132 | # print('w', word) 133 | sql = "select count(*) from `pkubase` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % (word, word, word) 134 | cur.execute(sql) 135 | data = cur.fetchall() 136 | begin += 1 137 | if (data[0][0] > 0): 138 | # print(word) 139 | entity.append(word) 140 | # print('data', data[0:3]) 141 | if (max_string): # 是否最大匹配 142 | end = begin 143 | break 144 | else: 145 | continue 146 | sql = "select count(*) from `pkubase` where `value`='%s' or `value`='<%s>' or `value`='\"%s\"'" % (word, word, word) 147 | cur.execute(sql) 148 | data = cur.fetchall() 149 | if (data[0][0] > 0): 150 | # print(word) 151 | entity.append(word) 152 | # print('data', data[0:3]) 153 | if (max_string): # 是否最大匹配 154 | end = begin 155 | break 156 | else: 157 | continue 158 | sql = "select count(*) from `pkubase` where `prop`='%s' or `prop`='<%s>' or `prop`='\"%s\"'" % (word, word, word) 159 | cur.execute(sql) 160 | data = cur.fetchall() 161 | if(data[0][0] > 10): 162 | entity.append(word) 163 | if (max_string): # 是否最大匹配 164 | end = begin 165 | break 166 | else: 167 | continue 168 | sql = "select count(*) from `pkuorder` where `entry`='%s' or `entry`='<%s>' or `entry`='\"%s\"'" % (word, word, word) 169 | cur.execute(sql) 170 | data = cur.fetchall() 171 | if (data[0][0] > 0): 172 | entity.append(word) 173 | if (max_string): # 是否最大匹配 174 | end = begin 175 | break 176 | else: 177 | continue 178 | end -= 1 179 | entity.reverse() 180 | return entity 181 | 182 | # 根据jieba分词后的个数比较 183 | def select_cut_word_by_jieba_num(entity1, entity2): 184 | num1 = 0 185 | num2 = 0 186 | max1 = 0 187 | max2 = 0 188 | for word in entity1: 189 | num1 += len(jieba.lcut(word)) 190 | if(len(word) > max1): 191 | max1 = len(word) 192 | for word in entity2: 193 | num2 += len(jieba.lcut(word)) 194 | if (len(word) > max2): 195 | max2 = len(word) 196 | if(num1 > num2): 197 | return entity2 198 | elif num1 < num2: 199 | return entity1 200 | elif(max1 >= max2): 201 | return entity1 202 | else: 203 | return entity2 204 | 205 | 206 | #将正向最大匹配算法和逆向最大匹配算法结合的方法,即相互补充的思想 207 | def combine_forward_backward(entity1, entity2): 208 | i = 0 209 | j = 0 210 | temp1 = '' 211 | temp1_list = [] 212 | temp2 = '' 213 | temp2_list = [] 214 | entity_final = [] 215 | while(i < len(entity1)): 216 | # print('i', i) 217 | temp1 += entity1[i] 218 | temp1_list.append(entity1[i]) 219 | i += 1 220 | if(len(temp1) == len(temp2)): 221 | if ('\t'.join(temp1_list) == '\t'.join(temp2_list)): 222 | entity_final.extend(temp1_list) 223 | temp1 = '' 224 | temp2 = '' 225 | temp1_list = [] 226 | temp2_list = [] 227 | else: 228 | # print('选择频率高的') 229 | # entity_final.extend(select_cut_word(temp1_list, temp2_list)) 230 | entity_final.extend(select_cut_word_by_jieba_num(temp1_list, temp2_list)) 231 | # entity_final.extend(select_cut_word_by_jieba(temp1_list, temp2_list)) 232 | temp1 ='' 233 | temp2 ='' 234 | temp1_list = [] 235 | temp2_list = [] 236 | elif(len(temp1) < len(temp2)): 237 | continue 238 | else: 239 | while(j < len(entity2)): 240 | # print('j', j) 241 | temp2 += entity2[j] 242 | temp2_list.append(entity2[j]) 243 | j += 1 244 | if(len(temp1) == len(temp2)): 245 | if('\t'.join(temp1_list) == '\t'.join(temp2_list)): 246 | entity_final.extend(temp1_list) 247 | temp1 = '' 248 | temp2 = '' 249 | temp1_list = [] 250 | temp2_list = [] 251 | else: 252 | # print('选择频率高的') 253 | # entity_final.extend(select_cut_word(temp1_list, temp2_list)) 254 | entity_final.extend(select_cut_word_by_jieba_num(temp1_list, temp2_list)) 255 | # entity_final.extend(select_cut_word_by_jieba(temp1_list, temp2_list)) 256 | temp1 ='' 257 | temp2 = '' 258 | temp1_list = [] 259 | temp2_list = [] 260 | elif(len(temp1) > len(temp2)): 261 | continue 262 | else: 263 | break 264 | return entity_final 265 | 266 | # 去掉开头和结尾的'的' 267 | def del_de_extra(word): 268 | # print('word1', word) 269 | cut_word = jieba.lcut(word) 270 | while(len(cut_word) > 0): 271 | if(cut_word[0] == '的'): 272 | cut_word[0] = '\n' 273 | elif(cut_word[-1] == '的'): 274 | cut_word[-1] = '\n' 275 | else: 276 | break 277 | # print('cut_word', cut_word) 278 | new_word = '\n'.join(cut_word) 279 | # print('new_word', new_word) 280 | # word = new_word.replace('的\n', '\n') 281 | # word = word.strip('的') 282 | word = ''.join(cut_word) 283 | # word = word.replace('\n', '') 284 | # print('word2', word) 285 | return word 286 | 287 | # 去除开始和结尾处的标点 288 | def delete_punc_extra(line): 289 | punc = ['《', '》', '\"', '\'', '<', '>', '?', '?', ',', ',', ':'] 290 | # punc = [] 291 | for item in punc: 292 | line = line.strip(item) 293 | return line 294 | # # 去除开始和结尾处的标点 295 | # def delete_punc_outside(line): 296 | # punc = ['《', '》', '\"', '\'', '<', '>', '?', '?', ',', ',', ':'] 297 | # # punc = [] 298 | # for item in punc: 299 | # line = line.strip(item) 300 | # return line 301 | 302 | # 先预过一遍,得到初步结果,然后去除‘的’ 303 | def pre_cut(line_with_punc): 304 | entity1 = search_entity_pkubase1(line_with_punc) 305 | entity2 = search_entity_pkubase2(line_with_punc) 306 | relation = combine_forward_backward(entity1, entity2) 307 | # relation = entity1 308 | i = 0 309 | while i < len(relation): 310 | relation[i] = del_de_extra(relation[i]) 311 | # relation[i] = delete_punc_extra(relation[i]) 312 | i += 1 313 | return ''.join(relation) 314 | 315 | with open("../data/test.json",'r')as f: 316 | all_test_data = json.load(f) 317 | questions = all_test_data[1] 318 | mentions = all_test_data[6] 319 | 320 | # 测试集切分keywords 321 | i = 0 322 | entity_max_match = [] 323 | entity_all = [] 324 | while (i < len(questions)): 325 | relation = [] 326 | 327 | line_with_punc = questions[i] 328 | #line_with_punc = pre_cut(line_with_punc) 329 | line = line_with_punc 330 | print('line:', i, line) 331 | entity2 = search_entity_pkubase_forward(line)#正向最大匹配分词 332 | #entity_all.append(entity2) 333 | entity1 = search_entity_pkubase_backward(line)#逆向最大匹配分词 334 | #entity_all.append(entity1) 335 | # print('entity1', entity1) 336 | # print('entity2', entity2) 337 | relation = combine_forward_backward(entity1, entity2)#将正向和逆向结合,得到相对更合适的分词结果 338 | entity = [] 339 | for word in relation: 340 | if(word not in entity and len(word) > 1): 341 | entity.append(word) 342 | entity_max_match.append(entity) 343 | 344 | entitys = search_entity_pkubase_all(line_with_punc, False)#得到暴力搜索得到的分词结果 345 | entity = [] 346 | for word in entitys: 347 | if(word not in entity and len(word) > 1): 348 | entity.append(word) 349 | entity_all.append(entity) 350 | 351 | i += 1 352 | 353 | 354 | fn_out = 'data/questions_ws.txt' 355 | fp_out = open(fn_out, 'w', encoding='utf-8') 356 | lines = questions 357 | i = 0 358 | while (i < len(lines)): 359 | line_with_punc = lines[i] 360 | fp_out.write(line_with_punc + '\n') 361 | fp_out.write('\t'.join(entity_max_match[i]) + '\n') 362 | fp_out.write('\t'.join(entity_all[i]) + '\n') 363 | i += 1 364 | fp_out.close() -------------------------------------------------------------------------------- /PathRanking/model/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | def get_args(mode='train'): 5 | parser = ArgumentParser() 6 | 7 | ## Required parameters 8 | parser.add_argument("--bert_path", 9 | default='../../data/', 10 | type=str) 11 | parser.add_argument("--bert_model", 12 | default='../../data/bert-base-chinese.tar.gz', 13 | type=str, 14 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 15 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 16 | parser.add_argument("--bert_vocab", 17 | default='../../data/bert-base-chinese-vocab.txt', 18 | type=str, 19 | help="Bert Vocabulary") 20 | parser.add_argument("--output_dir", 21 | default='saved', 22 | type=str, 23 | help="The output directory where the model predictions and checkpoints will be written.") 24 | parser.add_argument("--train_file", 25 | default='../data/train.json', 26 | type=str, 27 | help="The train data. Should contain the .json files (or other data files) for the task.") 28 | parser.add_argument("--valid_file", 29 | default='../data/valid.json', 30 | type=str, 31 | help="The valid data. Should contain the .json files (or other data files) for the task.") 32 | 33 | ## Other parameters 34 | parser.add_argument("--max_seq_length", 35 | default=128, 36 | type=int, 37 | help="The maximum total input sequence length after WordPiece tokenization. \n" 38 | "Sequences longer than this will be truncated, and sequences shorter \n" 39 | "than this will be padded.") 40 | parser.add_argument("--do_lower_case", 41 | default=False, 42 | action='store_true', 43 | help="Set this flag if you are using an uncased model.") 44 | parser.add_argument("--batch_size", 45 | default=16, 46 | type=int, 47 | help="Total batch size for training.") 48 | parser.add_argument("--eval_batch_size", 49 | default=8, 50 | type=int, 51 | help="Total batch size for eval.") 52 | parser.add_argument("--neg_size", 53 | default=5, 54 | type=int, 55 | help="Size of negative sample.") 56 | parser.add_argument("--neg_fix", 57 | default=False, 58 | action='store_true', 59 | help="Whether not to fix neg sample.") 60 | parser.add_argument("--learning_rate", 61 | default=1e-5, 62 | type=float, 63 | help="The initial learning rate for Adam.") 64 | parser.add_argument("--margin", 65 | default=0.1, 66 | type=float, 67 | help="Margin for margin ranking loss.") 68 | parser.add_argument("--num_train_epochs", 69 | default=100, 70 | type=int, 71 | help="Total number of training epochs to perform.") 72 | parser.add_argument("--patience", 73 | default=10, 74 | type=int, 75 | help="Stop training when nums of epochs not improving.") 76 | parser.add_argument("--warmup_proportion", 77 | default=0.1, 78 | type=float, 79 | help="Proportion of training to perform linear learning rate warmup for. " 80 | "E.g., 0.1 = 10%% of training.") 81 | parser.add_argument('--gradient_accumulation_steps', 82 | type=int, 83 | default=1, 84 | help="Number of updates steps to accumulate before performing a backward/update pass.") 85 | parser.add_argument("--no_cuda", 86 | default=False, 87 | action='store_true', 88 | help="Whether not to use CUDA when available") 89 | parser.add_argument("--gpu", 90 | type=str, 91 | default='3', 92 | help="use which gpu") 93 | parser.add_argument('--seed', 94 | type=int, 95 | default=42, 96 | help="random seed for initialization") 97 | parser.add_argument("--optimizer", 98 | type=str, 99 | default='Adam', 100 | help="choose optimizer") 101 | parser.add_argument("--model", 102 | default='bert_comparing', 103 | type=str, 104 | choices=['bert_comparing','bert_sharecomparing']) 105 | # model params 106 | parser.add_argument("--requires_grad", 107 | action='store_true', 108 | help="Whether not to fine tune Bert.") 109 | parser.add_argument("--maxpooling", 110 | action='store_true', 111 | help="Whether not to use maxpooling") 112 | parser.add_argument("--avepooling", 113 | action='store_true', 114 | help="Whether not to use avepooling") 115 | parser.add_argument("--bert_embedding_size", 116 | default=768, 117 | type=int) 118 | parser.add_argument("--hidden_dim", 119 | default=300, 120 | type=int) 121 | parser.add_argument("--syntax_dim", 122 | default=800, 123 | type=int, 124 | help="dim for hidden syntax embedding") 125 | parser.add_argument("--lstm_layer", 126 | default=1, 127 | type=int) 128 | parser.add_argument("--bilstm", 129 | action='store_false', 130 | help='whether to use bilstm') 131 | parser.add_argument("--len_syntax_dict", 132 | default=30, 133 | type=int, 134 | help="Num of syntax labels.") 135 | parser.add_argument("--dropout", 136 | default=0.5, 137 | type=float, 138 | help='dropout rate for drop out layer.') 139 | if mode == 'predict': 140 | parser.add_argument("--model_path", 141 | default='saved2/pytorch_model.bin', 142 | type=str, 143 | help="the path of trained model!") 144 | parser.add_argument("--input_file", 145 | default='../cls_all_path/BERT_LSTM_maxpooling_embed/one_hop_cand_paths_ws_ent.json', 146 | type=str, 147 | help="the path of predict file!") 148 | parser.add_argument("--output_file", 149 | default='', 150 | type=str, 151 | help="the path of predict file!") 152 | parser.add_argument("--test_batch_size", 153 | default=8, 154 | type=int, 155 | help="batch size for test.") 156 | parser.add_argument("--topk", 157 | default=1, 158 | type=int, 159 | help="topk paths while inferring.") 160 | args = parser.parse_args() 161 | return args 162 | -------------------------------------------------------------------------------- /PathRanking/model/bert_function.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | 3 | def _is_whitespace(char): 4 | """Checks whether `chars` is a whitespace character.""" 5 | # \t, \n, and \r are technically contorl characters but we treat them 6 | # as whitespace since they are generally considered as such. 7 | if char == " " or char == "\t" or char == "\n" or char == "\r": 8 | return True 9 | cat = unicodedata.category(char) 10 | if cat == "Zs": 11 | return True 12 | return False 13 | 14 | def _is_control(char): 15 | """Checks whether `chars` is a control character.""" 16 | # These are technically control characters but we count them as whitespace 17 | # characters. 18 | if char == "\t" or char == "\n" or char == "\r": 19 | return False 20 | cat = unicodedata.category(char) 21 | if cat.startswith("C"): 22 | return True 23 | return False 24 | 25 | def _is_punctuation(char): 26 | """Checks whether `chars` is a punctuation character.""" 27 | cp = ord(char) 28 | # We treat all non-letter/number ASCII as punctuation. 29 | # Characters such as "^", "$", and "`" are not in the Unicode 30 | # Punctuation class but we treat them as punctuation anyways, for 31 | # consistency. 32 | if ( 33 | (cp >= 33 and cp <= 47) 34 | or (cp >= 58 and cp <= 64) 35 | or (cp >= 91 and cp <= 96) 36 | or (cp >= 123 and cp <= 126) 37 | ): 38 | return True 39 | cat = unicodedata.category(char) 40 | if cat.startswith("P"): 41 | return True 42 | return False 43 | 44 | def _clean_text(text): 45 | output = [] 46 | for char in text: 47 | cp = ord(char) 48 | if cp == 0 or cp == 0xFFFD or _is_control(char): 49 | continue 50 | if _is_whitespace(char): 51 | output.append(" ") 52 | else: 53 | output.append(char) 54 | return "".join(output) 55 | 56 | def judge_ignore(word): 57 | if len(_clean_text(word)) == 0: 58 | return True 59 | for char in word: 60 | cp = ord(char) 61 | if cp == 0 or cp == 0xFFFD or _is_control(char): 62 | return True 63 | return False 64 | 65 | def flatten(list_of_lists): 66 | for list in list_of_lists: 67 | for item in list: 68 | yield item 69 | 70 | def warmup_linear(x, warmup=0.002): 71 | if x < warmup: 72 | return x/warmup 73 | return 1.0 - x 74 | 75 | class Vocab(object): 76 | def __init__(self, bert_vocab_path): 77 | self.tokenizer = BertTokenizer.from_pretrained( 78 | bert_vocab_path, do_lower_case=False 79 | ) 80 | 81 | def convert_tokens_to_ids(self, tokens): 82 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens) 83 | ids = torch.tensor(token_ids, dtype=torch.long) 84 | mask = torch.ones(len(ids), dtype=torch.long) 85 | return ids, mask 86 | 87 | def subword_tokenize(self, tokens): 88 | subwords = list(map(self.tokenizer.tokenize, tokens)) 89 | subword_lengths = [1] + list(map(len, subwords)) + [1] 90 | subwords = ["[CLS]"] + list(flatten(subwords)) + ["[SEP]"] 91 | token_start_idxs = torch.cumsum(torch.tensor([0] + subword_lengths[:-1]), dim=0) 92 | return subwords, token_start_idxs 93 | 94 | def subword_tokenize_to_ids(self, tokens): 95 | tokens = ["[PAD]" if judge_ignore(t) else t for t in tokens] 96 | subwords, token_start_idxs = self.subword_tokenize(tokens) 97 | subword_ids, mask = self.convert_tokens_to_ids(subwords) 98 | token_starts = torch.zeros(len(subword_ids), dtype=torch.uint8) 99 | token_starts[token_start_idxs] = 1 100 | return subword_ids, mask, token_starts 101 | 102 | def tokenize(self, tokens): 103 | subwords = list(map(self.tokenizer.tokenize, tokens)) 104 | subword_lengths = [1] + list(map(len, subwords)) + [1] 105 | subwords = ["[CLS]"] + list(flatten(subwords)) + ["[SEP]"] 106 | return subwords 107 | -------------------------------------------------------------------------------- /PathRanking/model/data.py: -------------------------------------------------------------------------------- 1 | # -*- encoding:utf-8 -*- 2 | 3 | import json 4 | import torch 5 | 6 | class Data(): 7 | def __init__(self, args): 8 | self.args = args 9 | 10 | def load(self, mode): 11 | if mode == 'train': 12 | with open(self.args.train_file, 'r') as f: 13 | my_data = json.load(f) 14 | out = (my_data['questions'], my_data['golds'],my_data['negs']) 15 | elif mode == 'valid': 16 | with open(self.args.valid_file, 'r') as f: 17 | my_data = json.load(f) 18 | out = (my_data['questions'], my_data['golds'],my_data['negs']) 19 | elif mode == 'test': 20 | with open(self.args.valid_file, 'r') as f: 21 | my_data = json.load(f) 22 | out = (my_data['questions'],my_data['cands']) 23 | return out 24 | 25 | def numericalize(self, field, seqs): 26 | out = field.numericalize(seqs) 27 | return out 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /PathRanking/model/field.py: -------------------------------------------------------------------------------- 1 | import torch 2 | class Field(object): 3 | 4 | def __init__(self, name, pad=0, unk=None, bos='[CLS]', sep='[SEP]', 5 | lower=False, use_vocab=True, tokenizer=None, fn=None): 6 | self.name = name 7 | self.pad = pad 8 | self.unk = unk 9 | self.bos = bos 10 | self.sep = sep 11 | self.lower = lower 12 | self.use_vocab = use_vocab 13 | self.tokenizer = tokenizer 14 | self.fn = fn 15 | self.specials = [token for token in [pad, unk, bos, sep] 16 | if token is not None] 17 | 18 | def __repr__(self): 19 | s, params = f"({self.name}): {self.__class__.__name__}(", [] 20 | if self.pad is not None: 21 | params.append(f"pad={self.pad}") 22 | if self.unk is not None: 23 | params.append(f"unk={self.unk}") 24 | if self.bos is not None: 25 | params.append(f"bos={self.bos}") 26 | if self.sep is not None: 27 | params.append(f"sep={self.sep}") 28 | if self.lower: 29 | params.append(f"lower={self.lower}") 30 | if not self.use_vocab: 31 | params.append(f"use_vocab={self.use_vocab}") 32 | s += f", ".join(params) 33 | s += f")" 34 | 35 | return s 36 | 37 | @property 38 | def pad_index(self): 39 | return self.specials.index(self.pad) if self.pad is not None else 0 40 | 41 | @property 42 | def unk_index(self): 43 | return self.specials.index(self.unk) if self.unk is not None else 0 44 | 45 | @property 46 | def bos_index(self): 47 | return self.specials.index(self.bos) 48 | 49 | @property 50 | def eos_index(self): 51 | return self.specials.index(self.sep) 52 | 53 | def transform(self, sequence): 54 | if self.tokenizer is not None: 55 | sequence = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(sequence)) 56 | if self.lower: 57 | sequence = [str.lower(token) for token in sequence] 58 | if self.fn is not None: 59 | sequence = [self.fn(token) for token in sequence] 60 | return sequence 61 | 62 | def numericalize(self, sequences): 63 | sequences = [self.transform(sequence) for sequence in sequences] 64 | # if self.use_vocab: 65 | # sequences = [self.vocab.token2id(sequence) 66 | # for sequence in sequences] 67 | # if self.bos: 68 | # sequences = [[self.bos_index] + sequence for sequence in sequences] 69 | # if self.sep: 70 | # sequences = [sequence + [self.eos_index] for sequence in sequences] 71 | sequences = [torch.tensor(sequence) for sequence in sequences] 72 | return sequences 73 | 74 | class BertField(Field): 75 | 76 | def simplize(self, seq): 77 | chars = ['<','>'] 78 | new_seq = [] 79 | for word in seq: 80 | if word not in chars: 81 | new_seq.append(word) 82 | return new_seq 83 | 84 | def numericalize(self, sequences): 85 | subwords, lens = [], [] 86 | sequences = [([self.bos] if self.bos else []) + list(sequence) + 87 | ([self.sep] if self.sep else []) 88 | for sequence in sequences] 89 | 90 | origin_len = len(sequences) 91 | 92 | for one_sequence in sequences: 93 | sequence = one_sequence 94 | sequence = [self.transform(token) for token in sequence] 95 | if [] in sequence: 96 | sequence.remove([]) 97 | sequence = [piece if piece else self.transform(self.pad) 98 | for piece in sequence] 99 | subwords.append(sum(sequence, [])) 100 | lens.append(torch.tensor([len(piece) for piece in sequence])) 101 | subwords = [torch.tensor(pieces) for pieces in subwords] 102 | mask = [torch.ones(len(pieces)).ge(0) for pieces in subwords] 103 | 104 | assert origin_len == len(lens) 105 | return (subwords, lens, mask) 106 | 107 | class BertCharField(Field): 108 | 109 | def numericalize(self, sequences): 110 | tmp = sequences 111 | 112 | sequences = [ [self.bos] + self.tokenizer.tokenize(sequence) for sequence in sequences] 113 | sequences = [self.tokenizer.convert_tokens_to_ids(sequence) for sequence in sequences] 114 | sequences = [torch.tensor(sequence) for sequence in sequences] 115 | mask = [torch.ones(len(sequence)) for sequence in sequences] 116 | return (sequences, mask) 117 | 118 | class SyntaxField(object): 119 | def __init__(self, pad=0, unk='', bos='', sep=''): 120 | self.pad = pad 121 | self.unk = unk 122 | self.bos = bos 123 | self.sep = sep 124 | 125 | label_list=['','','','','','ADV','AMOD','APP','AUX','BNF','CJT','CJTN','CJTN0','CJTN1','CJTN2','CJTN3','CJTN4','CJTN5','CJTN6','CJTN7','CJTN8','CJTN9','CND','COMP','DIR','DMOD','EXT','FOC','IO','LGS','LOC','MNR','NMOD','OBJ','OTHER','PRD','PRP','PRT','RELC','ROOT','SBJ','TMP','TPC','UNK','VOC','cCJTN'] 126 | self.syntax_dict={} 127 | for item in label_list: 128 | self.syntax_dict[item]=len(self.syntax_dict) 129 | self.len_syntax_dict = len(self.syntax_dict) 130 | 131 | def numericalize(self, sequences): 132 | seqids, lens = [], [] 133 | sequences = [([self.bos] if self.bos else []) + list(sequence) + 134 | ([self.sep] if self.sep else []) 135 | for sequence in sequences] 136 | for seq in sequences: 137 | seq = [self.syntax_dict.get(label,self.syntax_dict.get(self.unk)) for label in seq] 138 | seqids.append(seq) 139 | lens.append(len(seq)) 140 | return seqids -------------------------------------------------------------------------------- /PathRanking/model/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import argparse 3 | import csv 4 | import logging 5 | import os 6 | import random 7 | import sys 8 | from datetime import datetime 9 | import numpy as np 10 | import torch 11 | from torch.utils.data.distributed import DistributedSampler 12 | from tqdm import tqdm, trange 13 | from torch.nn import CrossEntropyLoss, MSELoss, MarginRankingLoss 14 | from argparse import ArgumentParser 15 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME 16 | from pytorch_pretrained_bert.tokenization import BertTokenizer 17 | from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule 18 | logger = logging.getLogger(__name__) 19 | 20 | # personal package 21 | from field import * 22 | from bert_function import * 23 | from model import * 24 | from args import get_args 25 | from data import Data 26 | 27 | # neg sample function 28 | def train_neg_sample(negs, neg_size): 29 | new_negs = [] 30 | for (subwords, mask) in negs: 31 | l = len(mask) 32 | try: 33 | index = random.sample([i for i in range(l)], neg_size) 34 | except: 35 | from pdb import set_trace 36 | set_trace() 37 | new_negs.append(([subwords[i] for i in index], [mask[i] for i in index])) 38 | return new_negs 39 | 40 | # train data batchlize 41 | def train_batchlize(train_dataset_question, train_dataset_gold, train_dataset_negs, batch_size, neg_size): 42 | ''' 43 | 44 | ''' 45 | batches_train_question = data_batchlize(batch_size, train_dataset_question) 46 | batches_train_gold = data_batchlize(batch_size, train_dataset_gold) 47 | 48 | all_neg0, all_neg1 = [], [] 49 | # 拼起来 50 | for one_neg in train_dataset_negs: # case_num,((neg_size,L),(neg_size, L)) 51 | all_neg0.extend(one_neg[0]) 52 | all_neg1.extend(one_neg[1]) 53 | new_train_dataset_negs = (all_neg0, all_neg1) 54 | 55 | batches_train_negs = data_batchlize(batch_size*neg_size, new_train_dataset_negs) 56 | 57 | for i, batch_neg in enumerate(batches_train_negs): 58 | new_batch_neg = [] 59 | for one_data in batch_neg: 60 | case_nums, max_seq_len = one_data.shape 61 | one_data = one_data.reshape(neg_size, -1, max_seq_len) 62 | new_batch_neg.append(one_data) 63 | batches_train_negs[i] = tuple(new_batch_neg) 64 | return (batches_train_question, batches_train_gold, batches_train_negs) 65 | 66 | def data_batchlize(batch_size, data_tuple): 67 | ''' 68 | give a tuple, return batches of data 69 | ''' 70 | (subwords, mask) = data_tuple 71 | 72 | batches_subwords, batches_mask = [], [] 73 | 74 | indexs = [i for i in range(len(subwords))] 75 | start = 0 76 | start_indexs = [] 77 | while start <= len(indexs)-1: 78 | start_indexs.append(start) 79 | start += batch_size 80 | 81 | start = 0 82 | for start in start_indexs: 83 | cur_indexs = indexs[start:start + batch_size] 84 | cur_subwords = [subwords[i] for i in cur_indexs] 85 | cur_mask = [mask[i] for i in cur_indexs] 86 | 87 | maxlen_i, maxlen_j = 0, 0 88 | for i, j in zip(cur_subwords, cur_mask): 89 | maxlen_i, maxlen_j = max(maxlen_i, len(i)), max(maxlen_j, len(j)) 90 | batch_a, batch_b = [], [] 91 | for a, b in zip(cur_subwords, cur_mask): 92 | batch_a.append([i for i in a]+[0]*(maxlen_i-len(a))) 93 | batch_b.append([i for i in b]+[0]*(maxlen_j-len(b))) 94 | 95 | batches_subwords.append(torch.LongTensor(batch_a)) 96 | batches_mask.append(torch.LongTensor(batch_b)) 97 | 98 | return [item for item in zip(batches_subwords, batches_mask)] 99 | 100 | def train(args, bert_field, model): 101 | 102 | Dataset = Data(args) 103 | 104 | # datasets 105 | train_rawdata = Dataset.load('train') 106 | valid_rawdata = Dataset.load('valid') 107 | 108 | (train_rawdata_questions, train_rawdata_gold, train_rawdata_neg) = train_rawdata 109 | (valid_rawdata_questions, valid_rawdata_gold, valid_rawdata_neg) = valid_rawdata 110 | train_dataset_question = Dataset.numericalize(bert_field, train_rawdata_questions) 111 | train_dataset_gold = Dataset.numericalize(bert_field, train_rawdata_gold) 112 | train_dataset_negs = [] 113 | for one_neg in train_rawdata_neg: 114 | train_dataset_neg = Dataset.numericalize(bert_field, one_neg) # train_dataset_neg is a tuple(subwords, lens, mask) 115 | train_dataset_negs.append(train_dataset_neg) 116 | print('train data loaded!') 117 | 118 | if args.neg_fix: 119 | # batchlize 120 | # sample_train_dataset_negs = train_neg_sample(train_dataset_negs, args.neg_size) 121 | # train_data = train_batchlize(train_dataset_question, train_dataset_gold, sample_train_dataset_negs, args.batch_size, args.neg_size, syntax_embed=train_syntax_embed, hidden_embed=args.syntax_hidden_embed) 122 | 123 | # print("train data batchlized............") 124 | sample_train_dataset_negs = train_neg_sample(train_dataset_negs, args.neg_size) 125 | train_data = train_batchlize(train_dataset_question, train_dataset_gold, sample_train_dataset_negs, args.batch_size, args.neg_size) 126 | print("train data batchlized............") 127 | 128 | valid_dataset_question = Dataset.numericalize(bert_field, valid_rawdata_questions) 129 | valid_dataset_gold = Dataset.numericalize(bert_field, valid_rawdata_gold) 130 | valid_dataset_negs = [] 131 | for index, one_neg in enumerate(valid_rawdata_neg): 132 | if not one_neg: 133 | print('no neg paths', index) 134 | valid_dataset_neg = Dataset.numericalize(bert_field, one_neg) 135 | valid_dataset_negs.append(valid_dataset_neg) 136 | 137 | valid_dataset = (valid_dataset_question, valid_dataset_gold, valid_dataset_negs) 138 | print('valid data loaded!') 139 | 140 | # num of train steps 141 | print('train examples',len(train_rawdata_questions)) 142 | num_train_steps = int( 143 | len(train_rawdata_questions) / args.batch_size / args.gradient_accumulation_steps * args.num_train_epochs) 144 | 145 | # optimizer 146 | param_optimizer = list(model.named_parameters()) 147 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 148 | optimizer_grouped_parameters = [ 149 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 150 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 151 | ] 152 | 153 | optimizer = BertAdam(optimizer_grouped_parameters, 154 | lr=args.learning_rate, 155 | warmup=args.warmup_proportion, 156 | t_total=num_train_steps) 157 | 158 | # loss function 159 | criterion = MarginRankingLoss(margin=args.margin) 160 | 161 | # train params 162 | patience = args.patience 163 | num_train_epochs = args.num_train_epochs 164 | iters_left = patience 165 | best_precision = 0 166 | num_not_improved = 0 167 | global_step = 0 168 | 169 | logger.info('\nstart training:%s'%datetime.now().strftime('%Y-%m-%d %H:%M:%S')) 170 | print("start training!") 171 | 172 | # train and evaluate 173 | for epoch in range(args.num_train_epochs): 174 | 175 | # batchlize 176 | if not args.neg_fix: 177 | sample_train_dataset_negs = train_neg_sample(train_dataset_negs, args.neg_size) 178 | train_data = train_batchlize(train_dataset_question, train_dataset_gold, sample_train_dataset_negs, args.batch_size, args.neg_size) 179 | print("train data batchlized............") 180 | 181 | train_right = 0 182 | train_total = 0 183 | # 打印 184 | print('start time') 185 | start_time = datetime.now() 186 | logger.info('\nstart training:%s'%datetime.now().strftime('%Y-%m-%d %H:%M:%S')) 187 | print(start_time) 188 | 189 | model.train() 190 | optimizer.zero_grad() 191 | loss_epoch = 0 # 单次迭代的总loss 192 | (batches_train_question, batches_train_gold, batches_train_negs) = train_data 193 | for step,(batch_train_question, batch_train_gold, batch_train_negs) in enumerate(zip(batches_train_question,batches_train_gold, batches_train_negs)): 194 | batch_train_question = (t.cuda() for t in batch_train_question) 195 | batch_train_gold = (t.cuda() for t in batch_train_gold) 196 | batch_train_negs = (t.cuda() for t in batch_train_negs) 197 | scores = model(batch_train_question, batch_train_gold, batch_train_negs) 198 | (pos_score, neg_scores) = scores 199 | pos_score = pos_score.expand_as(neg_scores).reshape(-1) 200 | neg_scores = neg_scores.reshape(-1) 201 | assert len(pos_score) == len(neg_scores) 202 | ones = torch.ones(pos_score.shape) 203 | if args.no_cuda == False: 204 | ones = ones.cuda() 205 | loss = criterion(pos_score, neg_scores, ones) 206 | 207 | # evaluate train 208 | result = (torch.sum(pos_score.reshape(-1, args.neg_size) > neg_scores.reshape(-1, args.neg_size),-1) == args.neg_size).cpu() 209 | 210 | train_right += torch.sum(result).item() 211 | train_total += len(result) 212 | 213 | if args.gradient_accumulation_steps > 1: 214 | loss = loss / args.gradient_accumulation_steps 215 | loss.backward() 216 | loss_epoch += loss 217 | if (step + 1) % args.gradient_accumulation_steps == 0: 218 | # modify learning rate with special warm up BERT uses 219 | lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_steps, args.warmup_proportion) 220 | for param_group in optimizer.param_groups: 221 | param_group['lr'] = lr_this_step 222 | optimizer.step() 223 | optimizer.zero_grad() 224 | global_step += 1 225 | 226 | # 打印 227 | end_time = datetime.now() 228 | logger.info('\ntrain epoch %d time span:%s'%(epoch, end_time-start_time)) 229 | print('train loss', loss_epoch.item()) 230 | logger.info('train loss:%f'%loss_epoch.item()) 231 | print('train result', train_right, train_total, 1.0*train_right/train_total) 232 | logger.info(('train result', train_right, train_total, 1.0*train_right/train_total)) 233 | 234 | # 评估 235 | right, total, precision = evaluate(args, model, valid_dataset, valid_rawdata, epoch) 236 | # right, total, precision = 0, 0, 0.0 237 | 238 | # 打印 239 | print('valid result', right, total, precision) 240 | print('epoch time') 241 | print(datetime.now()) 242 | print('*'*20) 243 | logger.info("epoch:%d\t"%epoch+"dev_Accuracy-----------------------%d/%d=%f\n"%(right, total, precision)) 244 | end_time = datetime.now() 245 | logger.info('dev epoch %d time span:%s'%(epoch,end_time-start_time)) 246 | 247 | if precision > best_precision: 248 | best_precision = precision 249 | iters_left = patience 250 | print("epoch %d saved\n"%epoch) 251 | logger.info("epoch %d saved\n"%epoch) 252 | # Save a trained model 253 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 254 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 255 | torch.save(model_to_save.state_dict(), output_model_file) 256 | else: 257 | iters_left -= 1 258 | if iters_left == 0: 259 | break 260 | logger.info('finish training!') 261 | print('finish training!') 262 | 263 | def evaluate(args, model, valid_dataset, valid_rawdata,epoch): 264 | model.eval() 265 | #f = open('tmp_predict_epoch%d.txt'%epoch,'w') 266 | f = open('tmp_predict_valid.txt','w') 267 | (valid_dataset_question, valid_dataset_gold, valid_dataset_negs) = valid_dataset 268 | (q_valid1, q_valid2) = valid_dataset_question 269 | (gold_valid1, gold_valid2) = valid_dataset_gold 270 | # (negs_valid1, negs_valid2,negs_valid3) = valid_dataset_negs 271 | right = 0 272 | total = 0 273 | for index, (q1, q2, gold1, gold2, negs) in enumerate(zip(q_valid1, q_valid2, gold_valid1, gold_valid2, valid_dataset_negs)): 274 | q = (q1, q2) 275 | gold = (gold1, gold2) 276 | # negs = (neg1, neg2, neg3) 277 | batches_negs = data_batchlize(args.eval_batch_size, negs) 278 | pos_score, all_scores = model.cal_score(q, batches_negs, pos=gold) 279 | 280 | f.write(''.join(valid_rawdata[0][index])) 281 | total += 1 282 | if len(all_scores) == torch.sum(pos_score > all_scores): 283 | right += 1 284 | f.write(''.join(valid_rawdata[1][index])) 285 | else: 286 | f.write(''.join(valid_rawdata[2][index][torch.argmax(all_scores)]) if len(valid_rawdata[2][index]) > 0 else '') 287 | f.write('\n') 288 | f.close() 289 | return right, total, 1.0*right/total 290 | 291 | if __name__ == "__main__": 292 | 293 | # params 294 | args = get_args() 295 | 296 | # 297 | # if not os.path.exists(args.output_dir): 298 | # os.makedirs(args.output_dir) 299 | # print('maked dir %s' % args.output_dir) 300 | 301 | # random seed 302 | random.seed(args.seed) 303 | np.random.seed(args.seed) 304 | torch.manual_seed(args.seed) 305 | 306 | # tokenize 307 | tokenizer = BertTokenizer.from_pretrained(args.bert_vocab, do_lower_case=args.do_lower_case) 308 | bert_field = BertCharField('BERT', tokenizer=tokenizer) 309 | print("loaded tokenizer") 310 | 311 | # gpu 312 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 313 | print('使用%s号GPU' % args.gpu) 314 | # model 315 | model_dict = {'bert_comparing':Bert_Comparing,'bert_sharecomparing':Bert_ShareComparing} 316 | print(args.model) 317 | model_name = model_dict[args.model] 318 | model = model_name(args) 319 | if args.no_cuda == False: 320 | model.cuda() 321 | 322 | print(model) 323 | print(args) 324 | train(args, bert_field, model) 325 | -------------------------------------------------------------------------------- /PathRanking/model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_pretrained_bert.modeling import BertModel 4 | from pytorch_pretrained_bert.tokenization import BertTokenizer 5 | from torch.nn.utils.rnn import pad_sequence 6 | import unicodedata 7 | from torch.nn.utils.rnn import pack_padded_sequence as pack 8 | from torch.nn.utils.rnn import pad_packed_sequence as pad 9 | import torch.nn.functional as F 10 | import random 11 | from torch.nn import CosineSimilarity 12 | from collections import Counter 13 | from field import * 14 | 15 | # def data_batchlize(batch_size, data_tuple, syntaxs=None, hidden_embed=False): 16 | # ''' 17 | # give a tuple, return batches of data 18 | # ''' 19 | # (subwords, lens, mask) = data_tuple 20 | # if syntaxs: 21 | # batches_subwords, batches_lens, batches_mask, batches_syntaxs = [], [], [], [] 22 | # else: 23 | # batches_subwords, batches_lens, batches_mask = [], [], [] 24 | # indexs = [i for i in range(len(subwords))] 25 | # start = 0 26 | # start_indexs = [] 27 | # while start <= len(indexs)-1: 28 | # start_indexs.append(start) 29 | # start += batch_size 30 | # start = 0 31 | # for start in start_indexs: 32 | # cur_indexs = indexs[start:start + batch_size] 33 | # cur_subwords = [subwords[i] for i in cur_indexs] 34 | # cur_lens = [lens[i] for i in cur_indexs] 35 | # cur_mask = [mask[i] for i in cur_indexs] 36 | # if syntaxs: 37 | # cur_syntaxs = [syntaxs[i] for i in cur_indexs] 38 | # maxlen_i, maxlen_j, maxlen_k = 0, 0, 0 39 | # for i, j, k in zip(cur_subwords, cur_lens, cur_mask): 40 | # maxlen_i, maxlen_j, maxlen_k = max(maxlen_i, len(i)), max(maxlen_j, len(j)), max(maxlen_k, len(k)) 41 | # batch_a, batch_b, batch_c = [], [], [] 42 | # for a, b, c in zip(cur_subwords, cur_lens, cur_mask): 43 | # batch_a.append([i for i in a]+[0]*(maxlen_i-len(a))) 44 | # batch_b.append([i for i in b]+[0]*(maxlen_j-len(b))) 45 | # batch_c.append([i for i in c]+[0]*(maxlen_k-len(c))) 46 | # if syntaxs: 47 | # batch_syntax = [] 48 | # if not hidden_embed: 49 | # for d in cur_syntaxs: 50 | # batch_syntax.append([i for i in d]+[0]*(maxlen_j-len(d))) 51 | # else: 52 | # hidden_size = cur_syntaxs[0].shape[1] 53 | # batch_syntax = torch.zeros(min(batch_size,len(cur_syntaxs)), maxlen_j, hidden_size) 54 | # for i, matrix in enumerate(cur_syntaxs): 55 | # one_len, _ = matrix.shape 56 | # # from pdb import set_trace 57 | # # set_trace() 58 | # batch_syntax[i, :one_len] = matrix[: , :] 59 | # batches_subwords.append(torch.LongTensor(batch_a)) 60 | # batches_lens.append(torch.LongTensor(batch_b)) 61 | # batches_mask.append(torch.LongTensor(batch_c)) 62 | # if syntaxs: 63 | # if hidden_embed: 64 | # batches_syntaxs.append(batch_syntax) 65 | # else: 66 | # batches_syntaxs.append(torch.LongTensor(batch_syntax)) 67 | # if syntaxs: 68 | # return [item for item in zip(batches_subwords, batches_lens, batches_mask, batches_syntaxs)] 69 | # else: 70 | # return [item for item in zip(batches_subwords, batches_lens, batches_mask)] 71 | 72 | class BertEmbedding(nn.Module): 73 | 74 | def __init__(self, model, requires_grad=True): 75 | super(BertEmbedding, self).__init__() 76 | 77 | self.bert = BertModel.from_pretrained(model) 78 | #self.bert = self.bert.requires_grad_(requires_grad) 79 | self.requires_grad = requires_grad 80 | self.hidden_size = self.bert.config.hidden_size 81 | self.n_layers = 1 82 | 83 | def __repr__(self): 84 | s = self.__class__.__name__ + '(' 85 | if hasattr(self, 'n_layers') and hasattr(self, 'n_out'): 86 | s += f"n_layers={self.n_layers}, n_out={self.n_out}" 87 | if self.requires_grad: 88 | s += f", requires_grad={self.requires_grad}" 89 | s += ')' 90 | return s 91 | 92 | def forward(self, subwords, bert_lens, bert_mask): 93 | batch_size, seq_len = bert_lens.shape 94 | mask = bert_lens.gt(0) 95 | bert_mask = bert_mask.gt(0) 96 | if not self.requires_grad: 97 | self.bert.eval() 98 | bert, _ = self.bert(subwords, attention_mask=bert_mask, output_all_encoded_layers=False) 99 | bert = bert[bert_mask].split(bert_lens[mask].tolist()) 100 | bert = torch.stack([i.mean(0) for i in bert]) 101 | bert_embed = bert.new_zeros(batch_size, seq_len, self.hidden_size) 102 | bert_embed = bert_embed.masked_scatter_(mask.unsqueeze(-1), bert) 103 | return bert_embed 104 | 105 | class BertCharEmbedding(nn.Module): 106 | def __init__(self, path, requires_grad=True): 107 | super(BertCharEmbedding, self).__init__() 108 | self.bert = BertModel.from_pretrained(path) 109 | self.requires_grad = requires_grad 110 | 111 | def forward(self, subwords, bert_mask): 112 | bert, _ = self.bert(subwords, attention_mask=bert_mask, output_all_encoded_layers=False) 113 | return bert 114 | 115 | class Bert_Comparing(nn.Module): 116 | def __init__(self, data): 117 | super(Bert_Comparing, self).__init__() 118 | 119 | self.question_bert_embedding = BertCharEmbedding(data.bert_path, data.requires_grad) 120 | self.path_bert_embedding = BertCharEmbedding(data.bert_path, data.requires_grad) 121 | self.args = data 122 | self.similarity = CosineSimilarity(dim=1) 123 | 124 | def question_encoder(self, input_idxs, bert_mask): 125 | bert_outs = self.question_bert_embedding(input_idxs, bert_mask) 126 | return bert_outs[:, 0] 127 | 128 | def path_encoder(self, input_idxs, bert_mask): 129 | bert_outs = self.path_bert_embedding(input_idxs, bert_mask) 130 | return bert_outs[:, 0] 131 | 132 | def forward(self, questions, pos, negs): 133 | ''' 134 | questions: batch_size, max_seq_len 135 | 136 | pos_input_idxs: batch_size, max_seq_len 137 | pos_bert_lens: batch_size, max_seq_len 138 | pos_bert_mask: batch_size, max_seq_len 139 | 140 | neg_input_idxs: neg_size, batch_size, max_seq_len 141 | neg_bert_lens: neg_size, batch_size, max_seq_len 142 | neg_bert_mask: neg_size, batch_size, max_seq_len 143 | ''' 144 | 145 | (q_input_idxs, q_bert_mask) = questions 146 | 147 | (pos_input_idxs, pos_bert_mask) = pos 148 | (neg_input_idxs, neg_bert_mask) = negs 149 | neg_size, batch_size, _ = neg_input_idxs.shape 150 | 151 | q_encoding = self.question_encoder(q_input_idxs, q_bert_mask) # (batch_size, hidden_dim) 152 | 153 | pos_encoding = self.path_encoder(pos_input_idxs, pos_bert_mask) 154 | 155 | neg_input_idxs = neg_input_idxs.reshape(neg_size*batch_size, -1) # (neg_size*batch_size, max_seq_len) 156 | neg_bert_mask = neg_bert_mask.reshape(neg_size*batch_size, -1) # (neg_size*batch_size, max_seq_len) 157 | 158 | neg_encoding = self.path_encoder(neg_input_idxs, neg_bert_mask) # (neg_size*batch_size, hidden_dim) 159 | # p_encoding = p_encoding.reshape(neg_size, batch_size, -1) # (neg_size, batch_size, hidden_dim) 160 | 161 | q_encoding_expand = q_encoding.unsqueeze(0).expand(neg_size, batch_size, q_encoding.shape[-1]).reshape(neg_size*batch_size, -1) # (neg_size*batch_size, hidden_dim) 162 | 163 | pos_score = self.similarity(q_encoding, pos_encoding) 164 | pos_score = pos_score.unsqueeze(1) # (batch_size, 1) 165 | neg_score = self.similarity(q_encoding_expand, neg_encoding) 166 | neg_score = neg_score.reshape(neg_size,-1).transpose(0,1) # (batch_size, neg_size) 167 | 168 | return (pos_score, neg_score) 169 | 170 | @torch.no_grad() 171 | def cal_score(self, question, cands, pos=None): 172 | ''' 173 | one question, several candidate paths 174 | question: (max_seq_len), (max_seq_len), (max_seq_len) 175 | cands: (batch_size, max_seq_len), (batch_size, max_seq_len), (batch_size, max_seq_len) 176 | ''' 177 | question = (t.unsqueeze(0) for t in question) 178 | 179 | if self.args.no_cuda == False: 180 | question = (t.cuda() for t in question) 181 | 182 | (q_input_idxs, q_bert_mask) = question 183 | 184 | q_encoding = self.question_encoder(q_input_idxs, q_bert_mask) # (batch_size=1, hidden_dim) 185 | 186 | if pos: 187 | pos = (t.unsqueeze(0) for t in pos) 188 | if self.args.no_cuda == False: 189 | pos = (t.cuda() for t in pos) 190 | 191 | (pos_input_idxs, pos_bert_mask) = pos 192 | pos_encoding = self.path_encoder(pos_input_idxs, pos_bert_mask) # (batch_size=1, hidden_dim) 193 | pos_score = self.similarity(q_encoding, pos_encoding) # (batch_size=1) 194 | 195 | all_scores = [] 196 | 197 | for (batch_input_idxs, batch_bert_mask) in cands: 198 | if self.args.no_cuda ==False: 199 | batch_input_idxs, batch_bert_mask = batch_input_idxs.cuda(), batch_bert_mask.cuda() 200 | path_encoding = self.path_encoder(batch_input_idxs, batch_bert_mask) #(batch_size, hidden_dim) 201 | q_encoding_expand = q_encoding.expand_as(path_encoding) 202 | scores = self.similarity(q_encoding_expand, path_encoding) # (batch_size) 203 | for score in scores: 204 | all_scores.append(score) 205 | all_scores = torch.Tensor(all_scores) 206 | 207 | if pos: 208 | return pos_score.cpu(), all_scores.cpu() 209 | else: 210 | return all_scores.cpu() 211 | 212 | class Bert_ShareComparing(Bert_Comparing): 213 | def __init__(self, data): 214 | super(Bert_ShareComparing, self).__init__(data) 215 | self.question_bert_embedding = BertCharEmbedding(data.bert_path, data.requires_grad) 216 | self.path_bert_embedding = self.question_bert_embedding 217 | self.args = data 218 | self.similarity = CosineSimilarity(dim=1) 219 | -------------------------------------------------------------------------------- /PathRanking/model/process_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | # pip install requests first 4 | import requests 5 | 6 | def segment(seqs): 7 | ''' 8 | 利用接口分词 9 | ''' 10 | if not seqs: 11 | return [] 12 | url = "http://192.168.126.171:5001/api" 13 | headers = {"Content-Type": "application/json; charset=UTF-8"} 14 | input_json = {"input_string":seqs, "ws": True, "pos": False, "dep": False} 15 | response = requests.post(url, data=json.dumps(input_json), headers=headers) 16 | outs = response.json() 17 | if 'words' in outs.keys(): 18 | return outs['words'] 19 | else: 20 | return [] 21 | 22 | if __name__ == '__main__': 23 | dir = '../cls_all_path/BERT_LSTM_maxpooling_embed' 24 | fns = ['one_hop_cand_paths_ent.json', 'two_hop_cand_paths_ent.json', 'multi_constraint_cand_paths_ent.json', 'left_cand_paths_ent.json'] 25 | for fn in fns: 26 | path = os.path.join(dir, fn) 27 | output_path = os.path.join(dir, fn.replace('paths','paths_ws')) 28 | with open(path, 'r') as f, open(output_path, 'w')as fout: 29 | data_ws = [] 30 | data_input = json.load(f) 31 | for line in data_input: 32 | q = line['q'] 33 | seqs = segment([q.replace(' ','')]) 34 | q_ws = seqs[0] 35 | paths = line['paths'] 36 | seqs = [''.join(item).replace(' ','') for item in paths] 37 | paths_ws = segment(seqs) 38 | assert len(paths) == len(paths_ws) 39 | new_line = {'q':q, 'q_ws':q_ws, 'paths':paths, 'paths_ws':paths_ws} 40 | data_ws.append(new_line) 41 | json.dump(data_ws, fout, ensure_ascii=False) 42 | print('File', fn, 'finish') 43 | 44 | 45 | -------------------------------------------------------------------------------- /PathRanking/model/train.sh: -------------------------------------------------------------------------------- 1 | nohup python -u main.py --gpu 4 --model 'bert_sharecomparing' --neg_fix --batch_size 16 --train_file '../data/train.json' --valid_file '../data/valid.json' --output_dir 'saved_sharebert_negfix' --requires_grad >saved_sharebert_negfix/log.txt &2>1 & 2 | -------------------------------------------------------------------------------- /PathRanking/predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import argparse 3 | import csv 4 | import logging 5 | import os 6 | import random 7 | import sys 8 | import json 9 | from datetime import datetime 10 | import numpy as np 11 | import torch 12 | from torch.utils.data.distributed import DistributedSampler 13 | from tqdm import tqdm, trange 14 | from torch.nn import CrossEntropyLoss, MSELoss, MarginRankingLoss 15 | from argparse import ArgumentParser 16 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME 17 | from pytorch_pretrained_bert.tokenization import BertTokenizer 18 | from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule 19 | logger = logging.getLogger(__name__) 20 | 21 | # personal package 22 | from model.field import * 23 | from model.bert_function import * 24 | from model.model import * 25 | from model.args import get_args 26 | from model.data import Data 27 | 28 | alpha = 0.1 # 字符层面得分占比 not 0.5 29 | 30 | def occupied(seqa,seqb_list): 31 | # seqa:question 32 | # seqb:path 33 | # return value:[-1,1] 34 | scores = [] 35 | for seqb in seqb_list: 36 | s = jaccard(seqa,seqb) 37 | scores.append(s) 38 | return scores 39 | 40 | def jaccard(seqa,seqb): 41 | 42 | """ 43 | 返回两个句子的 jaccard 相似度 并没有 计算 字出现的次数 44 | """ 45 | seqa = set(list(seqa.upper())) 46 | seqb = set(list(seqb.upper())) 47 | aa = seqa.intersection(seqb) 48 | bb = seqa.union(seqb) 49 | #return (len(aa)-1)/len(bb) 50 | return len(aa)/len(bb) 51 | 52 | def hint(seqa,seqb): 53 | seqa = set(list(seqa.upper())) 54 | seqb = set(list(seqb.upper())) 55 | aa = seqa.intersection(seqb) 56 | return len(aa) 57 | 58 | def data_batchlize(batch_size, data_tuple): 59 | ''' 60 | give a tuple, return batches of data 61 | ''' 62 | (subwords, mask) = data_tuple 63 | 64 | batches_subwords, batches_mask = [], [] 65 | 66 | indexs = [i for i in range(len(subwords))] 67 | start = 0 68 | start_indexs = [] 69 | while start <= len(indexs)-1: 70 | start_indexs.append(start) 71 | start += batch_size 72 | 73 | start = 0 74 | for start in start_indexs: 75 | cur_indexs = indexs[start:start + batch_size] 76 | cur_subwords = [subwords[i] for i in cur_indexs] 77 | cur_mask = [mask[i] for i in cur_indexs] 78 | 79 | maxlen_i, maxlen_j = 0, 0 80 | for i, j in zip(cur_subwords, cur_mask): 81 | maxlen_i, maxlen_j = max(maxlen_i, len(i)), max(maxlen_j, len(j)) 82 | batch_a, batch_b = [], [] 83 | for a, b in zip(cur_subwords, cur_mask): 84 | batch_a.append([i for i in a]+[0]*(maxlen_i-len(a))) 85 | batch_b.append([i for i in b]+[0]*(maxlen_j-len(b))) 86 | 87 | batches_subwords.append(torch.LongTensor(batch_a)) 88 | batches_mask.append(torch.LongTensor(batch_b)) 89 | 90 | return [item for item in zip(batches_subwords, batches_mask)] 91 | 92 | # 删去实体中的带括号的描述信息 93 | def del_des(string): 94 | stack=[] 95 | # if '_(' not in string and ')' not in string and '_(' not in string and ')' not in string: 96 | if '_' not in string: 97 | return string 98 | mystring=string[1:-1] 99 | if mystring[-1]!=')' and mystring[-1]!=')': 100 | return string 101 | for i in range(len(mystring)-1,-1,-1): 102 | char=mystring[i] 103 | if char==')': 104 | stack.append(')') 105 | elif char == ')': 106 | stack.append(')') 107 | elif char=='(': 108 | if stack[-1]==')': 109 | stack=stack[:-1] 110 | if not stack: 111 | break 112 | elif char=='(': 113 | if stack[-1]==')': 114 | stack=stack[:-1] 115 | if not stack: 116 | break 117 | if mystring[i-1]=='_': 118 | i-=1 119 | else: 120 | return string 121 | return '<'+mystring[:i]+'>' 122 | 123 | def predict(args, model, field): 124 | 125 | model.eval() 126 | Dataset = Data(args) 127 | 128 | fn_in = args.input_file 129 | # if 'cand_paths' in fn_in: 130 | # fn_out = fn_in.replace('cand_paths','best_path') 131 | # else: 132 | # fn_out = fn_in.replace('paths','predict_path') 133 | if not args.output_file: 134 | fn_out = fn_in.replace('cand_paths','best_path') 135 | else: 136 | fn_out = args.output_file 137 | 138 | with open(fn_in, 'r')as f: 139 | raw_data = json.load(f) 140 | 141 | output_data = {} 142 | 143 | topk = args.topk 144 | 145 | for line in raw_data: 146 | if 'q_ws' in line.keys(): 147 | q, q_ws, paths, paths_ws = line['q'], line['q_ws'], line['paths'], line['paths_ws'] 148 | else: 149 | q, paths= line['q'], line['paths'] 150 | 151 | one_question = Dataset.numericalize(field, [q]) # 内部元素都是二维的 152 | one_question = [t[0] for t in one_question] # 内部是一维的 153 | 154 | one_question = (t for t in one_question) 155 | 156 | paths_input = [''.join([del_des(item) for item in path]) for path in paths] 157 | one_cands = Dataset.numericalize(field, paths_input) 158 | batches_cands = data_batchlize(args.test_batch_size, one_cands) 159 | 160 | # 字符层面得分 161 | char_scores = occupied(q,[''.join([del_des(i) for i in p]) for p in paths]) 162 | char_scores = torch.Tensor(char_scores) 163 | # 模型层面得分 164 | model_scores = model.cal_score(one_question, batches_cands) 165 | all_scores = alpha*char_scores + (1-alpha)*model_scores 166 | 167 | if len(all_scores)>0 and topk == 1: 168 | index = torch.argmax(all_scores) 169 | output_data[q] = paths[index] 170 | elif len(all_scores)>0 and topk > 1: 171 | sorted_scores, index = torch.sort(all_scores, descending=True) 172 | output_data[q] = [paths[i] for i in index[:topk]] 173 | else: 174 | print(q, 'no path') 175 | 176 | with open(fn_out, 'w')as f: 177 | json.dump(output_data, f, ensure_ascii=False) 178 | 179 | if __name__ == "__main__": 180 | 181 | args = get_args(mode='predict') 182 | 183 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 184 | 185 | # tokenize 186 | tokenizer = BertTokenizer.from_pretrained(args.bert_vocab, do_lower_case=args.do_lower_case) 187 | bert_field = BertField('BERT', tokenizer=tokenizer) 188 | print("loaded tokenizer") 189 | 190 | # gpu 191 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 192 | print('使用%s号GPU' % args.gpu) 193 | 194 | # model 195 | # model = Bert_Comparing(args) 196 | # if args.no_cuda == False: 197 | # model.cuda() 198 | 199 | # Load a trained model that you have fine-tuned 200 | model_dict = {'bert_comparing':Bert_Comparing,'bert_sharecomparing':Bert_ShareComparing} 201 | model_name = model_dict[args.model] 202 | 203 | model_state_dict = torch.load(args.model_path) 204 | model = model_name(args) 205 | model.load_state_dict(model_state_dict) 206 | model.eval() 207 | 208 | if args.no_cuda == False: 209 | model.cuda() 210 | print('loaded model!') 211 | 212 | # tokenizer 213 | tokenizer = BertTokenizer.from_pretrained(args.bert_vocab, do_lower_case=args.do_lower_case) 214 | bert_field = BertCharField('BERT', tokenizer=tokenizer) 215 | print("loaded tokenizer!") 216 | 217 | fn = args.input_file 218 | with open(fn, 'r')as f: 219 | test_raw_data = json.load(f) 220 | predict(args, model, bert_field) 221 | -------------------------------------------------------------------------------- /PathRanking/predict_stage1.sh: -------------------------------------------------------------------------------- 1 | DATADIR='../PreScreen/data/' 2 | GPU=0 3 | TOPK=5 4 | 5 | MODEL='bert_sharecomparing' 6 | MODELDIR='saved_sharebert/' 7 | DATADIR2='merge/' 8 | 9 | # predict 10 | nohup python -u predict.py --gpu $GPU --learning_rate 1e-5 --margin 0.1 --model $MODEL --model_path $MODELDIR'pytorch_model.bin' --input_file $DATADIR'one_hop_paths.json' --output_file $DATADIR$DATADIR2'one_hop_predict_path.json' --topk $TOPK >$DATADIR$DATADIR2'log_'$TOPK'.txt' & 11 | echo 'Finish prescreen' -------------------------------------------------------------------------------- /PathRanking/predict_stage2.sh: -------------------------------------------------------------------------------- 1 | DATADIR='../PreScreen/data/' 2 | GPU=0 3 | TOPK=5 4 | 5 | MODEL='bert_sharecomparing' 6 | MODELDIR='saved_sharebert/' 7 | DATADIR2='merge/' 8 | 9 | nohup python -u predict.py --gpu $GPU --learning_rate 1e-5 --margin 0.1 --model $MODEL --model_path $MODELDIR'pytorch_model.bin' --input_file $DATADIR$DATADIR2'paths_all_merge.json' --output_file $DATADIR$DATADIR2'mix_predict_path.json' --topk 1 >'log_merge'$TOPK'.txt' & 10 | echo 'Finish predcit' -------------------------------------------------------------------------------- /PathRanking/search_path_stage2.sh: -------------------------------------------------------------------------------- 1 | DATADIR='../PreScreen/data/' 2 | TOPK=5 3 | DATADIR2='merge/' 4 | 5 | nohup python -u ../../PreScreen/data/mix_paths.py --fn_in $A$DATADIR2"one_hop_predict_path.json" --fn_out $A$DATADIR2"mix_paths.json" >'log'$TOPK'.txt' & 6 | nohup python -u ../../PreScreen/data/merge_path.py --fn_in $A$DATADIR2"mix_paths.json" --fn_multi "multi_paths.json" --fn_out $DATADIR2"mix_paths_all.json" >'log'$TOPK'.txt' & 7 | echo 'Finish search path' -------------------------------------------------------------------------------- /PathRanking/utils/ans_tools.py: -------------------------------------------------------------------------------- 1 | # 搜索数据库找答案 2 | import sys 3 | import json 4 | from pdb import set_trace 5 | 6 | sys.path.append('../../') 7 | from utils import * 8 | read_con = pymysql.connect(host="192.168.126.143",port = 3337,user='root', password='pjzhang', database='ccks_2019',charset='utf8') 9 | cur = read_con.cursor() 10 | 11 | def multi_path2ans(fns): 12 | answer={} 13 | path_questions={} 14 | for fn in fns: 15 | out_merge=json.load(open(fn,'r')) 16 | fn_out=fn.replace('path','ans') 17 | for key,value in out_merge.items(): 18 | out=[] 19 | path_questions[key]=value 20 | out1=[] 21 | out1.extend([line[3] for line in search_ans(value[0],value[1],cur)]) 22 | out1.extend([line[1] for line in search_ans(value[0],value[1],cur,reverse=True)]) 23 | out2=[] 24 | out2.extend([line[3] for line in search_ans(value[2],value[3],cur)]) 25 | out2.extend([line[1] for line in search_ans(value[2],value[3],cur,reverse=True)]) 26 | 27 | if len(value)==6: 28 | out3=[] 29 | out3.extend([line[3] for line in search_ans(value[4],value[5],cur)]) 30 | out3.extend([line[1] for line in search_ans(value[4],value[5],cur,reverse=True)]) 31 | for item in out1: 32 | if item in out2 and len(value)==6 and item in out3: 33 | out.append(item) 34 | elif item in out2 and len(value)==4: 35 | out.append(item) 36 | answer[key]=list(set(out)) 37 | json.dump(answer,open(fn_out,'w'),ensure_ascii=False) 38 | return answer 39 | 40 | def one_hop_path2ans(fn): 41 | 42 | fn_out=fn.replace('path','ans') 43 | data_input=json.load(open(fn,'r')) 44 | answer={} 45 | path_questions = {} 46 | x1 = 0 47 | x2 = 0 48 | x3 = 0 49 | PAD='' 50 | question_multi_path={} 51 | for key,value in data_input.items(): 52 | path_questions[key]=value 53 | origin_tri=value 54 | if len(origin_tri)==0: 55 | print(key,value) 56 | continue 57 | 58 | if PAD==origin_tri[0]: 59 | x1+=1 60 | if origin_tri[2][0]=='"': 61 | e=origin_tri[2] 62 | rel=rel=origin_tri[1] 63 | ans=search_ans(e,rel,cur,reverse=True) 64 | answer[key]=[j[1] for j in ans] 65 | # if len(ans)==1 and ans[0][1][1:-1] in key: 66 | # answer[key]=[ans[0][3]] 67 | continue 68 | # ent=get_entry(origin_tri[2][1:-1],cur) 69 | # # 排序 70 | # tmp1=origin_tri[2][1:-1] 71 | # # 加入mention对应的实体 72 | # ents=[(tmp1,0)] 73 | # ents.extend([(i[2],int(i[-1])) for i in ent]) 74 | 75 | # ents=list(set(ents)) 76 | # ents=sorted(ents,key=lambda x:x[1]) 77 | ents=[(origin_tri[2][1:-1],0)] 78 | rel=origin_tri[1] 79 | for e in ents: 80 | if e[0][0]!='<': 81 | e='<'+e[0]+'>' 82 | else: 83 | e=e[0] 84 | ans=search_ans(e,rel,cur,reverse=True) 85 | if len(ans)==0: 86 | pass 87 | else: 88 | answer[key]=[j[1] for j in ans] 89 | # if len(ans)==1 and ans[0][1][1:-1] in key: 90 | # answer[key]=[ans[0][3]] 91 | break 92 | if key not in answer.keys(): 93 | print(key,value,ents) 94 | 95 | elif PAD==origin_tri[2] and origin_tri[0]!='': 96 | x2+=1 97 | # ent=get_entry(origin_tri[0][1:-1],cur) 98 | # tmp1=origin_tri[0][1:-1] 99 | # # 加入mention对应的实体 100 | # ents=[(tmp1,0)] 101 | # ents.extend([(i[2],int(i[-1])) for i in ent]) 102 | 103 | # ents=list(set(ents)) 104 | # ents=sorted(ents,key=lambda x:x[1]) 105 | 106 | ents=[(origin_tri[0][1:-1],0)] 107 | 108 | rel=origin_tri[1] 109 | 110 | #按知名度进行答案搜索 111 | for e in ents: 112 | if e[0][0]!='<': 113 | e='<'+e[0]+'>' 114 | else: 115 | e=e[0] 116 | ans=search_ans(e,rel,cur,reverse=False) 117 | 118 | if len(ans)==0: 119 | pass 120 | else: 121 | answer[key]=[j[3] for j in ans] 122 | break 123 | else: 124 | print("ERROR question:",key,value) 125 | if not ans: 126 | print("ERROR question:",key,value) 127 | print("正向:",x1,"反向",x2) 128 | print('总共:',len(answer.keys())) 129 | fff = open(fn_out,"w") 130 | json.dump(answer,fff,ensure_ascii=False) 131 | return answer 132 | 133 | def two_hop_path2ans(fns): 134 | for fn in fns: 135 | print(fn+':\n') 136 | data_input=json.load(open(fn,'r')) 137 | fn_out=fn.replace('path','ans') 138 | answer={} 139 | path_questions = {} 140 | mask='' 141 | question_multi_path={} 142 | for key,value in data_input.items(): 143 | path_questions[key]=value 144 | items=value 145 | if len(items)==0: 146 | print(key,value) 147 | continue 148 | 149 | if len(items)==3: 150 | if items[0]==mask: 151 | one_ans=search_ans(items[2],items[1],cur,reverse=True) 152 | answer[key]=list(set([it[1] for it in one_ans])) 153 | else: 154 | one_ans=search_ans(items[0],items[1],cur,reverse=False) 155 | answer[key]=list(set([it[3] for it in one_ans])) 156 | continue 157 | 158 | if '' not in items and len(items)==4: 159 | out=[] 160 | out1=[] 161 | out1.extend([line[3] for line in search_ans(value[0],value[1],cur)]) 162 | out1.extend([line[1] for line in search_ans(value[0],value[1],cur,reverse=True)]) 163 | # if key=='"光武中兴"说的是哪位皇帝?': 164 | # set_trace() 165 | out2=[] 166 | out2.extend([line[3] for line in search_ans(value[2],value[3],cur)]) 167 | out2.extend([line[1] for line in search_ans(value[2],value[3],cur,reverse=True)]) 168 | 169 | for item in out1: 170 | if item in out2: 171 | out.append(item) 172 | answer[key]=list(set(out)) 173 | continue 174 | 175 | if items[0][0]=='<': 176 | mention=items[0][1:-1] 177 | ent=get_entry(mention,cur) 178 | # 排序 179 | # 加入mention对应的实体 180 | ents=[(mention,0)] 181 | ents.extend([(i[2],int(i[-1])) for i in ent]) 182 | ents=list(set(ents)) 183 | ents=sorted(ents,key=lambda x:x[1]) 184 | else: 185 | ents=[items[0]] 186 | 187 | #ents=[(items[0],0)] 188 | final_ans=[] 189 | for e in ents: 190 | if e[0]=='"': 191 | one_ans=search_ans(e,items[1],cur,reverse=True) 192 | for ans in [k[1] for k in one_ans]: 193 | two_ans=search_ans(ans,items[2],cur,reverse=False) 194 | final_ans.extend([i[3] for i in two_ans]) 195 | else: 196 | e='<'+e[0]+'>' 197 | one_ans=search_ans(e,items[1],cur,reverse=False) 198 | one_out=[k[3] for k in one_ans] 199 | one_ans=search_ans(e,items[1],cur,reverse=True) 200 | one_out.extend([k[1] for k in one_ans]) 201 | for ans in one_out: 202 | two_ans=search_ans(ans,items[2],cur,reverse=False) 203 | if not two_ans: 204 | two_ans=search_ans(ans,items[2],cur,reverse=True) 205 | final_ans.extend([i[3] for i in two_ans]) 206 | if not final_ans: 207 | continue 208 | answer[key]=list(set(final_ans)) 209 | break 210 | if not final_ans: 211 | print(key,value,final_ans) 212 | fff = open(fn_out,"w") 213 | json.dump(answer,fff,ensure_ascii=False) 214 | return answer 215 | 216 | # 单个问句的prf值 217 | def one_value(pred,gold): 218 | pred=set(pred) 219 | gold=set(gold) 220 | inter=pred.intersection(gold) 221 | if len(inter)==0: 222 | p,r,f=0.0,0.0,0.0 223 | else: 224 | p=float(len(inter)/len(pred)) 225 | r=len(inter)/len(gold) 226 | f=2*p*r/(p+r) 227 | return p,r,f -------------------------------------------------------------------------------- /PathRanking/utils/search_ans.py: -------------------------------------------------------------------------------- 1 | from ans_tools import * 2 | import argparse 3 | import os 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser(description='search answer') 7 | parser.add_argument("--data_dir", default='../../PathRanking/model/saved_rich', type=str) 8 | args = parser.parse_args() 9 | 10 | data_dir = args.data_dir 11 | # 原始数据:问题答案所有信息等 12 | fn1 = '../../data/test_format.json' 13 | with open(fn1,'r')as f1: 14 | test_all_data=json.load(f1) 15 | questions = test_all_data['questions'] 16 | answers = test_all_data['answers'] 17 | qa_data = [] 18 | for q,a in zip(questions,answers): 19 | qa_data.append((q,a)) 20 | print('all qa data number:',len(qa_data)) 21 | fn_out = os.path.join(data_dir, 'one_hop_best_path_ws_ent.json') 22 | answer=one_hop_path2ans(fn_out) 23 | all_p,all_r,all_f = 0,0,0 24 | num_one_hop=0 25 | for line in qa_data: 26 | (q,a)= line 27 | if q in answer.keys(): 28 | predict=answer[q] 29 | one_p,one_r,one_f=one_value(predict,a) 30 | if one_f <0.9: 31 | pass 32 | #print(q,a,predict) 33 | num_one_hop+=1 34 | all_f+=one_f 35 | print("one_hop F1 value %f/%d=%f"%(all_f,num_one_hop,all_f/num_one_hop)) 36 | 37 | fns = ['left_best_path_ws_ent.json','two_hop_best_path_ws_ent.json'] 38 | for i, fn in enumerate(fns): 39 | fns[i] = os.path.join(data_dir,fn) 40 | answer=two_hop_path2ans([fns[0]]) 41 | 42 | all_p,all_r,all_f = 0,0,0 43 | num_one_hop=0 44 | for line in qa_data: 45 | (q,a)= line 46 | if q in answer.keys(): 47 | predict=answer[q] 48 | one_p,one_r,one_f=one_value(predict,a) 49 | if one_f <0.9: 50 | pass 51 | #print(q,a,predict) 52 | num_one_hop+=1 53 | all_f+=one_f 54 | print("two_hop F1 value %f/%d=%f"%(all_f,num_one_hop,all_f/num_one_hop)) 55 | 56 | answer=two_hop_path2ans([fns[1]]) 57 | 58 | all_p,all_r,all_f = 0,0,0 59 | num_one_hop=0 60 | for line in qa_data: 61 | (q,a)= line 62 | if q in answer.keys(): 63 | predict=answer[q] 64 | one_p,one_r,one_f=one_value(predict,a) 65 | if one_f <0.9: 66 | pass 67 | #print(q,a,predict) 68 | num_one_hop+=1 69 | all_f+=one_f 70 | print("two_hop F1 value %f/%d=%f"%(all_f,num_one_hop,all_f/num_one_hop)) 71 | 72 | fn_out = os.path.join(data_dir,'multi_constraint_best_path_ws_ent.json') 73 | answer=multi_path2ans([fn_out]) 74 | all_p,all_r,all_f = 0,0,0 75 | num_one_hop=0 76 | for line in qa_data: 77 | (q,a)= line 78 | if q in answer.keys(): 79 | predict=answer[q] 80 | one_p,one_r,one_f=one_value(predict,a) 81 | if one_f <0.9: 82 | pass 83 | #print(q,a,predict) 84 | num_one_hop+=1 85 | all_f+=one_f 86 | print("multi constraint F1 value %f/%d=%f"%(all_f,num_one_hop,all_f/num_one_hop)) -------------------------------------------------------------------------------- /PreScreen/data/ans_tools.py: -------------------------------------------------------------------------------- 1 | # 搜索数据库找答案 2 | import sys 3 | import json 4 | import pymysql 5 | from pdb import set_trace 6 | 7 | sys.path.append('../../') 8 | from utils import * 9 | 10 | 11 | read_con = pymysql.connect(host="192.168.126.143",port = 3337,user='root', password='pjzhang', database='ccks_2019',charset='utf8') 12 | cur = read_con.cursor() 13 | 14 | def multi_path2ans_one(value): 15 | out1=[] 16 | out1.extend([line[3] for line in search_ans(value[0],value[1],cur)]) 17 | out1.extend([line[1] for line in search_ans(value[0],value[1],cur,reverse=True)]) 18 | out2=[] 19 | out2.extend([line[3] for line in search_ans(value[2],value[3],cur)]) 20 | out2.extend([line[1] for line in search_ans(value[2],value[3],cur,reverse=True)]) 21 | 22 | out = [] # 答案输出 23 | if len(value)==6: 24 | out3=[] 25 | out3.extend([line[3] for line in search_ans(value[4],value[5],cur)]) 26 | out3.extend([line[1] for line in search_ans(value[4],value[5],cur,reverse=True)]) 27 | for item in out1: 28 | if item in out2 and len(value)==6 and item in out3: 29 | out.append(item) 30 | elif item in out2 and len(value)==4: 31 | out.append(item) 32 | return out 33 | 34 | def one_hop_path2ans(origin_tri, PAD=''): 35 | 36 | if PAD==origin_tri[0]: 37 | if origin_tri[2][0]=='"': 38 | e=origin_tri[2] 39 | rel=rel=origin_tri[1] 40 | ans=search_ans(e,rel,cur,reverse=True) 41 | answer=[j[1] for j in ans] 42 | # if len(ans)==1 and ans[0][1][1:-1] in key: 43 | # answer[key]=[ans[0][3]] 44 | return answer 45 | # ent=get_entry(origin_tri[2][1:-1],cur) 46 | # # 排序 47 | # tmp1=origin_tri[2][1:-1] 48 | # # 加入mention对应的实体 49 | # ents=[(tmp1,0)] 50 | # ents.extend([(i[2],int(i[-1])) for i in ent]) 51 | 52 | # ents=list(set(ents)) 53 | # ents=sorted(ents,key=lambda x:x[1]) 54 | ents=[(origin_tri[2][1:-1],0)] 55 | rel=origin_tri[1] 56 | for e in ents: 57 | if e[0][0]!='<': 58 | e='<'+e[0]+'>' 59 | else: 60 | e=e[0] 61 | ans=search_ans(e,rel,cur,reverse=True) 62 | if len(ans)==0: 63 | pass 64 | else: 65 | answer=[j[1] for j in ans] 66 | # if len(ans)==1 and ans[0][1][1:-1] in key: 67 | # answer[key]=[ans[0][3]] 68 | return answer 69 | 70 | elif PAD==origin_tri[2] and origin_tri[0]!='': 71 | # ent=get_entry(origin_tri[0][1:-1],cur) 72 | # tmp1=origin_tri[0][1:-1] 73 | # # 加入mention对应的实体 74 | # ents=[(tmp1,0)] 75 | # ents.extend([(i[2],int(i[-1])) for i in ent]) 76 | 77 | # ents=list(set(ents)) 78 | # ents=sorted(ents,key=lambda x:x[1]) 79 | 80 | ents=[(origin_tri[0][1:-1],0)] 81 | rel=origin_tri[1] 82 | #按知名度进行答案搜索 83 | for e in ents: 84 | if e[0][0]!='<': 85 | e='<'+e[0]+'>' 86 | else: 87 | e=e[0] 88 | ans=search_ans(e,rel,cur,reverse=False) 89 | 90 | if len(ans)==0: 91 | pass 92 | else: 93 | answer=[j[3] for j in ans] 94 | return answer 95 | 96 | def two_hop_path2ans(path, mask=''): 97 | items=path 98 | final_ans=[] 99 | e = items[0] 100 | if e[0]=='"': 101 | one_ans=search_ans(e,items[1],cur,reverse=True) 102 | for ans in [k[1] for k in one_ans]: 103 | two_ans=search_ans(ans,items[2],cur,reverse=False) 104 | final_ans.extend([i[3] for i in two_ans]) 105 | else: 106 | one_ans=search_ans(e,items[1],cur,reverse=False) 107 | one_out=[k[3] for k in one_ans] 108 | one_ans=search_ans(e,items[1],cur,reverse=True) 109 | one_out.extend([k[1] for k in one_ans]) 110 | for ans in one_out: 111 | two_ans=search_ans(ans,items[2],cur,reverse=False) 112 | if not two_ans: 113 | two_ans=search_ans(ans,items[2],cur,reverse=True) 114 | final_ans.extend([i[3] for i in two_ans]) 115 | answer=list(set(final_ans)) 116 | return answer 117 | 118 | # 单个问句的prf值 119 | def one_value(pred,gold): 120 | pred=set(pred) 121 | gold=set(gold) 122 | inter=pred.intersection(gold) 123 | if len(inter)==0: 124 | p,r,f=0.0,0.0,0.0 125 | else: 126 | p=float(len(inter)/len(pred)) 127 | r=len(inter)/len(gold) 128 | f=2*p*r/(p+r) 129 | return p,r,f -------------------------------------------------------------------------------- /PreScreen/data/count.py: -------------------------------------------------------------------------------- 1 | import json 2 | folds=['BERT_2','BERT_5','BERT_share','BERT_15','BERT_20','BERT_TandA'] 3 | for fold in folds: 4 | with open(fold+"/mix_paths_all.json",'r')as f: 5 | data = json.load(f) 6 | all = 0 7 | num = 0 8 | maxn = 0 9 | for line in data: 10 | v = line['paths'] 11 | v = [tuple(i) for i in v] 12 | 13 | # v = list(set(v)) 14 | all += len(v) 15 | if len(v) > 0: 16 | num += 1 17 | if len(v) > maxn: 18 | maxn = len(v) 19 | print(fold) 20 | print('average number of path:') 21 | print(all/766) 22 | print('num',num) 23 | print('max',maxn) -------------------------------------------------------------------------------- /PreScreen/data/merge_path.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from argparse import ArgumentParser 4 | 5 | if __name__ == "__main__": 6 | parser = ArgumentParser(description='Next Hop For KBQA') 7 | parser.add_argument("--fn_in", default='', type=str) 8 | parser.add_argument("--fn_multi", default="", type=str) 9 | parser.add_argument("--fn_out", default="", type=str) 10 | 11 | args = parser.parse_args() 12 | with open(args.fn_in,'r')as f, open(args.fn_multi,'r')as f2, open(args.fn_out,'w')as fout: 13 | data = json.load(f) 14 | multi_data = json.load(f2) 15 | new_data = [] 16 | for line in data: 17 | q = line['q'] 18 | paths = line['paths'] 19 | if q in multi_data.keys(): 20 | m = multi_data[q] 21 | paths.extend(m) 22 | new_line = {} 23 | new_line['q'] = q 24 | new_line['paths'] = paths 25 | new_data.append(new_line) 26 | json.dump(new_data, fout, ensure_ascii=False) 27 | 28 | 29 | -------------------------------------------------------------------------------- /PreScreen/data/mix_paths.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import argparse 4 | from argparse import ArgumentParser 5 | sys.path.append('../../') 6 | from utils import * 7 | import pymysql 8 | read_con = pymysql.connect(host="192.168.126.143",port = 3337,user='root', password='pjzhang', database='ccks_2019',charset='utf8') 9 | cur = read_con.cursor() 10 | 11 | def select_twohop_path(seq, onehop_path, cur, mask=''): 12 | cand_paths = [] 13 | for path in onehop_path: 14 | if mask == path[0]: # 反向查询 15 | focus, rel = path[-1], path[1] 16 | record = search_ans(focus, rel, cur, reverse=True) 17 | ents = [line[1] for line in record] 18 | for e in ents: 19 | try: 20 | record = from_entry(e,cur) 21 | except: 22 | continue 23 | 24 | path2 = list(set([(i[2], i[3]) for i in record])) 25 | 26 | cand_paths.extend([(focus, rel, p2[0], mask) for p2 in path2 if p2[1] != focus and rel != p2[0]]) # 不找回自己 27 | 28 | cand_paths.extend([(focus,rel,p2[1],p2[0]) for p2 in path2 if p2[1] != focus and rel != p2[0] and del_des(p2[1])[1:-1] in seq]) # 答案结点在句子内出现 29 | 30 | 31 | try: 32 | record = from_value(e, cur) 33 | except: 34 | continue 35 | path2 = list(set([(i[2], i[1]) for i in record])) 36 | # 该情况下不反向查找 37 | # cand_paths.extend([(focus, rel, p2[0], mask) for p2 in path2 if p2[1] != focus and rel != p2[0]]) # 不找回自己 38 | cand_paths.extend([(focus, rel, p2[1], p2[0]) for p2 in path2 if p2[1] != focus and rel != p2[0] and del_des(p2[1])[1:-1] in seq]) # 答案结点在句子内出现 39 | 40 | elif mask == path[2]: # 正向查询 41 | focus, rel = path[0], path[1] 42 | record = search_ans(path[0], path[1], cur, reverse=False) 43 | ents = [line[3] for line in record] 44 | 45 | for e in ents: 46 | if e[0] != '"': 47 | try: 48 | record = from_entry(e,cur) 49 | except: 50 | continue 51 | 52 | path2 = list(set([(i[2], i[3]) for i in record])) 53 | 54 | cand_paths.extend([(focus, rel, p2[0], mask) for p2 in path2 if p2[1] != focus and rel != p2[0]]) # 不找回自己 55 | 56 | cand_paths.extend([(focus, rel, p2[1], p2[0]) for p2 in path2 if p2[1] != focus and rel != p2[0] and del_des(p2[1])[1:-1] in seq]) # 答案结点在句子内出现 57 | 58 | try: 59 | record = from_value(e, cur) 60 | except: 61 | continue 62 | 63 | path2 = list(set([(i[2], i[1]) for i in record])) 64 | 65 | cand_paths.extend([(focus, rel, p2[0], mask) for p2 in path2 if p2[1] != focus and rel != p2[0]]) # 不找回自己 66 | 67 | cand_paths.extend([(focus, rel, p2[1], p2[0]) for p2 in path2 if p2[1] != focus and rel != p2[0] and del_des(p2[1])[1:-1] in seq]) # 答案结点在句子内出现 68 | return list(set(cand_paths)) 69 | 70 | def intersect_items(str1,str2,cur): 71 | # str1,str带引号和尖括号标志 72 | items1,items2=[],[] 73 | items1.extend([i[1] for i in from_value(str1,cur)]) 74 | items2.extend([i[1] for i in from_value(str2,cur)]) 75 | if str1[0]=='<': 76 | items1.extend([i[3] for i in from_entry(str1,cur)]) 77 | if str2[0]=='<': 78 | items2.extend([i[3] for i in from_entry(str2,cur)]) 79 | # 取交集 80 | intersection=[] 81 | for e in items1: 82 | if e in items2: 83 | intersection.append(e) 84 | return intersection 85 | 86 | def link_items(str1,cur): 87 | items1=[] 88 | items1.extend([i[1] for i in from_value(str1,cur)]) 89 | if str1[0]=='<': 90 | items1.extend([normalize(i[3]) for i in from_entry(str1,cur)]) 91 | return items1 92 | 93 | #mention聚类 94 | def clustering(question,mentions,gender=False): 95 | clusters=[] 96 | mentions.sort(key=lambda x:-len(x)) 97 | for mention in mentions: 98 | # 一层判断,应对同义词替换后找不到mention位置的情况 99 | if mention not in question: 100 | continue 101 | flag=False # 标记有没有找到相应的分类 102 | start=question.index(mention) 103 | end=start+len(mention)-1 104 | # 每一个分类 105 | for i,c in enumerate(clusters): 106 | for m in c: 107 | start_m=question.index(m) 108 | end_m=start_m+len(m)-1 109 | if (start>=start_m and start<=end_m) or (end>=start_m and end<=end_m) or (startend_m): 110 | #set_trace() 111 | clusters[i].append(mention) 112 | flag=True 113 | break 114 | if flag: 115 | break 116 | if not flag: 117 | clusters.append([mention]) 118 | new_clusters=[] 119 | for clu in clusters: 120 | if len(clu[0])==1: 121 | if gender: 122 | if clu[0] in ['男','女']: 123 | new_clusters.append(clu) 124 | else: 125 | continue 126 | else: 127 | new_clusters.append(clu) 128 | # mention_longest=[cluster[0] for cluster in new_clusters] 129 | 130 | # this_question=question 131 | # for word in mention_longest: 132 | # this_question=this_question.replace(word,' '+word+'/ ') 133 | return new_clusters 134 | 135 | def clustering_with_syntax(question,mentions,syntax_out,gender=False): 136 | clusters=[] 137 | 138 | my_syntax_label='att' 139 | pattern="%d_%d_att" 140 | 141 | mentions.sort(key=lambda x:-len(x)) 142 | ws=syntax_out.split('\t')[0].split(' ') 143 | syntax=['_'.join(string.split('_')[:-1]) for string in syntax_out.split('\t')[-1].split(' ')] 144 | for mention in mentions: 145 | # 一层判断,应对同义词替换后找不到mention位置的情况 146 | if mention not in question: 147 | continue 148 | flag=False # 标记有没有找到相应的分类 149 | start=question.index(mention) 150 | end=start+len(mention)-1 151 | # 每一个分类 152 | for i,c in enumerate(clusters): 153 | for m in c: 154 | start_m=question.index(m) 155 | end_m=start_m+len(m)-1 156 | if (start>=start_m and start<=end_m) or (end>=start_m and end<=end_m) or (startend_m): 157 | #set_trace() 158 | clusters[i].append(mention) 159 | flag=True 160 | break 161 | else: 162 | if m in ws and mention in ws: 163 | i1=ws.index(m) 164 | i2=ws.index(mention) 165 | if pattern%(i1+1,i2+1) in syntax or pattern%(i2+1,i1+1) in syntax: 166 | clusters[i].append(mention) 167 | flag=True 168 | break 169 | if flag: 170 | break 171 | if not flag: 172 | clusters.append([mention]) 173 | new_clusters=[] 174 | for clu in clusters: 175 | if len(clu[0])==1: 176 | if gender: 177 | if clu[0] in ['男','女']: 178 | new_clusters.append(clu) 179 | else: 180 | continue 181 | else: 182 | new_clusters.append(clu) 183 | return new_clusters 184 | 185 | 186 | def intersect_2path(str1_list,str2_list,cur): 187 | paths=[] 188 | for str1 in str1_list: 189 | for str2 in str2_list: 190 | 191 | # 作为尾实体寻找 192 | items1,items2=[],[] 193 | items1.extend([(str1,i[2],i[1]) for i in from_value(str1,cur)]) 194 | items2.extend([(str2,i[2],i[1]) for i in from_value(str2,cur)]) 195 | 196 | # 作为头实体寻找 197 | if str1[0]=='<': 198 | items1.extend([(str1,i[2],i[3]) for i in from_entry(str1,cur)]) 199 | if str2[0]=='<': 200 | items2.extend([(str2,i[2],i[3]) for i in from_entry(str2,cur)]) 201 | 202 | for a in items1: 203 | item1=a[2] 204 | for b in items2: 205 | item2=b[2] 206 | if item1==item2: 207 | paths.append((a[0],a[1],b[0],b[1])) 208 | return list(set(paths)) 209 | 210 | def intersect_3path(str1_list,str2_list,str3_list,cur): 211 | paths=[] 212 | for str1 in str1_list: 213 | for str2 in str2_list: 214 | for str3 in str3_list: 215 | # 作为尾实体寻找 216 | items1,items2,items3=[],[],[] 217 | items1.extend([(str1,i[2],i[1]) for i in from_value(str1,cur)]) 218 | items2.extend([(str2,i[2],i[1]) for i in from_value(str2,cur)]) 219 | items3.extend([(str3,i[2],i[1]) for i in from_value(str3,cur)]) 220 | 221 | # 作为头实体寻找 222 | if str1[0]=='<': 223 | items1.extend([(str1,i[2],i[3]) for i in from_entry(str1,cur)]) 224 | if str2[0]=='<': 225 | items2.extend([(str2,i[2],i[3]) for i in from_entry(str2,cur)]) 226 | if str3[0]=='<': 227 | items3.extend([(str3,i[2],i[3]) for i in from_entry(str3,cur)]) 228 | 229 | for a in items1: 230 | item1=a[2] 231 | for b in items2: 232 | item2=b[2] 233 | if item1!=item2: 234 | continue 235 | for c in items3: 236 | item3=c[2] 237 | if item1==item3: 238 | #set_trace() 239 | paths.append((a[0],a[1],b[0],b[1],c[0],c[1])) 240 | return list(set(paths)) 241 | 242 | 243 | def get_paths_ent_multi(que, mentions, ents, cur, mask=''): 244 | this_paths = [] 245 | clusters=clustering(que,mentions) # 聚类 246 | topk=100 247 | if len(clusters)==2: 248 | ents=[] 249 | print(que,mentions) 250 | for ment in clusters: # 对一个类别cluster 251 | this_ent=[] 252 | for m in ment: # 对单个cluster内的单个mention 253 | ent1=[(line[2],line[3]) for line in get_entry(m,cur)] 254 | ent1.sort(key=lambda x:x[1]) 255 | this_ent.extend(['<%s>'%one[0] for one in ent1[:topk]]) 256 | 257 | this_ent.append('"%s"'%m) 258 | this_ent.append('<%s>'%m) 259 | this_ent=sorted(this_ent,key=lambda x:len(x[0])) 260 | ents.append(this_ent) 261 | #print(key,ents) 262 | # 取交集 263 | paths=intersect_2path(ents[0],ents[1],cur) 264 | this_paths=paths 265 | elif len(clusters)==3: 266 | ents=[] 267 | for ment in clusters: # 对一个类别cluster 268 | this_ent=[] 269 | for m in ment: # 对单个cluster内的单个mention 270 | ent1=[(line[2],line[3]) for line in get_entry(m,cur)] 271 | ent1.sort(key=lambda x:x[1]) 272 | this_ent.extend(['<%s>'%one[0] for one in ent1[:topk]]) 273 | this_ent.append('"%s"'%m) 274 | this_ent=sorted(this_ent,key=lambda x:len(x[0])) 275 | ents.append(this_ent) 276 | # 取交集 277 | paths=intersect_3path(ents[0],ents[1],ents[2],cur) 278 | this_paths=paths 279 | 280 | # if this_paths: 281 | # return this_paths 282 | 283 | # entities = mentions 284 | # for v in ents: 285 | # entities.append(v) 286 | # entities = list(set(entities)) 287 | # for one_e in entities: 288 | # e = "<%s>" % one_e 289 | # value = '"%s"' % one_e 290 | # # 从实体开始搜索 291 | # # 头找尾 292 | # record = [] 293 | # try: 294 | # record = list(from_entry(e,cur)) 295 | # except: 296 | # continue 297 | # path1 = list(set([(i[1],i[2],i[3]) for i in record if i[1][0] == '<' and i[3][0]=='<'])) 298 | 299 | # # 头尾尾 300 | # for p1 in path1: 301 | # tail = p1[2] 302 | # record = [] 303 | # try: 304 | # record = from_entry(tail,cur) 305 | # except: 306 | # continue 307 | # path2 = list(set([(i[3],i[2]) for i in record])) 308 | # for p2 in path2: 309 | # if p2[0] != e and p1[0][1:-1] != p2[0][1:-1]: 310 | # if p2[0][1:-1] in que or del_des(p2[0])[1:-1] in que: 311 | # this_paths.append((p1[0],p1[1],p2[0],p2[1])) 312 | # else: 313 | # pass 314 | 315 | # # 头尾头 316 | # for p1 in path1: 317 | # tail = p1[2] 318 | # record = [] 319 | # try: 320 | # record = from_value(tail, cur) 321 | # except: 322 | # continue 323 | # path2 = list(set([(i[1], i[2]) for i in record])) 324 | # for p2 in path2: 325 | # if p2[0] != e and p1[0][1:-1] != p2[0][1:-1]: 326 | # if p2[0][1:-1] in que or del_des(p2[0])[1:-1] in que: 327 | # this_paths.append((p1[0],p1[1],p2[0],p2[1])) 328 | # else: 329 | # pass 330 | 331 | # # 从尾部(属性值/实体)开始搜索 332 | # record = [] 333 | # path1 = [] 334 | # try: 335 | # record=list(from_value(e,cur)) 336 | # path1.extend(list(set([(i[3],i[2],i[1]) for i in record]))) 337 | # except: 338 | # pass 339 | # try: 340 | # record = list(from_value(value,cur)) 341 | # path1.extend(list(set([(i[3],i[2],i[1]) for i in record]))) 342 | # except: 343 | # pass 344 | # for p1 in path1: 345 | # tail = p1[2] 346 | # record = [] 347 | # try: 348 | # record = from_entry(tail,cur) 349 | # except: 350 | # continue 351 | # path2 = list(set([(i[3],i[2]) for i in record])) 352 | # for p2 in path2: 353 | # if p2[0] != e and p1[0][1:-1] != p2[0][1:-1]: 354 | # if p2[0][1:-1] in que or del_des(p2[0])[1:-1] in que: 355 | # this_paths.append((p1[0],p1[1],p2[0],p2[1])) 356 | # else: 357 | # pass 358 | 359 | # record = [] 360 | # try: 361 | # record = from_value(tail,cur) 362 | # except: 363 | # continue 364 | # path2 = list(set([(i[1],i[2]) for i in record])) 365 | # for p2 in path2: 366 | # if p2[0] != e and p1[0][1:-1] != p2[0][1:-1]: 367 | # if p2[0][1:-1] in que or del_des(p2[0])[1:-1] in que: 368 | # this_paths.append((p1[0],p1[1],p2[0],p2[1])) 369 | # else: 370 | # pass 371 | return list(set(this_paths)) 372 | 373 | 374 | if __name__ == "__main__": 375 | parser = ArgumentParser(description='Next Hop For KBQA') 376 | parser.add_argument("--fn_test", default='../../data/test_format.json', type=str) 377 | parser.add_argument("--fn_el", default="../../NER/data/test_el_baike_top10.json", type=str) 378 | parser.add_argument("--fn_in", default="lstm_syntax/one_hop_predict_path.json", type=str) 379 | parser.add_argument("--fn_out", default="lstm_syntax/mix_paths.json", type=str) 380 | args = parser.parse_args() 381 | 382 | with open(args.fn_in, 'r')as f: 383 | onehop_paths = json.load(f) 384 | 385 | with open(args.fn_el, 'r')as f: 386 | EL_data = [] 387 | for line in f: 388 | line = line.strip() 389 | if line: 390 | EL_data.append(json.loads(line)) 391 | 392 | with open(args.fn_test, 'r')as f: 393 | test_data = json.load(f) 394 | 395 | with open("multi_paths.json",'r')as f: 396 | multi_data = json.load(f) 397 | 398 | seq_path_dict = onehop_paths 399 | # seq_path_dict = {} 400 | # for line in onehop_paths: 401 | # seq = line['q'] 402 | # onehop_path = line['paths'] 403 | # seq_path_dict[seq] = onehop_path 404 | 405 | mix_paths = [] 406 | 407 | multi_paths = {} 408 | for line_el, seq in zip(EL_data, test_data['questions']): 409 | # multi2one path only 410 | # mentions = line_el["mentions"] 411 | # ents = [] 412 | # for k,v in line_el["ents"].items(): 413 | # ents.extend(v) 414 | # multi_path = get_paths_ent_multi(seq, mentions, ents, cur) 415 | # print("q:%s\tnum of paths:%d"%(seq, len(multi_path))) 416 | # multi_paths[seq] = multi_path 417 | 418 | # with open(args.fn_out, 'w') as f: 419 | # json.dump(multi_paths, f, ensure_ascii=False) 420 | 421 | # all paths included 422 | if seq in seq_path_dict.keys(): 423 | onehop_path = seq_path_dict[seq] 424 | all_paths = onehop_path 425 | mentions = line_el["mentions"] 426 | ents = [] 427 | for k,v in line_el["ents"].items(): 428 | ents.extend(v) 429 | twohop_path = select_twohop_path(seq, onehop_path, cur) 430 | # multi_path = get_paths_ent_multi(seq, mentions, ents, cur) 431 | multi_path = multi_data[seq] 432 | all_paths.extend(twohop_path) 433 | all_paths.extend(multi_path) 434 | 435 | one_data = {} 436 | one_data['q'] = seq 437 | one_data['paths'] = all_paths 438 | mix_paths.append(one_data) 439 | 440 | print("q:%s\tnum of paths:%d"%(seq, len(all_paths))) 441 | 442 | with open(args.fn_out, 'w') as f: 443 | json.dump(mix_paths, f, ensure_ascii=False) 444 | 445 | 446 | 447 | 448 | -------------------------------------------------------------------------------- /PreScreen/data/onehop_path.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import argparse 4 | from argparse import ArgumentParser 5 | sys.path.append('../../') 6 | from utils import * 7 | read_con = pymysql.connect(host="192.168.126.143",port = 3337,user='root', password='pjzhang', database='ccks_2019',charset='utf8') 8 | cur = read_con.cursor() 9 | 10 | def get_paths_ent(mentions, ents): 11 | paths = [] 12 | mask = '' 13 | for m in mentions: 14 | this_m = '"'+m+'"' 15 | this_e = '<'+m+'>' 16 | 17 | try: 18 | record = from_value(this_m, cur) 19 | except: 20 | record = [] 21 | if len(record) > 0: 22 | rels = list(set([i[2] for i in record])) 23 | paths.extend([(mask, r, this_m) for r in rels]) 24 | 25 | try: 26 | record = from_value(this_e, cur) 27 | except: 28 | record = [] 29 | if len(record) > 0: 30 | rels = list(set([i[2] for i in record])) 31 | paths.extend([(mask, r, this_e) for r in rels]) 32 | 33 | try: 34 | record = from_entry(this_e, cur) 35 | except: 36 | record = [] 37 | if len(record) > 0: 38 | rels = list(set([i[2] for i in record])) 39 | paths.extend([(this_e, r, mask) for r in rels]) 40 | 41 | for e in ents: 42 | this_e = '<'+e+'>' 43 | this_m = '"%s"' % e 44 | try: 45 | record = from_entry(this_e, cur) 46 | except: 47 | record = [] 48 | if len(record) > 0: 49 | rels = list(set([i[2] for i in record])) 50 | paths.extend([(this_e, r, mask) for r in rels]) 51 | 52 | try: 53 | record = from_value(this_e, cur) 54 | except: 55 | record = [] 56 | if len(record) > 0: 57 | rels = list(set([i[2] for i in record])) 58 | paths.extend([(mask, r, this_e) for r in rels]) 59 | 60 | try: 61 | record = from_value(this_m, cur) 62 | except: 63 | record = [] 64 | if len(record) > 0: 65 | rels = list(set([i[2] for i in record])) 66 | paths.extend([(mask, r, this_m) for r in rels]) 67 | return paths 68 | 69 | if __name__ == '__main__': 70 | # 输入文件 71 | parser = ArgumentParser(description='All One Hop For KBQA') 72 | parser.add_argument("--fn1", default='../data/test_format.json', type=str) 73 | parser.add_argument("--fn2", default="../NER/data/test_el_baike_top10.json", type=str) 74 | parser.add_argument("--fn_out", default="one_hop_paths.json", type=str) 75 | args = parser.parse_args() 76 | 77 | fn1 = args.fn1 78 | fn2 = args.fn2 79 | fn_out = args.fn_out # 输出文件 80 | 81 | label_one = 1 82 | question_paths = [] 83 | with open(fn1, 'r')as f1, open(fn2, 'r')as f2: 84 | test_all_data = json.load(f1) 85 | test_el_data = [] 86 | for line in f2: 87 | line = line.strip() 88 | if line: 89 | piece = json.loads(line) 90 | test_el_data.append(piece) 91 | for one_data_1, one_data_2 in zip(test_all_data['questions'], test_el_data): 92 | one_question_paths = {} 93 | one_q_1 = one_data_1 94 | one_q_2, mentions, entities = one_data_2['question'], one_data_2['mentions'], one_data_2['ents'] 95 | assert one_q_2 == one_q_1 96 | print(one_q_1) 97 | ents = [] 98 | for k, v in entities.items(): 99 | ents.extend(v) 100 | cand_paths = get_paths_ent(mentions, ents) 101 | one_question_paths['q'] = one_q_1 102 | one_question_paths['paths'] = list(set(cand_paths)) 103 | question_paths.append(one_question_paths) 104 | json.dump(question_paths, open(fn_out, 'w'), ensure_ascii=False) 105 | print("问句数量:", len(question_paths)) -------------------------------------------------------------------------------- /PreScreen/data/search_ans.py: -------------------------------------------------------------------------------- 1 | from ans_tools import * 2 | import argparse 3 | import os 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser(description='search answer') 7 | parser.add_argument("--data_dir", default='', type=str) 8 | parser.add_argument("--fn_in", default='', type=str) 9 | parser.add_argument("--fn_out", default='', type=str) 10 | args = parser.parse_args() 11 | 12 | data_dir = args.data_dir 13 | # 原始数据:问题答案所有信息等 14 | fn_out = args.fn_out 15 | mask = '' 16 | fn_in = args.fn_in 17 | with open(fn_in, 'r')as f: 18 | mix_paths = json.load(f) 19 | all_answer = {} 20 | for k,v in mix_paths.items(): 21 | #v = v[0] 22 | if len(v) == 3: 23 | answer = one_hop_path2ans(v) 24 | elif len(v) == 4 and mask in v: 25 | answer = two_hop_path2ans(v) 26 | elif mask not in v: 27 | answer = multi_path2ans_one(v) 28 | else: 29 | answer = '' 30 | all_answer[k] = answer 31 | # from pdb import set_trace 32 | # set_trace() 33 | with open(fn_out, 'w')as f: 34 | json.dump(all_answer, f, ensure_ascii=False) 35 | 36 | -------------------------------------------------------------------------------- /PreScreen/data/search_ans.sh: -------------------------------------------------------------------------------- 1 | F='merge/' 2 | python search_ans.py --fn_in $F'mix_predict_path.json' --fn_out $F'mix_answer.json' 3 | -------------------------------------------------------------------------------- /PreScreen/modules/charlstm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | 7 | 8 | class CHAR_LSTM(nn.Module): 9 | 10 | def __init__(self, n_chars, n_embed, n_out, pad_index=0): 11 | super(CHAR_LSTM, self).__init__() 12 | 13 | self.n_chars = n_chars 14 | self.n_embed = n_embed 15 | self.n_out = n_out 16 | self.pad_index = pad_index 17 | 18 | # the embedding layer 19 | self.embed = nn.Embedding(num_embeddings=n_chars, 20 | embedding_dim=n_embed) 21 | # the lstm layer 22 | self.lstm = nn.LSTM(input_size=n_embed, 23 | hidden_size=n_out//2, 24 | batch_first=True, 25 | bidirectional=True) 26 | 27 | def __repr__(self): 28 | s = self.__class__.__name__ + '(' 29 | s += f"{self.n_chars}, {self.n_embed}, " 30 | s += f"n_out={self.n_out}, " 31 | s += f"pad_index={self.pad_index}" 32 | s += ')' 33 | 34 | return s 35 | 36 | def forward(self, x): 37 | mask = x.ne(self.pad_index) 38 | lens = mask.sum(dim=1) 39 | 40 | x = pack_padded_sequence(self.embed(x), lens, True, False) 41 | x, (hidden, _) = self.lstm(x) 42 | hidden = torch.cat(torch.unbind(hidden), dim=-1) 43 | 44 | return hidden 45 | -------------------------------------------------------------------------------- /PreScreen/preprocess/check.py: -------------------------------------------------------------------------------- 1 | # import json 2 | 3 | # with open("valid_full_v2.json",'r')as f: 4 | # data = json.load(f) 5 | # questions = data['questions'] 6 | # gold = data['golds'] 7 | # negs = data['negs'] 8 | # # # query 9 | 10 | # index = 0 11 | # q = questions[index] 12 | # g = gold[index] 13 | # n = negs[index] 14 | # print(q) 15 | # print(g) 16 | # print(len(n)) 17 | # for one in n: 18 | # if g[:-3] in one: 19 | # print(one) 20 | 21 | ##*************************************************************************** 22 | # modify the valid data 23 | # cnt = 0 24 | # new_negs = [] 25 | # for g, n in zip(gold, negs): 26 | # if g in n: 27 | # n.remove(g) 28 | # cnt = cnt + 1 29 | # print(cnt) 30 | # new_negs.append(n) 31 | # cnt = 0 32 | # for g, n in zip(gold, new_negs): 33 | # if g in n: 34 | # cnt += 1 35 | # print(cnt) 36 | # assert len(negs) == len(new_negs) 37 | 38 | # new_data ={} 39 | # new_data['questions'] = questions 40 | # new_data['golds'] = gold 41 | # new_data['negs'] = new_negs 42 | # with open("valid_full_v2.json",'w')as f: 43 | # json.dump(new_data, f, ensure_ascii = False) 44 | import torch 45 | import jieba 46 | import json 47 | 48 | with open('../data/one_hop_paths.json','r')as f1: 49 | one_hop_data = json.load(f1) 50 | embed_dict = {} 51 | embed_data = torch.load('test.char.embed') 52 | for line, embedding in zip(one_hop_data, embed_data): 53 | q = line['q'] 54 | q_ws = [i for i in jieba.cut(q)] 55 | assert len(q_ws) == embedding.shape[0] 56 | embed_dict[q] = embedding 57 | torch.save(embed_dict, 'test.char.embed.dict') 58 | -------------------------------------------------------------------------------- /PreScreen/preprocess/func.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import sys 4 | sys.path.append('../../') 5 | from utils import * 6 | 7 | def get_one_hop_path(cur, mentions, PAD='', topk=None, addition=None, addition_bound=None,gold=None, des=False): 8 | ''' 9 | mentions: 搜索起点 10 | gold: gold path 11 | topk: 移除gold path,截取k个路径 12 | addition:额外的关系列表 13 | addition_bouund:扩充至最小数量 14 | ''' 15 | paths = [] 16 | for m in mentions: 17 | this_m='"'+m+'"' 18 | this_e = '<%s>' % m 19 | this_e_n = del_des(this_e) 20 | try: 21 | record=from_value(this_e,cur) 22 | except: 23 | record=[] 24 | 25 | if len(record) > 0: 26 | rels=list(set([i[2] for i in record])) 27 | paths.extend([(PAD, r, this_e if des else this_e_n) for r in rels]) 28 | 29 | try: 30 | record=from_value(this_m,cur) 31 | except: 32 | record=[] 33 | 34 | if len(record)>0: 35 | rels=list(set([i[2] for i in record])) 36 | paths.extend([(PAD, r, this_m) for r in rels]) 37 | 38 | try: 39 | record=from_entry(this_e,cur) 40 | except: 41 | record=[] 42 | 43 | if len(record)>0: 44 | rels=list(set([i[2] for i in record])) 45 | paths.extend([(this_e if des else this_e_n, r, PAD) for r in rels]) 46 | 47 | paths=list(set(paths)) 48 | if gold: 49 | if gold in paths: 50 | paths.remove(gold) 51 | if topk: 52 | random.shuffle(paths) 53 | paths = paths[:topk] 54 | if topk and addition and addition_bound and len(paths)',UNK=''): 8 | words_dict={ PAD:0, UNK:1 } 9 | vector_list=[] 10 | if word_vector_in: 11 | with open(word_vector_in,'r') as f: 12 | for line in f: 13 | 14 | items=line.strip().split() 15 | char=items[0] 16 | #set_trace() 17 | if len(char)==1: 18 | vector=[float(i) for i in items[1:]] 19 | words_dict[char]=len(words_dict) 20 | vector_list.append(vector) 21 | print("char from embed file:%d\n"%len(words_dict)) 22 | for one_q in seq_list: 23 | for word in one_q: 24 | if word not in words_dict: 25 | words_dict[word]=len(words_dict) 26 | word_vector=(torch.rand(len(words_dict),d_vector)-0.5)/2 # [-0.25,0.25]的均匀分布 27 | if word_vector_in: 28 | word_vector[2:2+len(vector_list)]=torch.Tensor(vector_list) 29 | return words_dict,word_vector 30 | 31 | class Corpus: 32 | def __init__(self, PAD='', UNK='', SOS='', EOS='', word_max_len=10): 33 | self.PAD = PAD 34 | self.UNK = UNK 35 | self.SOS = SOS 36 | self.EOS = EOS 37 | self.CHARS = [self.PAD, self.UNK, self.SOS, self.EOS] 38 | self.words_dict = {} 39 | self.word_max_len = word_max_len # 每个词中最多的字数 40 | for char in self.CHARS: 41 | self.words_dict[char] = len(self.words_dict) 42 | self.char_dict = {} 43 | for char in self.CHARS: 44 | self.char_dict[char] = len(self.char_dict) 45 | 46 | def load_embed(self, fn_embed): 47 | print("start load word embedding") 48 | vector_list=[] 49 | with open(fn_embed,'r') as f: 50 | index = -1 51 | for line in f: 52 | index += 1 53 | line = line.strip().split() 54 | if index == 0: 55 | items = line 56 | _ , dim = items[0], int(items[1]) 57 | else: 58 | items = line 59 | word = items[0] 60 | vector = [float(i) for i in items[1:]] 61 | if word not in self.words_dict: 62 | self.words_dict[word] = len(self.words_dict) 63 | vector_list.append(vector) 64 | word_vector=(torch.rand((len(self.words_dict), dim))-0.5)/2 # [-0.25,0.25]的均匀分布 65 | word_vector[len(self.CHARS):len(self.CHARS)+len(vector_list)]=torch.Tensor(vector_list) 66 | return self.words_dict, word_vector 67 | 68 | def load_data(self, fn_data, mode): 69 | # mode: train, valid, test 70 | with open(fn_data, 'r')as f: 71 | data = json.load(f) 72 | if mode == 'train' or mode == 'valid': 73 | questions, golds, negs = data['questions'], data['golds'], data['negs'] 74 | return (questions, golds, negs) 75 | elif mode == 'test': 76 | questions, cands = data['questions'], data['cands'] 77 | return (questions, cands) 78 | 79 | def len_char_dict(self): 80 | return len(self.char_dict) 81 | 82 | def dump_vocab(self, path, mode='word'): 83 | if mode == 'word': 84 | with open(path, 'w')as f: 85 | json.dump(self.words_dict, f, ensure_ascii=False) 86 | elif mode == 'char': 87 | with open(path, 'w')as f: 88 | json.dump(self.char_dict, f, ensure_ascii=False) 89 | else: 90 | print("Mode error! Please check the mode") 91 | 92 | def numericalize(self, sentences, mode, words_dict=None, char_dict=None, state='train'): 93 | # mode: word, char, word_char 94 | 95 | if not words_dict: 96 | words_dict = self.words_dict 97 | if not char_dict: 98 | char_dict = self.char_dict 99 | 100 | PID = words_dict.get(self.PAD) 101 | UID = words_dict.get(self.UNK) 102 | SID = words_dict.get(self.SOS) 103 | EID = words_dict.get(self.EOS) 104 | 105 | if mode == 'word': 106 | sents = [] 107 | for sentence in sentences: 108 | sentence = [word for word in jieba.cut(sentence)] 109 | sent = [words_dict.get(word, UID) for word in sentence] 110 | sents.append(sent) 111 | 112 | max_seq_len = 0 113 | for s in sents: 114 | max_seq_len = max(max_seq_len, len(s)) 115 | for i, s in enumerate(sents): 116 | sents[i] = s + [PID]*(max_seq_len - len(s)) 117 | return torch.LongTensor(sents) 118 | 119 | elif mode == 'char': 120 | sents = [] 121 | for sentence in sentences: 122 | sent = [words_dict.get(char, UID) for char in sentence] 123 | sents.append(sent) 124 | max_seq_len = 0 125 | for s in sents: 126 | max_seq_len = max(max_seq_len, len(s)) 127 | for i, s in enumerate(sents): 128 | sents[i] = s + [PID]*(max_seq_len - len(s)) 129 | return torch.LongTensor(sents) 130 | 131 | elif mode == 'word_char': 132 | sents = [] 133 | sents_char = [] 134 | for sentence in sentences: 135 | words = [word for word in jieba.cut(sentence)] 136 | sent = [words_dict.get(word, UID) for word in words] 137 | 138 | if state != 'test': 139 | for char in sentence: 140 | if char not in char_dict: 141 | char_dict[char] = len(char_dict) 142 | 143 | chars = [[char_dict.get(char, UID) for char in word[:self.word_max_len]] + [PID]*(self.word_max_len - len(word)) for word in words] 144 | sents.append(sent) 145 | sents_char.append(chars) 146 | 147 | # 对词进行pad 148 | max_seq_len = 0 149 | for s in sents: 150 | max_seq_len = max(max_seq_len, len(s)) 151 | for i, s in enumerate(sents): 152 | sents[i] = s + [PID]*(max_seq_len - len(s)) 153 | # 对字进行pad 154 | if sents_char: 155 | sents_char = pad_sequence([torch.LongTensor(line) for line in sents_char], True) 156 | return (torch.LongTensor(sents), torch.LongTensor(sents_char)) -------------------------------------------------------------------------------- /Question_classification/BERT_LSTM_char/bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_pretrained_bert.modeling import BertModel 4 | from pytorch_pretrained_bert.tokenization import BertTokenizer 5 | from torch.nn.utils.rnn import pad_sequence 6 | import unicodedata 7 | from torch.nn.utils.rnn import pack_padded_sequence as pack 8 | from torch.nn.utils.rnn import pad_packed_sequence as pad 9 | 10 | 11 | def _is_whitespace(char): 12 | """Checks whether `chars` is a whitespace character.""" 13 | # \t, \n, and \r are technically contorl characters but we treat them 14 | # as whitespace since they are generally considered as such. 15 | if char == " " or char == "\t" or char == "\n" or char == "\r": 16 | return True 17 | cat = unicodedata.category(char) 18 | if cat == "Zs": 19 | return True 20 | return False 21 | 22 | 23 | def _is_control(char): 24 | """Checks whether `chars` is a control character.""" 25 | # These are technically control characters but we count them as whitespace 26 | # characters. 27 | if char == "\t" or char == "\n" or char == "\r": 28 | return False 29 | cat = unicodedata.category(char) 30 | if cat.startswith("C"): 31 | return True 32 | return False 33 | 34 | 35 | def _is_punctuation(char): 36 | """Checks whether `chars` is a punctuation character.""" 37 | cp = ord(char) 38 | # We treat all non-letter/number ASCII as punctuation. 39 | # Characters such as "^", "$", and "`" are not in the Unicode 40 | # Punctuation class but we treat them as punctuation anyways, for 41 | # consistency. 42 | if ( 43 | (cp >= 33 and cp <= 47) 44 | or (cp >= 58 and cp <= 64) 45 | or (cp >= 91 and cp <= 96) 46 | or (cp >= 123 and cp <= 126) 47 | ): 48 | return True 49 | cat = unicodedata.category(char) 50 | if cat.startswith("P"): 51 | return True 52 | return False 53 | 54 | 55 | def _clean_text(text): 56 | output = [] 57 | for char in text: 58 | cp = ord(char) 59 | if cp == 0 or cp == 0xFFFD or _is_control(char): 60 | continue 61 | if _is_whitespace(char): 62 | output.append(" ") 63 | else: 64 | output.append(char) 65 | return "".join(output) 66 | 67 | 68 | def judge_ignore(word): 69 | if len(_clean_text(word)) == 0: 70 | return True 71 | for char in word: 72 | cp = ord(char) 73 | if cp == 0 or cp == 0xFFFD or _is_control(char): 74 | return True 75 | return False 76 | 77 | def flatten(list_of_lists): 78 | for list in list_of_lists: 79 | for item in list: 80 | yield item 81 | 82 | class Vocab(object): 83 | def __init__(self, bert_vocab_path): 84 | self.tokenizer = BertTokenizer.from_pretrained( 85 | bert_vocab_path, do_lower_case=False 86 | ) 87 | 88 | def convert_tokens_to_ids(self, tokens): 89 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens) 90 | ids = torch.tensor(token_ids, dtype=torch.long) 91 | mask = torch.ones(len(ids), dtype=torch.long) 92 | return ids, mask 93 | 94 | def subword_tokenize(self, tokens): 95 | subwords = list(map(self.tokenizer.tokenize, tokens)) 96 | subword_lengths = [1] + list(map(len, subwords)) + [1] 97 | subwords = ["[CLS]"] + list(flatten(subwords)) + ["[SEP]"] 98 | token_start_idxs = torch.cumsum(torch.tensor([0] + subword_lengths[:-1]), dim=0) 99 | return subwords, token_start_idxs 100 | 101 | def subword_tokenize_to_ids(self, tokens): 102 | tokens = ["[PAD]" if judge_ignore(t) else t for t in tokens] 103 | subwords, token_start_idxs = self.subword_tokenize(tokens) 104 | subword_ids, mask = self.convert_tokens_to_ids(subwords) 105 | token_starts = torch.zeros(len(subword_ids), dtype=torch.uint8) 106 | token_starts[token_start_idxs] = 1 107 | return subword_ids, mask, token_starts 108 | 109 | def tokenize(self, tokens): 110 | subwords = list(map(self.tokenizer.tokenize, tokens)) 111 | subword_lengths = [1] + list(map(len, subwords)) + [1] 112 | subwords = ["[CLS]"] + list(flatten(subwords)) + ["[SEP]"] 113 | return subwords 114 | 115 | class Bert_Classifier(nn.Module): 116 | def __init__(self, data): 117 | super(Bert_Classifier, self).__init__() 118 | self.bert = BertModel.from_pretrained(data.bert_path) 119 | self.args = data 120 | self.use_syntax = data.use_syntax 121 | if self.use_syntax: 122 | self.syntax_embed = nn.Embedding(data.len_syntax_dict, data.syntax_dim) 123 | self.lstm = nn.LSTM(data.bert_embedding_size+data.syntax_dim, data.hidden_dim, num_layers=data.lstm_layer, batch_first=True, bidirectional=data.bilstm) 124 | else: 125 | self.lstm = nn.LSTM(data.bert_embedding_size, data.hidden_dim, num_layers=data.lstm_layer, batch_first=True, bidirectional=data.bilstm) 126 | 127 | if data.bilstm: 128 | self.linear = nn.Linear(data.hidden_dim*2, data.num_labels) 129 | else: 130 | self.linear = nn.Linear(data.hidden_dim, data.num_labels) 131 | self.dropout = nn.Dropout(data.dropout) 132 | 133 | def forward(self, input_idxs, input_masks, syntax_ids=None): 134 | bert_outs, _ = self.bert( 135 | input_idxs, 136 | token_type_ids=None, 137 | attention_mask=input_masks, 138 | output_all_encoded_layers=False, 139 | ) 140 | 141 | lens = torch.sum(input_idxs.gt(0), dim=1) 142 | # bert_outs = torch.split(bert_outs[token_start], lens.tolist()) 143 | bert_outs = pad_sequence(bert_outs, batch_first=True) 144 | lstm_input = bert_outs 145 | if self.use_syntax: 146 | syntax_vec = self.syntax_embed(syntax_ids) 147 | lstm_input = torch.cat((lstm_input, syntax_vec),-1) 148 | 149 | max_len = lstm_input.size(1) 150 | lstm_input = lstm_input[:, :max_len, :] 151 | # mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1) 152 | # add lstm after bert 153 | sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True) 154 | reverse_idx = torch.sort(sorted_idx, dim=0)[1] 155 | lstm_input = lstm_input[sorted_idx] 156 | lstm_input = pack(lstm_input, sorted_lens, batch_first=True) 157 | lstm_output, (h, _) = self.lstm(lstm_input) 158 | 159 | hidden = torch.cat((h[-1, :, :], h[-2, :, :]), -1) 160 | 161 | hidden = hidden[reverse_idx] 162 | out = self.linear(torch.tanh(hidden)) 163 | return out 164 | 165 | def neg_log_likehood(self, subword_idxs, subword_masks, token_start, batch_label): 166 | bert_outs, _ = self.bert( 167 | subword_idxs, 168 | token_type_ids=None, 169 | attention_mask=subword_masks, 170 | output_all_encoded_layers=False, 171 | ) 172 | lens = token_start.sum(dim=1) 173 | 174 | #x = bert_outs[token_start] 175 | bert_outs = torch.split(bert_outs[token_start], lens.tolist()) 176 | bert_outs = pad_sequence(bert_outs, batch_first=True) 177 | max_len = bert_outs.size(1) 178 | mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1) 179 | 180 | 181 | # add lstm after bert 182 | sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True) 183 | reverse_idx = torch.sort(sorted_idx, dim=0)[1] 184 | bert_outs = bert_outs[sorted_idx] 185 | bert_outs = pack(bert_outs, sorted_lens, batch_first=True) 186 | bert_outs, hidden = self.lstm(bert_outs) 187 | bert_outs, _ = pad(bert_outs, batch_first=True) 188 | bert_outs = bert_outs[reverse_idx] 189 | 190 | 191 | out = self.linear(bert_outs) 192 | batch_size = out.size(0) 193 | seq_len = out.size(1) 194 | out = out.view(-1, out.size(2)) 195 | score = torch.nn.functional.log_softmax(out, 1) 196 | loss_function = nn.NLLLoss(ignore_index=0, reduction="sum") 197 | loss = loss_function(score, batch_label.view(-1)) 198 | _, seq = torch.max(score, 1) 199 | seq = seq.view(batch_size, seq_len) 200 | if self.args.average_loss: 201 | loss = loss / mask.float().sum() 202 | return loss,seq 203 | 204 | def extract_feature(self,subword_idxs,subword_masks,token_start,batch_label,layers): 205 | out_layers,outs = [],[] 206 | bert_outs, _ = self.bert( 207 | subword_idxs, 208 | token_type_ids=None, 209 | attention_mask=subword_masks, 210 | output_all_encoded_layers=True, 211 | ) 212 | lens = token_start.sum(dim=1) 213 | #x = bert_outs[token_start] 214 | #bert_outs = torch.split(bert_outs[token_start].cpu(), lens.tolist()) 215 | for layer in layers: 216 | out_layers.append(torch.split(bert_outs[layer][token_start].cpu(), lens.tolist())) 217 | batch_size = subword_idxs.size(0) 218 | for idx in range(batch_size): 219 | items = [] 220 | for idy,item in enumerate(out_layers): 221 | items.append(item[idx].unsqueeze(1)) 222 | outs.append(torch.cat(items,dim=1)) 223 | return outs 224 | 225 | class Bert_Classifier_Pooling(nn.Module): 226 | def __init__(self, data): 227 | super(Bert_Classifier, self).__init__() 228 | self.bert = BertModel.from_pretrained(data.bert_path) 229 | self.args = data 230 | self.use_syntax = data.use_syntax 231 | if self.use_syntax: 232 | self.syntax_embed = nn.Embedding(data.len_syntax_dict, data.syntax_dim) 233 | self.lstm = nn.LSTM(data.bert_embedding_size+data.syntax_dim, data.hidden_dim, num_layers=data.lstm_layer, batch_first=True, bidirectional=data.bilstm) 234 | else: 235 | self.lstm = nn.LSTM(data.bert_embedding_size, data.hidden_dim, num_layers=data.lstm_layer, batch_first=True, bidirectional=data.bilstm) 236 | 237 | if data.bilstm: 238 | self.linear = nn.Linear(data.hidden_dim*2, data.num_labels) 239 | else: 240 | self.linear = nn.Linear(data.hidden_dim, data.num_labels) 241 | self.dropout = nn.Dropout(data.dropout) 242 | 243 | def forward(self, input_idxs, input_masks, syntax_ids=None): 244 | bert_outs, _ = self.bert( 245 | input_idxs, 246 | token_type_ids=None, 247 | attention_mask=input_masks, 248 | output_all_encoded_layers=False, 249 | ) 250 | lens = torch.sum(input_idxs.gt(0), dim=1) 251 | # bert_outs = torch.split(bert_outs[token_start], lens.tolist()) 252 | bert_outs = pad_sequence(bert_outs, batch_first=True) 253 | lstm_input = bert_outs 254 | if self.use_syntax: 255 | syntax_vec = self.syntax_embed(syntax_ids) 256 | lstm_input = torch.cat((lstm_input, syntax_vec),-1) 257 | 258 | max_len = lstm_input.size(1) 259 | lstm_input = lstm_input[:, :max_len, :] 260 | # mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1) 261 | # add lstm after bert 262 | sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True) 263 | reverse_idx = torch.sort(sorted_idx, dim=0)[1] 264 | lstm_input = lstm_input[sorted_idx] 265 | lstm_input = pack(lstm_input, sorted_lens, batch_first=True) 266 | lstm_output, (h, _) = self.lstm(lstm_input) # lstm_output:[batch,sequence_length,embeding] 267 | output, _ = pad(lstm_output, batch_first=True) 268 | output = lstm_output.permute(0, 2, 1) # lstm_output:[batch,embeding,sequence_length] 269 | 270 | output = nn.MaxPool1d(output, output.size()[2]) # lstm_output:[batch,embeding,1] 271 | output = output.squeeze(2) # lstm_output:[batch,embeding] 272 | output = output[reverse_idx] 273 | output = self.linear(output) 274 | out = self.linear(torch.tanh(output)) 275 | return out -------------------------------------------------------------------------------- /Question_classification/BERT_LSTM_word/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | def get_args(): 5 | bert_path="../../data/" 6 | parser = ArgumentParser(description = 'For KBQA') 7 | parser.add_argument("--data_dir",default='train_data',type=str) 8 | parser.add_argument("--bert_path",default=bert_path,type=str) 9 | parser.add_argument("--bert_model", default=bert_path+'bert-base-chinese.tar.gz', type=str) 10 | parser.add_argument("--bert_vocab", default=bert_path+'bert-base-chinese-vocab.txt', type=str) 11 | parser.add_argument("--syntax_embed_path", default='../data/%s.char.embed', type=str, help='path pattern for syntax embedding') 12 | parser.add_argument("--task_name",default='mrpc',type=str,help="The name of the task to train.") 13 | parser.add_argument("--output_dir",default='saved_syntax_word',type=str) 14 | ## Other parameters 15 | parser.add_argument("--cache_dir",default="",type=str,help="Where do you want to store the pre-trained models downloaded from s3") 16 | parser.add_argument("--max_seq_length",default=55,type=int) 17 | parser.add_argument("--do_train",default='true',help="Whether to run training.") 18 | parser.add_argument("--do_eval",default='true',help="Whether to run eval on the dev set.") 19 | parser.add_argument("--do_lower_case",action='store_false',help="Set this flag if you are using an uncased model.") 20 | parser.add_argument("--train_batch_size",default=32,type=int,help="Total batch size for training.") 21 | parser.add_argument("--batch_size",default=32,type=int,help="Total batch size.") 22 | parser.add_argument("--no_gpu",default=1,type=int,help="use no.th gpu") 23 | parser.add_argument("--eval_batch_size",default=32,type=int,help="Total batch size for eval.") 24 | parser.add_argument("--learning_rate",default=1e-5,type=float,help="The initial learning rate for Adam.") 25 | parser.add_argument("--num_train_epochs",default=100,type=float,help="Total number of training epochs to perform.") 26 | parser.add_argument("--warmup_proportion",default=0.1,type=float,) 27 | parser.add_argument("--no_cuda",action='store_true',help="Whether not to use CUDA when available") 28 | parser.add_argument("--local_rank",type=int,default=-1,help="local_rank for distributed training on gpus") 29 | parser.add_argument('--seed',type=int,default=42,help="random seed for initialization") 30 | parser.add_argument('--gradient_accumulation_steps',type=int,default=1,help="Number of updates steps to accumulate before performing a backward/update pass.") 31 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 32 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 33 | ## lstm parameters 34 | parser.add_argument("--requires_grad",action='store_true',help="Whether not to use syntax") 35 | parser.add_argument("--use_syntax",action='store_true',help="Whether not to use syntax") 36 | parser.add_argument("--syntax_hidden_embed",action='store_true',help="Whether not to use embedding of syntax") 37 | parser.add_argument("--maxpooling",action='store_true',help="Whether not to use maxpooling") 38 | parser.add_argument("--avepooling",action='store_true',help="Whether not to use avepooling") 39 | parser.add_argument("--bert_embedding_size",default=768,type=int) 40 | parser.add_argument("--hidden_dim",default=300,type=int) 41 | parser.add_argument("--lstm_layer",default=1,type=int) 42 | parser.add_argument("--bilstm",action='store_true') 43 | parser.add_argument("--len_syntax_dict",default=30,type=int) 44 | parser.add_argument("--syntax_dim",default=50,type=int) 45 | parser.add_argument("--num_labels",default=4,type=int) 46 | parser.add_argument("--dropout",default=0.5,type=float) 47 | args = parser.parse_args() 48 | return args -------------------------------------------------------------------------------- /Question_classification/BERT_LSTM_word/compare.py: -------------------------------------------------------------------------------- 1 | fn1 = 'saved_word/test_result.txt' 2 | fn2 = 'saved_syntax_word_embed_2/test_result.txt' 3 | #label = '0' 4 | 5 | for label in ['1', '2', '3','0']: 6 | num1 = 0 7 | num2 = 0 8 | with open(fn1,'r')as f1, open(fn2, 'r')as f2: 9 | for line1, line2 in zip(f1,f2): 10 | line1 = line1.strip() 11 | line2 = line2.strip() 12 | if line1 and line2: 13 | q = line1.split('\t')[0] 14 | for symbol in ['[',']',"'",',']: 15 | q = q.replace(symbol,'') 16 | gold1 = line1.split('\t')[-2] 17 | gold2 = line2.split('\t')[-2] 18 | pred1 = line1.split('\t')[-1] 19 | pred2 = line2.split('\t')[-1] 20 | assert gold1 == gold2 21 | 22 | if gold1 == label and pred1 != pred2: 23 | if pred2 == label: 24 | print('file2:%s'%''.join(q)) 25 | num2 += 1 26 | elif pred1 == label: 27 | print('file1:%s'%''.join(q)) 28 | num1 += 1 29 | print('gold label:%s, file1 right %d, file2 right %d'%(label, num1, num2)) 30 | -------------------------------------------------------------------------------- /Question_classification/BERT_LSTM_word/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import argparse 3 | import csv 4 | import logging 5 | import os 6 | import random 7 | import sys 8 | import numpy as np 9 | import torch 10 | from args import * 11 | #from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,TensorDataset) 12 | from torch.utils.data.distributed import DistributedSampler 13 | from tqdm import tqdm, trange 14 | from torch.nn import CrossEntropyLoss, MSELoss 15 | #from scipy.stats import pearsonr, spearmanr 16 | #from sklearn.metrics import matthews_corrcoef, f1_score 17 | from argparse import ArgumentParser 18 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME 19 | from bert import * 20 | #from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig 21 | from pytorch_pretrained_bert.tokenization import BertTokenizer 22 | from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule 23 | logger = logging.getLogger(__name__) 24 | 25 | def test(model, examples, bert_field, syn_field, dataloader, args, label_map,device): 26 | ''' 27 | return: 28 | evaluate:(n_test_correct, n_test_data, 1.0*n_test_correct/n_test_data) 29 | result:[(gold label, predict label)] 30 | ''' 31 | model.eval() 32 | test_examples = examples 33 | all_input_ids, all_input_lens, all_input_mask = bert_field.numericalize([example.text_a for example in test_examples]) 34 | 35 | if args.use_syntax: 36 | all_syntax_ids = syn_field.numericalize([example.text_syntax for example in test_examples]) 37 | all_syntax_embed = torch.load(args.syntax_embed_path%'test') 38 | 39 | all_label_ids = torch.tensor([label_map[f.label] for f in test_examples], dtype=torch.long) 40 | 41 | if args.use_syntax: 42 | if args.syntax_hidden_embed: 43 | test_batches = dataloader.get_batches(all_input_ids,all_input_lens,all_input_mask, all_label_ids, syntaxs=all_syntax_embed, hidden_embed=args.syntax_hidden_embed) 44 | else: 45 | test_batches = dataloader.get_batches(all_input_ids,all_input_lens,all_input_mask, all_label_ids, syntaxs=all_syntax_ids, hidden_embed=args.syntax_hidden_embed) 46 | else: 47 | test_batches = dataloader.get_batches(all_input_ids,all_input_lens,all_input_mask, all_label_ids) 48 | 49 | test_batch_num = len(test_batches) 50 | 51 | n_test_correct = 0 52 | n_test_data = 0 53 | results = [] 54 | for batch in test_batches: 55 | batch = tuple(t.to(device) for t in batch) 56 | if args.use_syntax: 57 | input_ids, input_lens, input_mask, label_ids, syntax_ids = batch 58 | else: 59 | input_ids, input_lens, input_mask, label_ids = batch 60 | 61 | if args.use_syntax: 62 | logits = model(input_ids,input_lens,input_mask,syntax_ids) 63 | else: 64 | logits = model(input_ids,input_lens,input_mask) 65 | 66 | predict_batch = torch.max(logits, 1)[1].data 67 | n_test_correct += torch.sum((predict_batch == label_ids.data)) 68 | n_test_data += logits.size(0) 69 | for predict,label in zip(predict_batch,label_ids.data): 70 | results.append((label, predict)) 71 | return (n_test_correct, n_test_data, n_test_correct/(1.0*n_test_data)),results 72 | 73 | 74 | def main(): 75 | args = get_args() 76 | # 句法标签 77 | syn_field = SyntaxField() 78 | args.len_syntax_dict = syn_field.len_syntax_dict 79 | 80 | if args.server_ip and args.server_port: 81 | import ptvsd 82 | print("Waiting for debugger attach") 83 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 84 | ptvsd.wait_for_attach() 85 | 86 | no_gpu = args.no_gpu 87 | n_gpu = 1 88 | device = torch.device("cuda", no_gpu) 89 | print("使用GPU%d" % no_gpu) 90 | print(" ") 91 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 92 | 93 | random.seed(args.seed) 94 | np.random.seed(args.seed) 95 | torch.manual_seed(args.seed) 96 | if n_gpu > 0: 97 | torch.cuda.manual_seed_all(args.seed) 98 | 99 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 100 | pass 101 | #raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 102 | if not os.path.exists(args.output_dir): 103 | os.makedirs(args.output_dir) 104 | 105 | processor = MrpcProcessor() 106 | output_mode = "classification" 107 | 108 | label_list = processor.get_labels() 109 | label_map = {label : i for i, label in enumerate(label_list)} 110 | # label数量 111 | num_labels = len(label_list) 112 | 113 | # 分字器 114 | tokenizer = BertTokenizer.from_pretrained(args.bert_vocab, do_lower_case=args.do_lower_case) 115 | bert_field = BertField('BERT',tokenizer=tokenizer) 116 | 117 | print("loaded tokenizer") 118 | 119 | train_examples = None 120 | num_train_optimization_steps = None 121 | if args.do_train: 122 | train_examples = processor.get_train_examples(args.data_dir) 123 | num_train_optimization_steps = int( 124 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs 125 | # Prepare model 126 | cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE)) 127 | if args.maxpooling or args.avepooling: 128 | if args.syntax_hidden_embed: 129 | model = Bert_Classifier_Pooling_hidden(args) 130 | else: 131 | model = Bert_Classifier_Pooling(args) 132 | else: 133 | model = Bert_Classifier(args) 134 | print("loaded Bert model") 135 | model.to(device) 136 | 137 | # Prepare optimizer 138 | if args.do_train: 139 | param_optimizer = list(model.named_parameters()) 140 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 141 | optimizer_grouped_parameters = [ 142 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 143 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 144 | ] 145 | optimizer = BertAdam(optimizer_grouped_parameters, 146 | lr=args.learning_rate, 147 | warmup=args.warmup_proportion, 148 | t_total=num_train_optimization_steps) 149 | 150 | global_step = 0 151 | nb_tr_steps = 0 152 | tr_loss = 0 153 | 154 | # dataloader定义 155 | dataloader = DataLoader(args.batch_size) 156 | ### eval dataloader 157 | if args.do_eval: 158 | best_precision = 0 159 | patience = 10 160 | iters_left = patience 161 | eval_examples = processor.get_dev_examples(args.data_dir) 162 | all_input_ids, all_input_lens, all_input_mask = bert_field.numericalize([example.text_a for example in eval_examples]) 163 | 164 | if args.use_syntax: 165 | all_syntax_ids = syn_field.numericalize([example.text_syntax for example in eval_examples]) 166 | all_syntax_embed = torch.load(args.syntax_embed_path%'valid') 167 | 168 | all_label_ids = torch.tensor([label_map[f.label] for f in eval_examples], dtype=torch.long) 169 | 170 | # 输入控制 171 | if args.use_syntax: 172 | if args.syntax_hidden_embed: 173 | eval_batches = dataloader.get_batches(all_input_ids,all_input_lens,all_input_mask, all_label_ids, syntaxs=all_syntax_embed, hidden_embed=args.syntax_hidden_embed) 174 | else: 175 | eval_batches = dataloader.get_batches(all_input_ids,all_input_lens,all_input_mask, all_label_ids, syntaxs=all_syntax_ids, hidden_embed=args.syntax_hidden_embed) 176 | else: 177 | eval_batches = dataloader.get_batches(all_input_ids,all_input_lens,all_input_mask, all_label_ids) 178 | eval_batch_num = len(eval_batches) 179 | 180 | ### train dataloader 181 | if args.do_train: 182 | n_batch_correct = 0 183 | len_train_data = 0 184 | i_train_step = 0 185 | 186 | all_input_ids, all_input_lens, all_input_mask = bert_field.numericalize([example.text_a for example in train_examples]) 187 | if args.use_syntax: 188 | all_syntax_ids = syn_field.numericalize([example.text_syntax for example in train_examples]) 189 | all_syntax_embed = torch.load(args.syntax_embed_path%'train') 190 | all_label_ids = torch.tensor([label_map[f.label] for f in train_examples], dtype=torch.long) 191 | # syntax information 192 | if args.use_syntax: 193 | if args.syntax_hidden_embed: 194 | train_batches = dataloader.get_batches(all_input_ids,all_input_lens,all_input_mask, all_label_ids, syntaxs=all_syntax_embed,shuffle=True, hidden_embed=args.syntax_hidden_embed) 195 | else: 196 | train_batches = dataloader.get_batches(all_input_ids,all_input_lens,all_input_mask, all_label_ids, syntaxs=all_syntax_ids,shuffle=True, hidden_embed=args.syntax_hidden_embed) 197 | else: 198 | train_batches = dataloader.get_batches(all_input_ids,all_input_lens,all_input_mask, all_label_ids,shuffle=True) 199 | 200 | train_batch_num = len(train_batches) 201 | loss_fct = CrossEntropyLoss() 202 | 203 | for epoch in range(int(args.num_train_epochs)): 204 | model.train() 205 | tr_loss = 0 206 | len_train_data, n_batch_correct = 0, 0 207 | nb_tr_examples, nb_tr_steps = 0, 0 208 | for step, batch in enumerate(train_batches): 209 | i_train_step += 1 210 | 211 | batch = tuple(t.to(device) for t in batch) 212 | if args.use_syntax: 213 | input_ids, input_lens, input_mask, label_ids, syntax_ids = batch 214 | else: 215 | input_ids, input_lens, input_mask, label_ids = batch 216 | if args.use_syntax: 217 | logits = model(input_ids,input_lens,input_mask,syntax_ids) 218 | else: 219 | logits = model(input_ids,input_lens,input_mask) 220 | 221 | loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) 222 | if epoch == 100: 223 | print(torch.argmax(logits, dim=1)) 224 | print(label_ids) 225 | from pdb import set_trace 226 | set_trace() 227 | loss.backward() 228 | n_batch_correct += torch.sum((torch.max(logits, 1)[1].data == label_ids.data)) 229 | len_train_data += logits.size(0) 230 | if i_train_step % train_batch_num == 0: 231 | P_train = 1. * int(n_batch_correct)/len_train_data 232 | print(" ") 233 | print("-------------------------------------------------------------------") 234 | print("epoch:%d\ttrain_Accuracy-----------------------%d/%d=%f\n"%(epoch,n_batch_correct.data,len_train_data,P_train)) 235 | tr_loss += loss.item() 236 | nb_tr_examples += input_ids.size(0) 237 | nb_tr_steps += 1 238 | 239 | if (step + 1) % args.gradient_accumulation_steps == 0: 240 | optimizer.step() 241 | optimizer.zero_grad() 242 | global_step += 1 243 | 244 | print('loss',tr_loss) 245 | model.to(device) 246 | if args.do_eval: 247 | model.eval() 248 | eval_loss = 0 249 | nb_eval_steps = 0 250 | preds = [] 251 | n_dev_batch_correct = 0 252 | len_dev_data = 0 253 | i_dev_times = 0 254 | P_dev = 0 255 | for batch in eval_batches: 256 | i_dev_times += 1 257 | batch = tuple(t.to(device) for t in batch) 258 | if args.use_syntax: 259 | input_ids, input_lens, input_mask, label_ids, syntax_ids = batch 260 | else: 261 | input_ids, input_lens, input_mask, label_ids = batch 262 | 263 | with torch.no_grad(): 264 | if args.use_syntax: 265 | logits = model(input_ids,input_lens,input_mask,syntax_ids) 266 | else: 267 | logits = model(input_ids,input_lens,input_mask) 268 | 269 | if epoch == 100: 270 | print(torch.argmax(logits, dim=1)) 271 | print(label_ids) 272 | from pdb import set_trace 273 | set_trace() 274 | # create eval loss and other metric required by the task 275 | tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) 276 | 277 | n_dev_batch_correct += torch.sum((torch.max(logits, 1)[1].data == label_ids.data)) 278 | len_dev_data += logits.size(0) 279 | 280 | eval_loss += tmp_eval_loss.mean().item() 281 | nb_eval_steps += 1 282 | if len(preds) == 0: 283 | preds.append(logits.detach().cpu().numpy()) 284 | else: 285 | preds[0] = np.append( 286 | preds[0], logits.detach().cpu().numpy(), axis=0) 287 | 288 | # from pdb import set_trace 289 | # set_trace() 290 | print('loss',eval_loss) 291 | logger.info('loss',eval_loss) 292 | P_dev = 1. * int(n_dev_batch_correct)/len_dev_data 293 | print() 294 | print(" ") 295 | print("-------------------------------------------------------------------") 296 | print("epoch:%d\tdev_Accuracy-----------------------%d/%d=%f\n"%(epoch,n_dev_batch_correct.data,len_dev_data,P_dev)) 297 | logger.info("epoch:%d\tdev_Accuracy-----------------------%d/%d=%f\n"%(epoch,n_dev_batch_correct.data,len_dev_data,P_dev)) 298 | if P_dev > best_precision: 299 | best_precision = P_dev 300 | iters_left = patience 301 | if args.do_eval: 302 | print("epoch %d saved\n"%epoch) 303 | logger.info("epoch %d saved\n"%epoch) 304 | torch.save(model.state_dict(),args.output_dir+'/model_best.pkl') 305 | else: 306 | iters_left-=1 307 | if iters_left == 0: 308 | break 309 | eval_loss = eval_loss / nb_eval_steps 310 | preds = preds[0] 311 | if output_mode == "classification": 312 | preds = np.argmax(preds, axis=1) 313 | elif output_mode == "regression": 314 | preds = np.squeeze(preds) 315 | print('lr', args.learning_rate, 'best_epoch', epoch-patience, 'best precision', best_precision) 316 | 317 | # test 318 | examples = processor.get_test_examples(args.data_dir) 319 | target, results = test(model, examples, bert_field, syn_field, dataloader, args, label_map, device) 320 | print("test_Accuracy-----------------------%d/%d=%f\n"%target) 321 | logger.info("test_Accuracy-----------------------%d/%d=%f\n"%target) 322 | with open(os.path.join(args.output_dir,'test_result.txt'),'w') as f: 323 | questions = [example.text_a for example in examples] 324 | for q,result in zip(questions,results): 325 | f.write("%s\t%d\t%d\n"%(q,result[0],result[1])) 326 | 327 | if __name__ == "__main__": 328 | main() 329 | 330 | -------------------------------------------------------------------------------- /Question_classification/BERT_LSTM_word/run.sh: -------------------------------------------------------------------------------- 1 | # nohup python main.py --use_syntax --maxpooling --requires_grad --output_dir 'saved_syntax_word' >log.txt & 2 | # python main.py --maxpooling --requires_grad --output_dir 'saved_word' >log.txt & 3 | nohup python main.py --no_gpu 4 --maxpooling --requires_grad --output_dir 'saved_syntax_word_embed_2' >log_embed.txt & -------------------------------------------------------------------------------- /Question_classification/BERT_LSTM_word/test_acc.py: -------------------------------------------------------------------------------- 1 | multi_label_dict = {} 2 | with open('../../data/multi_label_result.txt','r')as f: 3 | for line in f: 4 | line = line.strip() 5 | if not line: 6 | continue 7 | else: 8 | q = line.split('\t')[0][:-1] 9 | multi_label_dict[q.replace(' ','')] = 1 10 | 11 | print(multi_label_dict) 12 | # 13 | fn = 'saved_syntax_word_embed_2/test_result.txt' 14 | holder = ['[', ']', "'", ',', ' '] 15 | with open(fn, 'r')as f: 16 | total = 0 17 | right = 0 18 | for line in f: 19 | line= line.strip() 20 | if not line: 21 | continue 22 | items = line.split('\t') 23 | q = items[0] 24 | for i in holder: 25 | q = q.replace(i,'') 26 | #print(q) 27 | gold, pred = int(items[-2]),int(items[-1]) 28 | gold = [gold] 29 | if q in multi_label_dict: 30 | gold.append(multi_label_dict[q]) 31 | total += 1 32 | if pred in gold: 33 | right += 1 34 | print('acc %d/%d=%f'%(right, total, 1.0*right/total)) 35 | -------------------------------------------------------------------------------- /Question_classification/data/convert_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | # pip install requests first 3 | import requests 4 | #from pytorch_pretrained_bert.tokenization import BertTokenizer 5 | import torch 6 | 7 | if __name__ == '__main__': 8 | # url = "http://192.168.126.171:5001/api" 9 | # headers = {"Content-Type": "application/json; charset=UTF-8"} 10 | # modes = ['train', 'valid', 'test'] 11 | 12 | # for mode in modes: 13 | # # with open("../../data/%s.json"%mode,'r')as f, open('../../data/%s_format.json'%mode,'w')as f_format: 14 | # # data = json.load(f) 15 | # # [qindex_list,q_list,s_list,a_list,e_list,rel_list,mention_list,type_list] = data 16 | # # new_data = {'qindex':qindex_list,'questions':q_list,'sqls':s_list,'answers':a_list,'ents':e_list,'triples':rel_list,'mentions':mention_list,'types':type_list} 17 | # # json.dump(new_data, f_format, ensure_ascii=False) 18 | 19 | # with open('../../data/%s_format.json'%mode,'r')as f, open('%s.tsv'%mode,'w')as fout: 20 | # data = json.load(f) 21 | # questions, types = data['questions'], data['types'] 22 | # for q,t in zip(questions, types): 23 | # t = int(t) 24 | # if t in [1,2]: 25 | # label = 1 26 | # elif t in [4,5,7]: 27 | # label = 2 28 | # elif t in [6,8,9,10]: 29 | # label = 3 30 | # else: 31 | # label = 0 32 | # fout.write('%s\t%s\n'%(q,label)) 33 | 34 | # for mode in modes: 35 | # seqs = [] 36 | # labels = [] 37 | # fn = '%s.tsv' % mode 38 | # fn_out = '%s_syntax_word.tsv' % mode 39 | # with open(fn, 'r', encoding='utf-8')as f: 40 | # for line in f: 41 | # one_words, label = line.strip().split('\t')[0].replace(' ',''), line.strip().split('\t')[1] 42 | # seqs.append(one_words) 43 | # labels.append(label) 44 | # #input_json = {"words": words, "ws": True, "pos": True, "dep": True} 45 | # input_json = {"input_string": seqs, "ws": True, "pos": True, "dep": True} 46 | # response = requests.post(url, data=json.dumps(input_json), headers=headers) 47 | # outs = response.json() 48 | # print(outs.keys()) 49 | # with open(fn_out, 'w', encoding='utf-8')as f_out: 50 | # for seq, syns, label in zip(outs['words'], outs['rels'], labels): 51 | # f_out.write("%s\t%s\t%s\n" % (' '.join(seq), ' '.join(syns), label)) 52 | 53 | modes = ['train', 'valid', 'test'] 54 | for mode in modes: 55 | data = torch.load('%s.char.embed'%mode) 56 | print(len(data)) 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CCKS2019-CKBQA 2 | A system for CCKS2019-CKBQA 3 | 4 | 具体方法参考文章《中文知识库问答中的路径选择》:http://jcip.cipsc.org.cn/CN/abstract/abstract3196.shtml 5 | 6 | 知识库链接(BaiduNetDisk):链接: https://pan.baidu.com/s/1XSH-kkzGZa49uE9oFY-GpQ 提取码: 7e5z 7 | 8 | 9 | ## Dependency 10 | 11 | python 3 12 | pytorch==3.5 13 | pytorch-pretrained-bert==0.4 14 | 15 | 16 | ## 知识库导入 17 | 18 | 1. mysql安装 19 | 20 | mysql安装 参考网址:https://blog.csdn.net/tianpy5/article/details/79842888 21 | 22 | 允许远程访问设置:https://blog.csdn.net/h985161183/article/details/82218710 23 | 24 | pymysqlpool安装参考网址:https://www.cnblogs.com/z-x-y/p/9481908.html 25 | 26 | pymsql安装:pip install PyMysqlPool 27 | 28 | 2. 知识库导入数据库 29 | 30 | follow KB/kb_processing.ipydb to create database 31 | 32 | useful instruction: 33 | 34 | 查看总体数据库信息:show databases; 35 | 36 | 创建数据库:create database ccks; 37 | 38 | 选择要使用的数据库:use ccks; 39 | 40 | 查看该数据库下的表的信息:show tables; 41 | 42 | 查看表中数据个数:select count(*) from pkubase; 43 | 44 | 查看表中最后6条数据:select * from pkubase order by id desc limit 0,6; 45 | 46 | 查看当前使用的数据库名字:select database(); 47 | 48 | 查看表结构:desc pkuprop; 49 | 50 | sql创建表时的varchar(num)中的num表示字符个数而不是字节个数。 51 | 52 | 更改密码:update mysql.user set authentication_string=password('yhjia') where user='root'; 53 | 54 | 55 | ## 预处理 56 | 57 | 1. dataset 58 | 59 | mkdir data 60 | 61 | You can download train/dev/test from https://github.com/pkumod/CKBQA and put them into data/ 62 | 63 | 2. preprocecss 64 | 65 | Preprocess.ipynb 66 | 67 | 对原始数据集(train/dev/test)进行预处理,生成 NER/data/train_bert_ner_input.txt、valid_bert_ner_input.txt和test_bert_ner_input.txt文件用于对下一步NER模型的训练。 68 | 69 | 70 | ## NER 71 | 72 | 1. 实体识别 73 | 74 | cd NER 75 | 76 | mkdir snapshot 77 | 78 | sh ccks_run.sh 79 | 80 | 训练阶段将ccks_bert.cfg中的status字段改为train, 预测阶段改为tag。 81 | 生成的NER模型保存在snapshot/modelbest.pkl。 82 | 83 | 2. 利用知识库匹配分词 84 | 85 | python ws.py 86 | 87 | 生成的分词文件在data/questions_ws.txt 88 | 89 | 第一行是问句,第二行是正向最大匹配(知识库中的别名作为词表)的结果,第三行是实体匹配(知识库中的别名作为词表)的结果。 90 | 91 | 2. 利用知识库进行优化,并进行实体链接 92 | 93 | 运行实体识别的优化与实体链接.ipynb 94 | 生成data/test_er_out.json:用途 95 | 生成data/test_er_out_baike.json:用途 96 | 生成data/test_el_baike_top10.json:用途 97 | 98 | ## 语义相似度模型训练 99 | 100 | ### 生成训练数据 101 | 102 | cd PreScreen/preprocess/ 103 | 104 | 运行data.ipynb 105 | 生成../../data/train.json和valid.json:实体链接模型的训练和验证数据 106 | 107 | 108 | ### 训练 109 | 110 | cd PathRanking/model/ 111 | 112 | mkdir saved_sharebert_negfix 113 | 114 | sh train.sh 115 | 生成的实体链接模型存放在saved_sharebert_negfix/pytorch_model.bin 116 | 117 | ## 问句分类模型训练 118 | 119 | cd Question_classification/BERT_LSTM_word 120 | sh run.sh 121 | 122 | 123 | ## 预测部分 124 | 125 | ## 方法1:基于问句分类的方法 126 | 127 | 本方法先对问句进行分类,再检索当前类别的路径,最后经过语义相似度匹配模型 128 | ![问句类型](question_classes.png) 129 | 130 | ## 问句分类 131 | 132 | to do 133 | 134 | ## 方法2:基于集束搜索的方法 135 | 136 | 本方法基于路径跳数不大于2的假设,每一跳会保留topk个最优的当前路径 137 | 138 | ### 预测部分(to do 文件夹结构比较混乱,待优化) 139 | 140 | 141 | 142 | ### 已经训练好了语义相似度匹配模型 143 | 144 | ### step1:搜索一跳路径 145 | 146 | cd PreScreen/data/ 147 | python onehop_path.py 148 | 生成./one_hop_paths.json 149 | 150 | ### step2:预测topk一跳路径 151 | 152 | mkdir /PreScreen/data/merge 153 | 154 | cd PathRanking/model/ 155 | 156 | sh predict_stage1.sh 157 | 生成PreScreen/data/merge/one_hop_predict_path.json:用途 158 | 159 | ### step3:搜索两跳路径 160 | 161 | cd PathRanking/model/ 162 | sh search_path_stage2.sh 163 | 生成PreScreen/data/merge/mix_paths.json:用途 164 | 生成PreScreen/data/merge/mix_paths_all.json:用途 165 | 166 | ### step4:预测一跳两条混合的所有路径中的topk 167 | 168 | cd PathRanking/model/ 169 | 170 | sh predict_stage2.sh 注:把此处的输入文件paths_all_merge.json更名为上一步search_path_stage2.sh生成的mix_paths_all_merge.json 171 | 生成PreScreen/data/merge/mix_predict_path.json 172 | 173 | ### step5:检索最后的答案 174 | 175 | cd PreScreen/data/ 176 | sh search_ans.sh 177 | 生成PreScreen/data/merge/mix_answer.json:用途 178 | 179 | ### step6:检验预测结果 180 | 181 | # 注意修改答案文件路径 182 | evaluation_answer.ipynb 183 | 184 | ## 结果 185 | Average F1: 186 | 187 | ![avatar](results.png) 188 | -------------------------------------------------------------------------------- /question_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThisIsSoMe/CCKS2019-CKBQA/142169a07b285d147cb70c2d118b2d7df3cd7836/question_classes.png -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThisIsSoMe/CCKS2019-CKBQA/142169a07b285d147cb70c2d118b2d7df3cd7836/results.png --------------------------------------------------------------------------------