├── README.md ├── word_segment.py └── syntax_parsing.py /README.md: -------------------------------------------------------------------------------- 1 | # 基于BERT的无监督分词和句法分析 2 | 3 | 文章 [Perturbed Masking: Parameter-free Probing for Analyzing and Interpreting BERT](https://arxiv.org/abs/2004.14786) 所提的方法在中文上的简单验证。 4 | 5 | - 博客介绍:https://kexue.fm/archives/7476 6 | - 原作实现:https://github.com/LividWo/Perturbed-Masking 7 | 8 | # 演示 9 | 10 | 无监督分词效果: 11 | ```python 12 | [u'习近平', u'总书记', u'6月', u'8日', u'赴', u'宁夏', u'考察', u'调研', u'。', u'当天', u'下午', u',他先后', u'来到', u'吴忠', u'市', u'红寺堡镇', u'弘德', u'村', u'、黄河', u'吴忠', u'市城区段、', u'金星', u'镇金花园', u'社区', u',', u'了解', u'当地', u'推进', u'脱贫', u'攻坚', u'、', u'加强', u'黄河流域', u'生态', u'保护', u'、', u'促进', u'民族团结', u'等', u'情况', u'。'] 13 | 14 | [u'大肠杆菌', u'是', u'人和', u'许多', u'动物', u'肠道', u'中最', u'主要', u'且数量', u'最多', u'的', u'一种', u'细菌'] 15 | 16 | [u'苏剑林', u'是', u'科学', u'空间', u'的博主'] 17 | 18 | [u'九寨沟', u'国家级', u'自然', u'保护', u'区', u'位于', u'四川', u'省', u'阿坝藏族羌族', u'自治', u'州', u'南坪县境内', u',', u'距离', u'成都市400多公里', u',', u'是', u'一条', u'纵深', u'40余公里', u'的山沟谷', u'地'] 19 | ``` 20 | 21 | 无监督句法分析: 22 | 23 | 24 | 25 | # 交流 26 | QQ交流群:67729435,微信群请加机器人微信号spaces_ac_cn 27 | -------------------------------------------------------------------------------- /word_segment.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # BERT做无监督分词 3 | # 介绍:https://kexue.fm/archives/7476 4 | 5 | import numpy as np 6 | from bert4keras.models import build_transformer_model 7 | from bert4keras.tokenizers import Tokenizer 8 | from bert4keras.snippets import uniout 9 | 10 | # BERT配置 11 | config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json' 12 | checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 13 | dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt' 14 | 15 | tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器 16 | model = build_transformer_model(config_path, checkpoint_path) # 建立模型,加载权重 17 | 18 | # 文本编码 19 | text = u'大肠杆菌是人和许多动物肠道中最主要且数量最多的一种细菌' 20 | token_ids, segment_ids = tokenizer.encode(text) 21 | length = len(token_ids) - 2 22 | 23 | 24 | def dist(x, y): 25 | """距离函数(默认用欧氏距离) 26 | 可以尝试换用内积或者cos距离,结果差不多。 27 | """ 28 | return np.sqrt(((x - y)**2).sum()) 29 | 30 | 31 | batch_token_ids = np.array([token_ids] * (2 * length - 1)) 32 | batch_segment_ids = np.zeros_like(batch_token_ids) 33 | 34 | for i in range(length): 35 | if i > 0: 36 | batch_token_ids[2 * i - 1, i] = tokenizer._token_mask_id 37 | batch_token_ids[2 * i - 1, i + 1] = tokenizer._token_mask_id 38 | batch_token_ids[2 * i, i + 1] = tokenizer._token_mask_id 39 | 40 | vectors = model.predict([batch_token_ids, batch_segment_ids]) 41 | 42 | threshold = 8 43 | word_token_ids = [[token_ids[1]]] 44 | for i in range(1, length): 45 | d1 = dist(vectors[2 * i, i + 1], vectors[2 * i - 1, i + 1]) 46 | d2 = dist(vectors[2 * i - 2, i], vectors[2 * i - 1, i]) 47 | d = (d1 + d2) / 2 48 | if d >= threshold: 49 | word_token_ids[-1].append(token_ids[i + 1]) 50 | else: 51 | word_token_ids.append([token_ids[i + 1]]) 52 | 53 | words = [tokenizer.decode(ids) for ids in word_token_ids] 54 | print(words) 55 | # 结果:[u'大肠杆菌', u'是', u'人和', u'许多', u'动物', u'肠道', u'中最', u'主要', u'且数量', u'最多', u'的', u'一种', u'细菌'] 56 | -------------------------------------------------------------------------------- /syntax_parsing.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # BERT无监督提取句法结构 3 | # 介绍:https://kexue.fm/archives/7476 4 | 5 | import json 6 | import numpy as np 7 | from bert4keras.models import build_transformer_model 8 | from bert4keras.tokenizers import Tokenizer 9 | from bert4keras.snippets import uniout 10 | import jieba 11 | 12 | # BERT配置 13 | config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json' 14 | checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt' 15 | dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt' 16 | 17 | tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器 18 | model = build_transformer_model(config_path, checkpoint_path) # 建立模型,加载权重 19 | 20 | # 文本编码 21 | text = u'计算机的鼠标有什么比较特殊的用途呢' 22 | words = jieba.lcut(text) 23 | spans = [] 24 | token_ids = [tokenizer._token_start_id] 25 | for w in words: 26 | w_ids = tokenizer.encode(w)[0][1:-1] 27 | token_ids.extend(w_ids) 28 | spans.append((len(token_ids) - len(w_ids), len(token_ids))) 29 | 30 | token_ids.append(tokenizer._token_end_id) 31 | length = len(spans) 32 | 33 | 34 | def dist(x, y): 35 | """距离函数(默认用欧氏距离) 36 | 可以尝试换用内积或者cos距离,结果差不多。 37 | """ 38 | return np.sqrt(((x - y)**2).sum()) 39 | 40 | 41 | batch_token_ids = np.array([token_ids] * (length * (length + 1) // 2)) 42 | batch_segment_ids = np.zeros_like(batch_token_ids) 43 | k, mapping = 0, {} 44 | for i in range(length): 45 | for j in range(i, length): 46 | mapping[i, j] = k 47 | batch_token_ids[k, spans[i][0]:spans[i][1]] = tokenizer._token_mask_id 48 | batch_token_ids[k, spans[j][0]:spans[j][1]] = tokenizer._token_mask_id 49 | k += 1 50 | 51 | vectors = model.predict([batch_token_ids, batch_segment_ids]) 52 | distances = np.zeros((length, length)) 53 | 54 | for i in range(length): 55 | for j in range(i + 1, length): 56 | vi = vectors[mapping[i, i], spans[i][0]:spans[i][1]].mean(0) 57 | vij = vectors[mapping[i, j], spans[i][0]:spans[i][1]].mean(0) 58 | distances[i, j] = dist(vi, vij) 59 | vj = vectors[mapping[j, j], spans[j][0]:spans[j][1]].mean(0) 60 | vji = vectors[mapping[i, j], spans[j][0]:spans[j][1]].mean(0) 61 | distances[j, i] = dist(vj, vji) 62 | 63 | 64 | def build_tree(words, distances): 65 | """递归解析句子层次结构 66 | """ 67 | if len(words) == 1: 68 | return [words[0]] 69 | elif len(words) == 2: 70 | return [words[0], [words[1]]] 71 | else: 72 | k = np.argmax([ 73 | distances[:i, :i].mean() + distances[i:, i:].mean() - 74 | distances[:i, i:].mean() - distances[i:, :i].mean() \ 75 | for i in range(1, len(words) - 1) 76 | ]) + 1 77 | return [ 78 | build_tree(words[:k], distances[:k, :k]), 79 | [words[k], 80 | build_tree(words[k + 1:], distances[k + 1:, k + 1:])] 81 | ] 82 | 83 | 84 | # 用json.dumps做简单的可视化 85 | json.dumps(build_tree(words, distances), indent=4, ensure_ascii=False) 86 | """输出: 87 | [ 88 | [ 89 | [ 90 | "计算机" 91 | ], 92 | [ 93 | "的", 94 | [ 95 | "鼠标" 96 | ] 97 | ] 98 | ], 99 | [ 100 | "有", 101 | [ 102 | [ 103 | [ 104 | "什么" 105 | ], 106 | [ 107 | "比较", 108 | [ 109 | "特殊" 110 | ] 111 | ] 112 | ], 113 | [ 114 | "的", 115 | [ 116 | "用途", 117 | [ 118 | "呢" 119 | ] 120 | ] 121 | ] 122 | ] 123 | ] 124 | ] 125 | """ 126 | --------------------------------------------------------------------------------