├── .gitignore ├── README.md ├── index ├── new.py ├── report.doc └── stop /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | /venv/ 3 | data.zip 4 | /.idea/ 5 | train_word_dict 6 | *.doc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Naive Bayesian algorithm for spam classification 朴素贝叶斯分类器实现垃圾邮件分类 4 | 5 | 朴素贝叶斯分类器原理:[朴素贝叶斯分类器 - 维基百科,自由的百科全书](https://en.wikipedia.org/wiki/Naive_Bayes_classifier) 6 | 7 | 朴素贝叶斯分类器做垃圾分类:[贝叶斯推断及其互联网应用(二):过滤垃圾邮件 - 阮一峰的网络日志](http://www.ruanyifeng.com/blog/2011/08/bayesian_inference_part_two.html) 8 | 9 | Google Drive [data.zip](https://drive.google.com/open?id=15Yi14PBw9P1pb045_aIRa-C3cdP0PKT_) (请先确认网络环境), 10 | 11 | OneDrive [data.zip](https://1drv.ms/u/s!ApLylQlrHpBQaIR8C52V1dfX7FE?e=HbfJih) 12 | 13 | 14 | 将数据集解压到仓库路径下即可 15 | 项目结构: 16 | 17 | ![image](https://user-images.githubusercontent.com/48375763/159438565-3c7741a1-6bb8-404f-abd1-b764813d21c9.png) 18 | 19 | > train_word_dict 文件仅用于保存实验结果,不必预先建立 20 | 21 | 22 | 运行方式:`python new.py` 23 | 24 | ## 环境 25 | 26 | * python 3.7 27 | * package : jieba、pandas、codecs 28 | 29 | ## 实验数据及结果分析: 30 | 31 | ![image](https://user-images.githubusercontent.com/48375763/159437202-080e5cb7-84fc-4742-b885-e0f3d58c936e.png) 32 | 33 | 部分训练字典如下 34 | 35 | ``` 36 | {'非': {0: 337, 1: 784}, 37 | '财务': {0: 72, 1: 8821}, 38 | '纠淼': {0: 0, 1: 214}, 39 | '牟': {0: 1, 1: 149}, 40 | '莆': {0: 1, 1: 183}, 41 | '窆': {0: 7, 1: 130}, 42 | '沙盘': {0: 0, 1: 117}, 43 | '模拟': {0: 56, 1: 761}, 44 | '运用': {0: 95, 1: 1267}, ...} 45 | ``` 46 | 47 | 48 | 49 | 当测试集选为20% 训练集选为80% 拉普拉斯平滑参数λ=1 时 50 | 51 | ``` 52 | 测试集样本总数 12924 53 | 正确预计个数 11326 54 | 错误预测个数 1598 55 | 预测准确率: 0.8763540699473847 56 | ``` 57 | 58 | 59 | 60 | ## 步骤解析 61 | 62 | ### 1. 收集数据,收集垃圾邮件文本数据以及停用词 63 | 64 | 仓库包含了两个文件`index`以及`stop`, 前者包含了邮件目录和标签,后者包含了停用词列表。邮件被分成两类:`spam`垃圾邮件以及`ham`正常邮件。 65 | 66 | 首先要从`index`和`stop`中读出标签-路径列表和停用词列表,由函数`load_formatted_data()` 67 | 68 | 和`load_stop_word()`实现。 69 | 70 | 针对`DataFrame`类的特点,使用函数式编程(lambda表达式)可以大幅提高效率以及代码可读性。 71 | 72 | `main`函数内: 73 | 74 | ```python 75 | if __name__ == '__main__': 76 | index_list = load_formatted_data() 77 | stop_words = load_stop_word() 78 | ``` 79 | 80 | 对应函数: 81 | 82 | ```python 83 | def load_formatted_data(): 84 | """ 85 | 加载格式化后的标签-路径列表 86 | spam列为1代表是垃圾邮件,0代表普通邮件 87 | path列代表该邮件路径 88 | :return:(DataFrame)index 89 | """ 90 | # 加载数据集 91 | index = pd.read_csv('index', sep=' ', names=['spam', 'path']) 92 | index.spam = index.spam.apply(lambda x: 1 if x == 'spam' else 0) 93 | index.path = index.path.apply(lambda x: x[1:]) 94 | return index 95 | 96 | ``` 97 | 98 | ```python 99 | def load_stop_word(): 100 | """ 101 | 读出停用词列表 102 | :return: (List)_stop_words 103 | """ 104 | with codecs.open("stop", "r") as f: 105 | lines = f.readlines() 106 | _stop_words = [i.strip() for i in lines] 107 | return _stop_words 108 | ``` 109 | 110 | ### 2.加载数据,使用pandas加载数据,并使用函数式编程(lambda)来对数据进行预处理。 111 | 112 | 依据上一步读取到的标签-路径列表,遍历得到每封邮件的词汇字符串,得到字符串之后,制作每封邮件的词汇字典(Dictionary),形如`{word:1}`。此处采用的是词集模型,即全部文档中的所有单词构成的集合,每个单词只出现一次,仅仅考虑词是否在文本中出现,而不考虑词频。 113 | 114 | `main`函数内: 115 | 116 | ```python 117 | if __name__ == '__main__': 118 | ... 119 | index_list['content'] = index_list.path.apply(lambda x: get_mail_content(x)) 120 | index_list['word_dict'] = index_list.content.apply(lambda x: create_word_dict(x, stop_words)) 121 | ``` 122 | 123 | 对应函数: 124 | 125 | ```python 126 | def get_mail_content(path): 127 | """ 128 | 遍历得到每封邮件的词汇字符串 129 | :param path: 邮件路径 130 | :return:(Str)content 131 | """ 132 | with codecs.open(path, "r", encoding="gbk", errors="ignore") as f: 133 | lines = f.readlines() 134 | 135 | for i in range(len(lines)): 136 | if lines[i] == '\n': 137 | # 去除第一个空行,即在第一个空行之前的邮件协议内容全部舍弃 138 | lines = lines[i:] 139 | break 140 | content = ''.join(''.join(lines).strip().split()) 141 | # print(content) 142 | return content 143 | ``` 144 | 145 | ```python 146 | def create_word_dict(content, stop_words_list): 147 | """ 148 | 依据邮件的词汇字符串统计词汇出现记录,依据停止词列表除去某些词语 149 | :param content: 邮件的词汇字符串 150 | :param stop_words_list:停止词列表 151 | :return:(Dict)word_dict 152 | """ 153 | word_list = [] 154 | word_dict = {} 155 | # word_dict key:word, value:1 156 | # 限定只查找中文字符 157 | content = re.findall(u"[\u4e00-\u9fa5]", content) 158 | content = ''.join(content) 159 | word_list_temp = jieba.cut(content) 160 | for word in word_list_temp: 161 | if word != '' and word not in stop_words_list: 162 | word_list.append(word) 163 | for word in word_list: 164 | word_dict[word] = 1 165 | return word_dict 166 | 167 | ``` 168 | 169 | 中文分词是通过第三方库`jieba`实现的 170 | 171 | ### 3. 训练算法 172 | 173 | 首先需要设置训练集和测试集,假设选定前80%的data数据作为训练集,20%作为测试集 174 | 175 | ```python 176 | if __name__ == '__main__': 177 | ... 178 | train_set = index_list.loc[:len(index_list) * 0.8] 179 | test_set = index_list.loc[len(index_list) * 0.8:] 180 | ``` 181 | 182 | 对数据集进行训练, 统计训练集中某个词在普通邮件和垃圾邮件中的出现次数, 为计算先验概率和后验概率提供数据。 183 | 184 | ```python 185 | if __name__ == '__main__': 186 | ... 187 | train_word_dict, spam_count, ham_count = train_dataset(train_set) 188 | ``` 189 | 190 | 191 | 192 | ```python 193 | def train_dataset(dataset_to_train): 194 | """ 195 | 对数据集进行训练, 统计训练集中某个词在普通邮件和垃圾邮件中的出现次数 196 | :param dataset_to_train: 将要用来训练的数据集 197 | :return:Tuple(词汇出现次数字典_train_word_dict, 垃圾邮件总数spam_count, 正常邮件总数ham_count) 198 | """ 199 | _train_word_dict = {} 200 | # train_word_dict内容,训练集中某个词在普通邮件和垃圾邮件中的出现次数 201 | for word_dict, spam in zip(dataset_to_train.word_dict, dataset_to_train.spam): 202 | # word_dict某封信的词汇表 spam某封信的状态 203 | for word in word_dict: 204 | # 对每封信的每个词在该邮件分类进行出现记录 出现过为则记录数加1 未出现为0 205 | _train_word_dict.setdefault(word, {0: 0, 1: 0}) 206 | _train_word_dict[word][spam] += 1 207 | ham_count = dataset_to_train.spam.value_counts()[0] 208 | spam_count = dataset_to_train.spam.value_counts()[1] 209 | return _train_word_dict, spam_count, ham_count 210 | 211 | ``` 212 | 213 | ### 4. 测试算法 214 | 215 | 先验概率P(s)极大似然估计 216 | 217 | P(spam) = 垃圾邮件数/邮件总数 218 | 219 | P(ham) = 正常邮件数/邮件总数 220 | 221 | 为了计算避免数字过小丢失精度,以下计算均以对数形式进行 222 | 223 | 用W表示某个词,现在需要计算P(S|W)的值,即在某个词语(W)已经存在的条件下,垃圾邮件(S)的概率有多大。 224 | 225 | ![img](http://chart.googleapis.com/chart?cht=tx&chl=P(S%7CW)%3D%5Cfrac%7BP(W%7CS)P(S)%7D%7BP(W%7CS)P(S)%2BP(W%7CH)P(H)%7D&chs=70) 226 | 227 | 又因为对于每个词P(S | H) ,分母与上式一致,所以只需比较分子即可得出结论——大者为更可能的分类结果。 228 | 229 | 为了增强信度,基于上面的推理,需要计算联合概率密度,对每个出现在一信件中的所有词汇的所有后验概率计算联合概率,大者为更可能的分类结果。 230 | 231 | ```python 232 | def predict_dataset(_train_word_dict, _spam_count, _ham_count, data): 233 | """ 234 | 测试算法 235 | :param _train_word_dict:词汇出现次数字典 236 | :param _spam_count:垃圾邮件总数 237 | :param _ham_count:正常邮件总数 238 | :param data:测试集 239 | :return: 240 | """ 241 | total_count = _ham_count + _spam_count 242 | word_dict = data['word_dict'] 243 | 244 | # 先验概率 已经取了对数 245 | ham_probability = math.log(float(_ham_count) / total_count) 246 | spam_probability = math.log(float(_spam_count) / total_count) 247 | 248 | for word in word_dict: 249 | word = word.strip() 250 | _train_word_dict.setdefault(word, {0: 0, 1: 0}) 251 | 252 | # 求联合概率密度 += log 253 | # 拉普拉斯平滑 254 | word_occurs_counts_ham = _train_word_dict[word][0] 255 | # 出现过这个词的信件数 / 垃圾邮件数 256 | ham_probability += math.log((float(word_occurs_counts_ham) + 1) / _ham_count + 2) 257 | 258 | word_occurs_counts_spam = _train_word_dict[word][1] 259 | # 出现过这个词的信件数 / 普通邮件数 260 | spam_probability += math.log((float(word_occurs_counts_spam) + 1) / _spam_count + 2) 261 | 262 | if spam_probability > ham_probability: 263 | is_spam = 1 264 | else: 265 | is_spam = 0 266 | 267 | # 返回预测正确状态 268 | if is_spam == data['spam']: 269 | return 1 270 | else: 271 | return 0 272 | 273 | ``` 274 | 275 | 拉普拉斯平滑:发现0概率会给后验概率计算带来致命影响,从实际意义上看,未出现在训练集中的词语不能说是不可能的,所以有必要指定一个默认值。这个过程称为拉普拉斯平滑。 276 | 277 | ### 5. 使用算法 278 | 279 | 测试算法得出的预测结论与实际分类情况做比较,得出准确率 280 | 281 | ```python 282 | if __name__ == '__main__': 283 | ... 284 | test_mails_predict = test_set.apply( 285 | lambda x: predict_dataset(train_word_dict, spam_count, ham_count, x), axis=1) 286 | 287 | corr_count = 0 288 | false_count = 0 289 | for i in test_mails_predict.values.tolist(): 290 | if i == 1: 291 | corr_count += 1 292 | if i == 0: 293 | false_count += 1 294 | 295 | print("测试集样本总数", (corr_count + false_count)) 296 | print("正确预计个数", corr_count) 297 | print("错误预测个数", false_count) 298 | 299 | result = float(corr_count / (corr_count + false_count)) 300 | print('预测准确率:', result) 301 | ``` 302 | 303 | -------------------------------------------------------------------------------- /new.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | import pandas as pd 4 | import codecs 5 | import jieba 6 | 7 | 8 | def load_formatted_data(): 9 | """ 10 | 加载格式化后的标签-路径列表 11 | spam列为1代表是垃圾邮件,0代表普通邮件 12 | path列代表该邮件路径 13 | :return:(DataFrame)index 14 | """ 15 | # 加载数据集 16 | index = pd.read_csv('index', sep=' ', names=['spam', 'path']) 17 | index.spam = index.spam.apply(lambda x: 1 if x == 'spam' else 0) 18 | index.path = index.path.apply(lambda x: x[1:]) 19 | return index 20 | 21 | 22 | def load_stop_word(): 23 | """ 24 | 读出停用词列表 25 | :return: (List)_stop_words 26 | """ 27 | with codecs.open("stop", "r") as f: 28 | lines = f.readlines() 29 | _stop_words = [i.strip() for i in lines] 30 | return _stop_words 31 | 32 | 33 | def get_mail_content(path): 34 | """ 35 | 遍历得到每封邮件的词汇字符串 36 | :param path: 邮件路径 37 | :return:(Str)content 38 | """ 39 | with codecs.open(path, "r", encoding="gbk", errors="ignore") as f: 40 | lines = f.readlines() 41 | 42 | for i in range(len(lines)): 43 | if lines[i] == '\n': 44 | # 去除第一个空行,即在第一个空行之前的邮件协议内容全部舍弃 45 | lines = lines[i:] 46 | break 47 | content = ''.join(''.join(lines).strip().split()) 48 | # print(content) 49 | return content 50 | 51 | 52 | def create_word_dict(content, stop_words_list): 53 | """ 54 | 依据邮件的词汇字符串统计词汇出现记录,依据停止词列表除去某些词语 55 | :param content: 邮件的词汇字符串 56 | :param stop_words_list:停止词列表 57 | :return:(Dict)word_dict 58 | """ 59 | word_list = [] 60 | word_dict = {} 61 | # word_dict key:word, value:1 62 | content = re.findall(u"[\u4e00-\u9fa5]", content) 63 | content = ''.join(content) 64 | word_list_temp = jieba.cut(content) 65 | for word in word_list_temp: 66 | if word != '' and word not in stop_words_list: 67 | word_list.append(word) 68 | for word in word_list: 69 | word_dict[word] = 1 70 | return word_dict 71 | 72 | 73 | def train_dataset(dataset_to_train): 74 | """ 75 | 对数据集进行训练, 统计训练集中某个词在普通邮件和垃圾邮件中的出现次数 76 | :param dataset_to_train: 将要用来训练的数据集 77 | :return:Tuple(词汇出现次数字典_train_word_dict, 垃圾邮件总数spam_count, 正常邮件总数ham_count) 78 | """ 79 | _train_word_dict = {} 80 | # train_word_dict内容,训练集中某个词在普通邮件和垃圾邮件中的出现次数 81 | for word_dict, spam in zip(dataset_to_train.word_dict, dataset_to_train.spam): 82 | # word_dict某封信的词汇表 spam某封信的状态 83 | for word in word_dict: 84 | # 对每封信的每个词在该邮件分类进行出现记录 出现过为则记录数加1 未出现为0 85 | _train_word_dict.setdefault(word, {0: 0, 1: 0}) 86 | _train_word_dict[word][spam] += 1 87 | ham_count = dataset_to_train.spam.value_counts()[0] 88 | spam_count = dataset_to_train.spam.value_counts()[1] 89 | return _train_word_dict, spam_count, ham_count 90 | 91 | 92 | def predict_dataset(_train_word_dict, _spam_count, _ham_count, data): 93 | """ 94 | 测试算法 95 | :param _train_word_dict:词汇出现次数字典 96 | :param _spam_count:垃圾邮件总数 97 | :param _ham_count:正常邮件总数 98 | :param data:测试集 99 | :return: 100 | """ 101 | total_count = _ham_count + _spam_count 102 | word_dict = data['word_dict'] 103 | 104 | # 先验概率 已经取了对数 105 | ham_probability = math.log(float(_ham_count) / total_count) 106 | spam_probability = math.log(float(_spam_count) / total_count) 107 | 108 | for word in word_dict: 109 | word = word.strip() 110 | _train_word_dict.setdefault(word, {0: 0, 1: 0}) 111 | 112 | # 求联合概率密度 += log 113 | # 拉普拉斯平滑 114 | word_occurs_counts_ham = _train_word_dict[word][0] 115 | # 出现过这个词的信件数 / 垃圾邮件数 116 | ham_probability += math.log((float(word_occurs_counts_ham) + 1) / _ham_count + 2) 117 | 118 | word_occurs_counts_spam = _train_word_dict[word][1] 119 | # 出现过这个词的信件数 / 普通邮件数 120 | spam_probability += math.log((float(word_occurs_counts_spam) + 1) / (_spam_count + 2)) 121 | 122 | if spam_probability > ham_probability: 123 | is_spam = 1 124 | else: 125 | is_spam = 0 126 | 127 | # 返回预测正确状态 128 | if is_spam == data['spam']: 129 | return 1 130 | else: 131 | return 0 132 | 133 | 134 | def save_train_word_dict(train_word_dict): 135 | with codecs.open("train_word_dict", "w", encoding="gbk", errors="ignore") as f: 136 | f.write(train_word_dict) 137 | 138 | 139 | if __name__ == '__main__': 140 | 141 | index_list = load_formatted_data() 142 | stop_words = load_stop_word() 143 | # get_mail_content(index_list.path[0]) 144 | index_list['content'] = index_list.path.apply(lambda x: get_mail_content(x)) 145 | index_list['word_dict'] = index_list.content.apply(lambda x: create_word_dict(x, stop_words)) 146 | 147 | train_set = index_list.loc[:len(index_list) * 0.8] 148 | test_set = index_list.loc[len(index_list) * 0.8:] 149 | 150 | train_word_dict, spam_count, ham_count = train_dataset(train_set) 151 | 152 | save_train_word_dict(train_word_dict) 153 | 154 | test_mails_predict = test_set.apply( 155 | lambda x: predict_dataset(train_word_dict, spam_count, ham_count, x), axis=1) 156 | 157 | corr_count = 0 158 | false_count = 0 159 | for i in test_mails_predict.values.tolist(): 160 | if i == 1: 161 | corr_count += 1 162 | if i == 0: 163 | false_count += 1 164 | 165 | print("测试集样本总数", (corr_count + false_count)) 166 | print("正确预计个数", corr_count) 167 | print("错误预测个数", false_count) 168 | 169 | result = float(corr_count / (corr_count + false_count)) 170 | print('预测准确率:', result) 171 | -------------------------------------------------------------------------------- /report.doc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nado-dev/Naive-Bayes-classifier/88f9138baed07bdccb341a5e3cea9544210cd839/report.doc -------------------------------------------------------------------------------- /stop: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nado-dev/Naive-Bayes-classifier/88f9138baed07bdccb341a5e3cea9544210cd839/stop --------------------------------------------------------------------------------