├── .gitignore ├── baseline.py ├── img ├── Figure_1.png └── Figure_2.png ├── readme.md ├── spider.py ├── stopwords.py ├── svm.py ├── text2term.py ├── vectorizer.py └── viewer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python: 2 | *.py[cod] 3 | *.so 4 | *.egg 5 | *.egg-info 6 | *.txt 7 | *.pkl 8 | 9 | -------------------------------------------------------------------------------- /baseline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import time 6 | 7 | import numpy as np 8 | 9 | from math import log 10 | 11 | from sklearn.externals import joblib 12 | 13 | 14 | def trainNaiveBayesClassifier(): 15 | """ 16 | 训练朴素贝叶斯分类器 17 | """ 18 | # 训练集特征矩阵保存路径 19 | train_matrix_path = 'matrix/train/matrix.pkl' 20 | # 后验概率保存路径 21 | prob_clf_path = 'classifier/baseline.pkl' 22 | 23 | # 加载数据 24 | matrix = joblib.load(train_matrix_path) 25 | 26 | # 计算每一类中每个特征的后验概率 27 | prob_clf = [] 28 | for clf in range(0, 10): 29 | start_row = clf * 50000 30 | end_row = (clf+1) * 50000 31 | 32 | vector = matrix[start_row:end_row].sum(axis=0) 33 | total_word = matrix[start_row:end_row].sum() 34 | 35 | prob = np.log((vector+1)/float(total_word)) 36 | prob_clf.append(prob) 37 | 38 | # 保存后验概率 39 | joblib.dump(prob_clf, prob_clf_path) 40 | 41 | 42 | def testNaiveBayesClassifier(): 43 | """ 44 | 测试朴素贝叶斯分类器 45 | """ 46 | # 测试集特征矩阵保存路径 47 | test_matrix_path = 'matrix/test/matrix.pkl' 48 | # 后验概率保存路径 49 | prob_clf_path = 'classifier/baseline.pkl' 50 | 51 | # 加载数据 52 | matrix = joblib.load(test_matrix_path) 53 | target = np.array([x for x in range(10) for i in range(50000)]) 54 | 55 | # 加载贝叶斯每一类的后验概率 56 | prob_clf = joblib.load(prob_clf_path) 57 | 58 | # 预测 59 | confusion_matrix = np.zeros(shape=(10,10),dtype=int) 60 | for i in range(0, len(target)): 61 | max_value, predicted = -float('inf'), 0 62 | a = np.array(matrix[i].sum(axis=0))[0] 63 | 64 | for clf in range(0, 10): 65 | b = np.array(prob_clf[clf])[0] 66 | value = np.dot(a, b) 67 | if value > max_value: 68 | max_value = value 69 | predicted = clf 70 | 71 | confusion_matrix[target[i]][predicted] += 1 72 | 73 | joblib.dump(confusion_matrix, 'results/Bayes_confusion_matrix.pkl') 74 | 75 | # 统计 76 | recall_list, precision_list, f_list = [], [], [] 77 | correct = 0 78 | r = confusion_matrix.sum(axis=1) 79 | p = confusion_matrix.sum(axis=0) 80 | for clf in range(0, 10): 81 | recall = confusion_matrix[clf][clf] / float(r[clf]) 82 | precision = confusion_matrix[clf][clf] / float(p[clf]) 83 | f = 2*recall*precision/(recall+precision) 84 | recall_list.append(recall) 85 | precision_list.append(precision) 86 | f_list.append(f) 87 | correct += confusion_matrix[clf][clf] 88 | correct /= float(matrix.shape[0]) 89 | 90 | # 打印测试报告 91 | print confusion_matrix,'\n' 92 | 93 | print '{0:>14}\t{1:<10}\t{2:<10}\t{3:<10}'.format('classification','Recall','Precision','F1-Score') 94 | for i, target_name in enumerate(os.listdir('data/test/raw/')): 95 | print '{0:>14}\t{1:<10.4f}\t{2:<10.4f}\t{3:<10.4f}'.format(target_name, recall_list[i], precision_list[i], f_list[i]) 96 | print '' 97 | avg_r, avg_p, avg_f = 0.0, 0.0, 0.0 98 | for a,b,c in zip(recall_list,precision_list,f_list): 99 | avg_r += a 100 | avg_p += b 101 | avg_f += c 102 | print '{0:>14}\t{1:<10.4f}\t{2:<10.4f}\t{3:<10.4f}'.format('avg / total', avg_r/10, avg_p/10, avg_f/10) 103 | 104 | print '\n','Correct Rate:',correct 105 | 106 | 107 | if __name__ == '__main__': 108 | # 训练 109 | time_start = time.time() 110 | trainNaiveBayesClassifier() 111 | print 'Training time:', time.time()-time_start, 's' 112 | 113 | # 测试 114 | time_start = time.time() 115 | testNaiveBayesClassifier() 116 | print 'Testing time:', time.time()-time_start, 's' 117 | 118 | -------------------------------------------------------------------------------- /img/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qyfang/TextClassification/4c4817495a017a98af471d807ee51d8b80ba39b4/img/Figure_1.png -------------------------------------------------------------------------------- /img/Figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qyfang/TextClassification/4c4817495a017a98af471d807ee51d8b80ba39b4/img/Figure_2.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # 新浪新闻文本分类 2 | 3 | ## 语料库构建 4 | 5 | 本项目的语料来源新浪新闻网,通过 _spider.py_ 爬虫模块获得全部语料,总计获得10类新闻文本,每一类新闻文本有10w篇。 6 | 7 | * 借助新浪新闻网的一个api获取新闻文本,api的url为[http://api.roll.news.sina.com.cn/zt_list?](http://api.roll.news.sina.com.cn/zt_list?) 8 | 9 | * 使用进程池并发执行爬虫,加快抓取速度。 10 | 11 | ## 数据预处理 12 | 13 | 本项目的数据预处理包括:分词处理,去噪,向量化,由 _stopwords.py_ 模块、_text2term.py_ 模块、_vectorizer.py_ 模块实现。 14 | 15 | * 本项目借助第三方库 _jieba_ 完成文本的分词处理。 16 | 17 | * 通过停用词表去除中文停用词,通过正则表达式去除数字(中文数字&阿拉伯数字)。 18 | 19 | ```python 20 | filter_pattern = re.compile(ur'[-+]?[\w\d]+|零|一|二|三|四|五|六|七|八|九|十|百|千|万|亿') 21 | ``` 22 | 23 | * 使用进程池并发执行数据的分词和去噪,加快数据预处理的过程。 24 | 25 | * 把数据集1:1划分为训练集和测试集,各50w篇文档。 26 | 27 | * 借助scikit-learn提供的`CountVectorizer`类完成向量化,得到训练集和测试集两个文本的特征矩阵,矩阵类型为稀疏矩阵。 28 | 29 | * 去除文档中文档频率小于0.1%的特征,这些特征我们认为出现的频率实在太低同时也不可能为某类文档的局部特征,以此完成降维,最终特征矩阵的维度大约为19543维。 30 | 31 | ## 朴素贝叶斯分类 32 | 33 | 本项目使用朴素贝叶斯作为本项目文本分类的baseline,由 _baseline.py_ 模块实现。 34 | 35 | * 平滑处理 36 | 37 | * 处理零概率 38 | 39 | * 最终分类结果: 40 | 最高召回率:0.95 | 最低召回率:0.46 | 平均召回率:0.79 41 | 最高精确度:0.96 | 最低精确度:0.55 | 平均精确度:0.78 42 | 最高F1测度:0.93 | 最低F1测度:0.50 | 平均F1测度:0.79 43 | 44 | ## SVM分类 45 | 46 | 本项目使用SVM作为最终的文本分类器,由 _svm.py_ 模块实现其中SVM的核函数选用线性核,特征矩阵投入训练前经过词频加权. 47 | 48 | * 借助`TfidfTransformer`使用TF-IDF对词频进行加权 49 | 50 | * 选用线性核`LinearSVC` 51 | 52 | * 结合5折交叉验证和网格搜索`GridSearchCV`完成调参 53 | 54 | * 最终分类结果: 55 | 最高召回率:0.99 | 最低召回率:0.77 | 平均召回率:0.90 56 | 最高精确度:0.98 | 最低精确度:0.77 | 平均精确度:0.90 57 | 最高F1测度:0.99 | 最低F1测度:0.77 | 平均F1测度:0.90 58 | 59 | ## 可视化 60 | 61 | 比较SVM分类器和贝叶斯分类器的分类性能,通过可视化的方式比较两者的预测结果,由 _viewer.py_ 模块实现。 62 | 63 | ### 混淆矩阵热力图 64 | 65 | ![混淆矩阵热力图](img/Figure_1.png) 66 | 67 | ### 性能对比直方图 68 | 69 | ![性能对比直方图](img/Figure_2.png) 70 | 71 | -------------------------------------------------------------------------------- /spider.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | 3 | import re 4 | 5 | import sys 6 | 7 | import json 8 | 9 | import requests 10 | 11 | import multiprocessing 12 | 13 | from bs4 import BeautifulSoup 14 | 15 | 16 | def getNewsUrl(*params): 17 | url_param = { 18 | 'channel': params[0], 19 | 'page': str(params[1]), 20 | 'show_all': '1', 21 | # 'at_1': 'gnxw', 22 | 'show_num': '5000', 23 | 'tag': '1', 24 | 'format': 'json', 25 | } 26 | 27 | # 构造api的url 28 | api_url = 'http://api.roll.news.sina.com.cn/zt_list?' 29 | for key, value in url_param.iteritems(): 30 | api_url += key + '=' + value + '&' 31 | api_url = api_url[:-1] 32 | 33 | # 请求api数据 34 | response = requests.get(api_url) 35 | data = json.loads(response.content) 36 | data = data['result']['data'] 37 | 38 | # 提取新闻的url 39 | for term in data: 40 | url = term['url'] 41 | yield url 42 | 43 | 44 | def loadNews(*params): 45 | reload(sys) 46 | sys.setdefaultencoding('utf8') 47 | 48 | url = params[0] 49 | clsf = params[1] 50 | 51 | response = requests.get(url) 52 | soup = BeautifulSoup(response.content, 'html.parser') 53 | 54 | # 提取新闻id 55 | news_id = '-'.join(url.split('/')[-2:]) 56 | news_id = re.sub(r'\..*', '', news_id) 57 | 58 | # 提取新闻内容 59 | news_content = [p.text for p in soup.select('p') if not p.findChildren()] 60 | news_content = '\n'.join(news_content) 61 | 62 | # 存储新闻文本 63 | if len(news_content) >= 30: 64 | path = 'data/' + clsf + '/' + news_id + '.txt' 65 | with open(path, 'w') as f: 66 | f.write(news_content) 67 | 68 | 69 | def runSpider(clsf): 70 | pool = multiprocessing.Pool(6) 71 | for page in range(1,100): 72 | urls = getNewsUrl(*(clsf, page)) 73 | print 'Page',page 74 | for url in urls: 75 | # loadNews(*(url, clsf)) 76 | pool.apply_async(loadNews, (url, clsf)) 77 | pool.close() 78 | pool.join() 79 | 80 | 81 | if __name__ == '__main__': 82 | # sports tech finance edu ent games fashion mil (news) 83 | runSpider('games') 84 | -------------------------------------------------------------------------------- /stopwords.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | reload(sys) 5 | sys.setdefaultencoding('utf8') 6 | 7 | def deduplicateStopWords(filepath): 8 | deduplicated = [] 9 | i = 0 10 | with open(filepath, 'rb') as f: 11 | for line in f.readlines(): 12 | i += 1 13 | word = line.strip() 14 | if word not in deduplicated: 15 | deduplicated.append(word) 16 | else: 17 | print i 18 | 19 | return deduplicated 20 | 21 | def writeNewStopWordsList(filepath, stopwords): 22 | with open(filepath, 'w') as f: 23 | for word in stopwords: 24 | f.write(word + '\n') 25 | 26 | if __name__ == '__main__': 27 | stopwords = deduplicateStopWords('stopwords/stopwords.txt') 28 | print 'Completion Detection' 29 | # writeNewStopWordsList('stopwords.txt', stopwords) 30 | # print 'Completion Deduplicate' 31 | -------------------------------------------------------------------------------- /svm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import time 6 | 7 | import numpy as np 8 | 9 | from sklearn.feature_extraction.text import TfidfTransformer 10 | 11 | from sklearn.decomposition import LatentDirichletAllocation 12 | 13 | from sklearn.svm import LinearSVC, SVC 14 | 15 | from sklearn.pipeline import make_pipeline 16 | from sklearn.model_selection import GridSearchCV 17 | from sklearn.externals import joblib 18 | 19 | from sklearn import metrics 20 | 21 | 22 | def trainClassifier(): 23 | """ 24 | 训练分类器 25 | """ 26 | # 训练集特征矩阵保存路径 27 | train_matrix_path = 'matrix/train/matrix.pkl' 28 | # 分类器保存路径 29 | classifier_path = 'classifier/classifier.pkl' 30 | 31 | matrix = joblib.load(train_matrix_path) 32 | target = np.array([x for x in range(10) for i in range(50000)]) 33 | 34 | # 构造分类器 35 | estimators = ( 36 | TfidfTransformer(), 37 | LinearSVC() 38 | ) 39 | classifier_params = { 40 | 'tfidftransformer__sublinear_tf': True, 41 | } 42 | classifier = make_pipeline(*estimators) 43 | classifier.set_params(**classifier_params) 44 | 45 | # classifier.fit(matrix, target) 46 | # best_model = classifier 47 | 48 | parameters = { 49 | 'linearsvc__C': np.arange(0.7, 1.3, 0.1), 50 | 'linearsvc__class_weight': [{0:a, 4:b, 6:c, 7:d} for a in [0.8,1.2,1.6] for b in [0.8,1.2,1.6] for c in [0.8,1.2,1.6] for d in [0.8,1.2,1.6]], 51 | } 52 | grid = GridSearchCV(classifier, parameters, cv=5, n_jobs=3) 53 | grid.fit(matrix, target) 54 | 55 | print 'params',grid.best_params_ 56 | print 'score',grid.best_score_ 57 | 58 | # 保存分类器 59 | best_model = grid.best_estimator_ 60 | joblib.dump(best_model, classifier_path) 61 | 62 | 63 | def testClassifier(): 64 | """ 65 | 测试测试集 66 | """ 67 | # 测试集特征矩阵保存路径 68 | train_matrix_path = 'matrix/test/matrix.pkl' 69 | # 分类器保存路径 70 | classifier_path = 'classifier/classifier.pkl' 71 | 72 | matrix = joblib.load(train_matrix_path) 73 | target = np.array([x for x in range(10) for i in range(50000)]) 74 | 75 | # 读取分类器 76 | classifier = joblib.load(classifier_path) 77 | predicted = classifier.predict(matrix) 78 | 79 | term_file_folder_path = 'data/test/raw/' 80 | confusion_matrix = metrics.confusion_matrix(target, predicted) 81 | print confusion_matrix 82 | print metrics.classification_report(target, predicted, target_names=os.listdir(term_file_folder_path)) 83 | print metrics.accuracy_score(target, predicted) 84 | 85 | joblib.dump(confusion_matrix, 'results/SVM_confusion_matrix.pkl') 86 | 87 | 88 | if __name__ == '__main__': 89 | # 训练 90 | # time_start = time.time() 91 | # trainClassifier() 92 | # print 'Training time:', time.time()-time_start, 's' 93 | 94 | # 测试 95 | time_start = time.time() 96 | testClassifier() 97 | print 'Testing time:', time.time()-time_start, 's' 98 | -------------------------------------------------------------------------------- /text2term.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import sys 5 | import logging 6 | 7 | import pickle 8 | 9 | import multiprocessing 10 | 11 | import re 12 | import jieba 13 | 14 | 15 | def convertText2Term(*args): 16 | """ 17 | 将Text文本文件转为Term分词文件 18 | 此函数作为一个进程 19 | """ 20 | reload(sys) 21 | sys.setdefaultencoding('utf8') 22 | 23 | jieba.setLogLevel(logging.INFO) 24 | 25 | text_path = args[0] 26 | term_path = args[1] 27 | stopwords_list = args[2] 28 | 29 | # 读取Text文件 30 | with open(text_path, 'r') as f: 31 | text = f.read() 32 | 33 | # 分词 34 | term_list = [x for x in jieba.cut(text)] 35 | 36 | # 过滤分词 37 | filter_pattern = re.compile(ur'[-+]?[\w\d]+|零|一|二|三|四|五|六|七|八|九|十|百|千|万|亿') 38 | filtered_term_list = [] 39 | for term in term_list: 40 | # 被过滤的分词:长度小于2, 包含数字或字母或中文数词, 停用词 41 | if len(term)<2 or filter_pattern.search(term) or term in stopwords_list: 42 | pass 43 | else: 44 | filtered_term_list.append(term) 45 | 46 | # 存储分词 47 | if len(filtered_term_list) >= 10: 48 | with open(term_path, 'w') as f: 49 | pickle.dump(filtered_term_list, f) 50 | 51 | # print '|'.join(filtered_term_list),'\n' 52 | 53 | 54 | def processText(text_file_folder_path, term_file_folder_path): 55 | """ 56 | 处理指定路径下的所有Text文件 57 | """ 58 | # 创建进程池,参数为池中进程数 59 | pool = multiprocessing.Pool(6) 60 | 61 | # 获取停用词表 62 | with open('stopwords/stopwords.txt', 'rb') as f: 63 | stopwords_list = [line.strip() for line in f.readlines()] 64 | 65 | classification = os.listdir(text_file_folder_path)[-2:] 66 | for clsf in classification: 67 | print clsf 68 | for text_filename in os.listdir(text_file_folder_path+clsf): 69 | text_path = text_file_folder_path + clsf + '/' + text_filename 70 | term_path = term_file_folder_path + clsf + '/' + text_filename.split('.')[0] + '.pkl' 71 | 72 | args = (text_path, term_path, stopwords_list) 73 | 74 | # 调用文本转分词的进程 75 | pool.apply_async(convertText2Term, args) 76 | 77 | pool.close() 78 | pool.join() 79 | 80 | 81 | if __name__ == '__main__': 82 | # 训练集文本数据的文件夹路径 83 | text_file_folder_path = 'data/train/raw/' 84 | # 训练集分词数据的文件夹路径 85 | term_file_folder_path = 'data/train/term/' 86 | # 处理训练集文本 87 | processText(text_file_folder_path, term_file_folder_path) 88 | 89 | 90 | # 测试集文本数据的文件夹路径 91 | text_file_folder_path = 'data/test/raw/' 92 | # 测试集分词数据的文件夹路径 93 | term_file_folder_path = 'data/test/term/' 94 | # 处理测试集文本 95 | processText(text_file_folder_path, term_file_folder_path) 96 | -------------------------------------------------------------------------------- /vectorizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import os 5 | 6 | import time 7 | 8 | import pickle 9 | 10 | import numpy as np 11 | 12 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 13 | 14 | from sklearn.externals import joblib 15 | 16 | 17 | 18 | def readTerm(term_file_folder_path): 19 | """ 20 | 读取Term文件,返回Term字符串生成器和类别生产器 21 | """ 22 | def getTerm(): 23 | classification = os.listdir(term_file_folder_path) 24 | for num, clsf in enumerate(classification): 25 | print num,'/',len(classification) 26 | for term_filename in os.listdir(term_file_folder_path+clsf)[:50000]: 27 | path = term_file_folder_path + clsf + '/' + term_filename 28 | with open(path, 'r') as f: 29 | term_list = pickle.load(f) 30 | term = ' '.join(term_list) 31 | yield term 32 | 33 | def getTarget(): 34 | classification = os.listdir(term_file_folder_path) 35 | for num, clsf in enumerate(classification): 36 | for term_filename in os.listdir(term_file_folder_path+clsf)[:50000]: 37 | yield num 38 | 39 | # Term字符串生成器 40 | term_generator = getTerm() 41 | # Term的类别生成器 42 | target_generator = getTarget() 43 | 44 | return term_generator, target_generator 45 | 46 | 47 | def generateMatrix(): 48 | """ 49 | 生成特征矩阵 50 | """ 51 | vectorizer = CountVectorizer(min_df=0.001) 52 | 53 | for x in ['train', 'test']: 54 | # 分词数据的文件夹路径 55 | term_file_folder_path = 'data/%s/term/' % x 56 | # 特征矩阵保存路径 57 | matrix_path = 'matrix/%s/matrix.pkl' % x 58 | 59 | # 读取数据 60 | term_generator, target_generator = readTerm(term_file_folder_path) 61 | 62 | # 训练集拟合后转换为矩阵,测试集根据拟合好的矢量器直接转换为矩阵 63 | if x == 'train': 64 | matrix = vectorizer.fit_transform(term_generator) 65 | joblib.dump(vectorizer.vocabulary_, 'matrix/vocabulary.pkl') 66 | 67 | elif x == 'test': 68 | matrix = vectorizer.transform(term_generator) 69 | 70 | # 保存特征矩阵 71 | joblib.dump(matrix, matrix_path) 72 | 73 | 74 | 75 | if __name__ == '__main__': 76 | reload(sys) 77 | sys.setdefaultencoding('utf8') 78 | 79 | time_start = time.time() 80 | generateMatrix() 81 | print 'Transform time:', time.time()-time_start, 's' 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /viewer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import sys 6 | 7 | import pickle 8 | 9 | import numpy as np 10 | 11 | import pandas as pd 12 | 13 | import seaborn as sns 14 | 15 | from sklearn.externals import joblib 16 | 17 | import matplotlib.pyplot as plt 18 | 19 | 20 | def _calculate(*matrices): 21 | labels = ['culture', 'education', 'entertainment', 'estate', 'finance', 'game', 'gov', 'society', 'sport', 'technology'] #os.listdir('data/test/raw/') 22 | df = pd.DataFrame(columns=('Classifier', 'classification', 'recall', 'precision', 'f1score')) 23 | index = 0 24 | for t, confusion_matrix in enumerate(matrices): 25 | typ = 'Bayes' if t == 0 else 'SVM' 26 | correct = 0 27 | 28 | r = confusion_matrix.sum(axis=1) 29 | p = confusion_matrix.sum(axis=0) 30 | 31 | for clf in range(0, 10): 32 | recall = confusion_matrix[clf][clf] / float(r[clf]) 33 | precision = confusion_matrix[clf][clf] / float(p[clf]) 34 | f1score = 2*recall*precision/(recall+precision) 35 | df.loc[index] = [typ, labels[clf], recall, precision, f1score] 36 | index += 1 37 | 38 | return df 39 | 40 | 41 | def viewMatrix(): 42 | """ 43 | 查看数据矩阵 44 | """ 45 | train_matrix_path = 'matrix/train/matrix.pkl' 46 | test_matrix_path = 'matrix/test/matrix.pkl' 47 | 48 | train_matrix = joblib.load(train_matrix_path) 49 | test_matrix = joblib.load(test_matrix_path) 50 | 51 | print 'Train Matrix',train_matrix.shape 52 | print 'Test Matrix',test_matrix.shape 53 | print type(test_matrix) 54 | 55 | 56 | def viewVocabulary(): 57 | """ 58 | 查看数据词典 59 | """ 60 | reload(sys) 61 | sys.setdefaultencoding('utf8') 62 | 63 | vocabulary = joblib.load('matrix/vocabulary.pkl') 64 | for voc in vocabulary.keys(): 65 | print voc, 66 | 67 | 68 | def viewTestResult(): 69 | """ 70 | 查看测试结果 71 | """ 72 | bayes = joblib.load('results/Bayes_confusion_matrix.pkl') 73 | svm = joblib.load('results/SVM_confusion_matrix.pkl') 74 | labels = ['culture', 'education', 'entertainment', 'estate', 'finance', 'game', 'gov', 'society', 'sport' 'technology'] #os.listdir('data/test/raw/') 75 | 76 | # 绘制混淆矩阵的热力图 77 | plt.figure(figsize=(12,5)) 78 | 79 | plt.subplot(1,2,1) 80 | ax = sns.heatmap(bayes, cmap='YlGnBu', xticklabels=labels, yticklabels=labels) 81 | ax.set_title('Bayes') 82 | 83 | plt.subplot(1,2,2) 84 | ax = sns.heatmap(svm, cmap='YlGnBu', xticklabels=labels, yticklabels=labels) 85 | ax.set_title('SVM') 86 | 87 | plt.subplots_adjust(wspace=0.4, bottom=0.25, top=0.9, right=0.95) 88 | 89 | # 绘制每一类的召回率、精确度、F1测度的直方图 90 | df = _calculate(bayes, svm) 91 | plt.figure(figsize=(12,8)) 92 | 93 | plt.subplot(3,1,1) 94 | ax = sns.barplot(x='classification',y='recall',hue='Classifier',data=df) 95 | ax.set_title('Recall') 96 | ax.set_xlabel('') 97 | 98 | plt.subplot(3,1,2) 99 | ax = sns.barplot(x='classification',y='precision',hue='Classifier',data=df) 100 | ax.set_title('Precision') 101 | ax.set_xlabel('') 102 | 103 | 104 | plt.subplot(3,1,3) 105 | ax = sns.barplot(x='classification',y='f1score',hue='Classifier',data=df) 106 | ax.set_title('F1score') 107 | ax.set_xlabel('') 108 | 109 | plt.subplots_adjust(hspace=0.4, bottom=0.07, top=0.96, right=0.93) 110 | 111 | plt.show() 112 | 113 | 114 | 115 | if __name__ == '__main__': 116 | # viewMatrix() 117 | 118 | # viewVocabulary() 119 | 120 | viewTestResult() 121 | 122 | 123 | --------------------------------------------------------------------------------