├── .gitignore ├── README.md ├── doc └── result.png ├── splitter ├── __init__.py ├── spark.py └── utils.py └── test ├── moyan.txt └── out.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | 3 | *.pyc 4 | .idea/* 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spark_splitter 2 | 3 | spark处理大规模语料库统计词频。代码实现参考[wordmaker](https://github.com/jannson/wordmaker)项目,有兴趣的可以看一下,此项目用到了不少很tricky的技巧提升性能,单纯只想看懂源代码可以参考wordmaker作者的一份[简单版](https://github.com/jannson/yaha/blob/master/extra/segword.cpp)代码。 4 | 5 | 这个项目统计语料库的结果和执行速度都还不错,但缺点也很明显,只能处理GBK编码的文档,而且不能分布式运行,刚好最近在接触spark,所以用python实现了里面的算法,使之能处理更大规模的语料库,并且同时支持GBK和UTF8两种编码格式。 6 | 7 | ## 分词原理 8 | 9 | wordmaker提供了一个统计大规模语料库词汇的算法,和结巴分词的原理不同,它不依赖已经统计好的词库或者隐马尔可夫模型,但是同样能得到不错的统计结果。原作者的文档提到是用多个线程独立计算各个文本块的词的信息,再按词的顺序分段合并,再计算各个段的字可能组成词的概率、左右熵,得到词语输出。下面就详细的讲解各个步骤: 10 | 11 | 1. 读取文本,去掉文本中的换行、空格、标点,将语料库分解成一条一条只包含汉字的句子。 12 | 2. 将上一步中的所有句子切分成各种长度的词,并统计所有词出现的次数。此处的切分只是把所有出现的可能都列出来,举个例子,天气真好 就可以切分为:天 天气 天气真 天气真好 气 气真 气真好 真 真好 好。不过为了省略一些没有必要的计算工作,此处设置了一个词汇长度限制。 13 | 3. 针对上面切分出来的词汇,为了筛掉出那些多个词汇连接在一起的情况,会进一步计算这个词分词的结果,例如 月亮星星中的月亮和星星这两个词汇在步骤二得到的词汇中频率非常高,则认为月亮星星不是一个单独的词汇,需要被剔除掉。 14 | 4. 为了进一步剔除错误切分的词汇,此处用到了[信息熵](http://www.baike.com/wiki/%E4%BF%A1%E6%81%AF%E7%86%B5)的概念。举例:邯郸,邯字右边只有郸字一种组合,所以邯的右熵为0,这样切分就是错误的。因为汉字词语中汉字的关系比较紧密,如果把一个词切分开来,则他们的熵势必会小,只需要取一个合适的阈值过滤掉这种错误切分即可。 15 | 16 | 17 | ## 代码解释 18 | 19 | 原始的C++代码挺长,但是用python改写之后很少,上文中的123步用spark实现非常简单,代码在split函数中,如下: 20 | 21 | ```python 22 | def split(self): 23 | """spark处理""" 24 | raw_rdd = self.sc.textFile(self.corpus_path) 25 | 26 | utf_rdd = raw_rdd.map(lambda line: str_decode(line)) 27 | hanzi_rdd = utf_rdd.flatMap(lambda line: extract_hanzi(line)) 28 | 29 | raw_phrase_rdd = hanzi_rdd.flatMap(lambda sentence: cut_sentence(sentence)) 30 | 31 | phrase_rdd = raw_phrase_rdd.reduceByKey(lambda x, y: x + y) 32 | phrase_dict_map = dict(phrase_rdd.collect()) 33 | total_count = 0 34 | for _, freq in phrase_dict_map.iteritems(): 35 | total_count += freq 36 | 37 | def _filter(pair): 38 | phrase, frequency = pair 39 | max_ff = 0 40 | for i in xrange(1, len(phrase)): 41 | left = phrase[:i] 42 | right = phrase[i:] 43 | left_f = phrase_dict_map.get(left, 0) 44 | right_f = phrase_dict_map.get(right, 0) 45 | max_ff = max(left_f * right_f, max_ff) 46 | return total_count * frequency / max_ff > 100.0 47 | 48 | target_phrase_rdd = phrase_rdd.filter(lambda x: len(x[0]) >= 2 and x[1] >= 3) 49 | result_phrase_rdd = target_phrase_rdd.filter(lambda x: _filter(x)) 50 | self.result_phrase_set = set(result_phrase_rdd.keys().collect()) 51 | self.phrase_dict_map = {key: PhraseInfo(val) for key, val in phrase_dict_map.iteritems()} 52 | ``` 53 | 第三部过滤后的结果已经相对较小,可以直接取出放入内存中,再计算熵过滤,在split中执行```target_phrase_rdd.filter(lambda x: _filter(x))```过滤的时候可以phrase_dict_map做成spark中的广播变量,提升分布式计算的效率,因为只有一台机器,所以就没有这样做。 54 | 55 | ## 分词结果 56 | 57 | 进入spark_splitter/splitter目录,执行命令```PYTHONPATH=. spark-submit spark.py```处理test/moyan.txt文本,只是莫言全集,统计完成的结果在out.txt中,统计部分的结果如下: 58 | 59 | ![result](doc/result.png) 60 | 61 | (我也不知道为什么这个词这么多) 62 | 63 | ## 问题汇总 64 | 65 | 1. 上述的算法不能去掉连词,结果中会出现很多类似于轻轻地 等待着这种词 66 | 2. 单机上用spark处理小规模数据没有任何优势,比wordmaker中的C++代码慢很多 67 | -------------------------------------------------------------------------------- /doc/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuRoy/spark_splitter/4c88e0113ed7bf02af466c815ff889347ac41bf2/doc/result.png -------------------------------------------------------------------------------- /splitter/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding=utf8 -*- 2 | -------------------------------------------------------------------------------- /splitter/spark.py: -------------------------------------------------------------------------------- 1 | # -*- coding=utf8 -*- 2 | from __future__ import division 3 | from math import log 4 | 5 | from pyspark import SparkContext, SparkConf 6 | from utils import str_decode, extract_hanzi, cut_sentence 7 | 8 | 9 | def init_spark_context(): 10 | """初始化spark环境""" 11 | conf = SparkConf().setAppName("spark-splitter") 12 | sc = SparkContext(conf=conf, pyFiles=['utils.py']) 13 | return sc 14 | 15 | 16 | class PhraseInfo(object): 17 | """词频信息""" 18 | def __init__(self, frequency): 19 | self.frequency = frequency 20 | self.left_trim = {} 21 | self.right_trim = {} 22 | 23 | @staticmethod 24 | def calc_entropy(trim): 25 | """计算熵""" 26 | if not trim: 27 | return float('-inf') 28 | 29 | trim_sum = sum(trim.values()) 30 | entropy = 0.0 31 | for _, value in trim.iteritems(): 32 | p = value / trim_sum 33 | entropy -= p * log(p) 34 | return entropy 35 | 36 | def calc_is_keep_right(self): 37 | right_entropy = self.calc_entropy(self.right_trim) 38 | if right_entropy < 1.0: 39 | return False 40 | return True 41 | 42 | def calc_is_keep_left(self): 43 | left_entropy = self.calc_entropy(self.left_trim) 44 | if left_entropy < 1.0: 45 | return False 46 | return True 47 | 48 | def calc_is_keep(self): 49 | if self.calc_is_keep_left() and self.calc_is_keep_right(): 50 | return True 51 | return False 52 | 53 | 54 | class SplitterEngine(object): 55 | """分词引擎""" 56 | def __init__(self, spark_context, corpus): 57 | self.sc = spark_context 58 | self.corpus_path = corpus 59 | 60 | self.result_phrase_set = None 61 | self.phrase_dict_map = None 62 | self.final_result = {} 63 | 64 | def split(self): 65 | """spark处理""" 66 | raw_rdd = self.sc.textFile(self.corpus_path) 67 | 68 | utf_rdd = raw_rdd.map(lambda line: str_decode(line)) 69 | hanzi_rdd = utf_rdd.flatMap(lambda line: extract_hanzi(line)) 70 | 71 | raw_phrase_rdd = hanzi_rdd.flatMap(lambda sentence: cut_sentence(sentence)) 72 | 73 | phrase_rdd = raw_phrase_rdd.reduceByKey(lambda x, y: x + y) 74 | phrase_dict_map = dict(phrase_rdd.collect()) 75 | total_count = 0 76 | for _, freq in phrase_dict_map.iteritems(): 77 | total_count += freq 78 | 79 | def _filter(pair): 80 | phrase, frequency = pair 81 | max_ff = 0 82 | for i in xrange(1, len(phrase)): 83 | left = phrase[:i] 84 | right = phrase[i:] 85 | left_f = phrase_dict_map.get(left, 0) 86 | right_f = phrase_dict_map.get(right, 0) 87 | max_ff = max(left_f * right_f, max_ff) 88 | return total_count * frequency / max_ff > 100.0 89 | 90 | target_phrase_rdd = phrase_rdd.filter(lambda x: len(x[0]) >= 2 and x[1] >= 3) 91 | result_phrase_rdd = target_phrase_rdd.filter(lambda x: _filter(x)) 92 | self.result_phrase_set = set(result_phrase_rdd.keys().collect()) 93 | self.phrase_dict_map = {key: PhraseInfo(val) for key, val in phrase_dict_map.iteritems()} 94 | 95 | def post_process(self): 96 | """根据熵过滤""" 97 | for phrase, phrase_info in self.phrase_dict_map.iteritems(): 98 | if len(phrase) < 3: 99 | continue 100 | freq = phrase_info.frequency 101 | 102 | left_trim = phrase[:1] 103 | right_part = phrase[1:] 104 | if right_part in self.result_phrase_set\ 105 | and right_part in self.phrase_dict_map: 106 | p_info = self.phrase_dict_map[right_part] 107 | p_info.left_trim[left_trim] = p_info.left_trim.get(left_trim, 0) + freq 108 | 109 | right_trim = phrase[-1:] 110 | left_part = phrase[:-1] 111 | if left_part in self.result_phrase_set \ 112 | and left_part in self.phrase_dict_map: 113 | p_info = self.phrase_dict_map[left_part] 114 | p_info.right_trim[right_trim] = p_info.right_trim.get(right_trim, 0) + freq 115 | 116 | for words in self.result_phrase_set: 117 | if words not in self.phrase_dict_map: 118 | continue 119 | 120 | words_info = self.phrase_dict_map[words] 121 | if words_info.calc_is_keep(): 122 | self.final_result[words] = words_info.frequency 123 | 124 | def save(self, out): 125 | """输出结果""" 126 | with open(out, 'w') as f: 127 | for phrase, frequency in self.final_result.iteritems(): 128 | f.write('{}\t{}\n'.format(phrase.encode('utf8'), frequency)) 129 | 130 | 131 | if __name__ == '__main__': 132 | spark_content = init_spark_context() 133 | corpus_path = '../test/moyan.txt' 134 | out_path = '../test/out.txt' 135 | engine = SplitterEngine(spark_content, corpus_path) 136 | 137 | engine.split() 138 | engine.post_process() 139 | engine.save(out_path) 140 | 141 | -------------------------------------------------------------------------------- /splitter/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding=utf8 -*- 2 | 3 | import re 4 | hanzi_re = re.compile(u"[\u4E00-\u9FD5]+", re.U) 5 | PHRASE_MAX_LENGTH = 6 6 | 7 | 8 | def str_decode(sentence): 9 | """转码""" 10 | if not isinstance(sentence, unicode): 11 | try: 12 | sentence = sentence.decode('utf-8') 13 | except UnicodeDecodeError: 14 | sentence = sentence.decode('gbk', 'ignore') 15 | return sentence 16 | 17 | 18 | def extract_hanzi(sentence): 19 | """提取汉字""" 20 | return hanzi_re.findall(sentence) 21 | 22 | 23 | def cut_sentence(sentence): 24 | """把句子按照前后关系切分""" 25 | result = {} 26 | sentence_length = len(sentence) 27 | for i in xrange(sentence_length): 28 | for j in xrange(1, min(sentence_length - i, PHRASE_MAX_LENGTH + 1)): 29 | tmp = sentence[i: j + i] 30 | result[tmp] = result.get(tmp, 0) + 1 31 | return result.items() 32 | --------------------------------------------------------------------------------