├── README.md ├── word2vec_train ├── separate_words.py ├── remove_words.py ├── process_wiki.py ├── train_word2vec_100_5_5_model.py └── train_word2vec_skip_ngram_200_5_5_model_wiki.py ├── tools.ipynb ├── clean_data.py ├── text-classification-deep-learning.py └── clean_data.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # text_classification_with_deep_learning 2 | 3 | 使用深度学习解决新闻文本分类的实验代码(测试准确率可达94%左右) 4 | 5 | 1.代码使用python实现,深度学习网络使用keras框架实现 6 | 7 | 2.代码为实验代码,文件路径都为实验环境中设定的路径,如需使用该代码注意修改路径 8 | 9 | 3.实验使用的数据有两部分: 10 | (1)wiki百科中文语料(https://dumps.wikimedia.org/zhwiki/latest/zhwiki-latest-pages-articles.xml.bz2) 11 | (2)清华自然语言处理实验室的中文文本分类数据集 THUCNews (http://thuctc.thunlp.org/) 12 | 其中,wiki百科语料用来训练word2vec模型;THUCNews数据集为文本分类任务数据集 13 | 14 | -------------------------------------------------------------------------------- /word2vec_train/separate_words.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import logging 6 | import sys 7 | import jieba 8 | 9 | if __name__=='__main__': 10 | 11 | program = os.path.basename(sys.argv[0]) 12 | logger = logging.getLogger(program) 13 | 14 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s') 15 | logging.root.setLevel(level=logging.INFO) 16 | 17 | if len(sys.argv) < 3: 18 | print(globals()['__doc__'] %locals()) 19 | sys.exit(1) 20 | 21 | inp, outp = sys.argv[1:3] 22 | space = ' ' 23 | 24 | output = open(outp, 'w') 25 | inputer = open(inp, 'r') 26 | 27 | for line in inputer.readlines(): 28 | seg_list = jieba.cut(line) 29 | output.write(space.join(seg_list) + '\n') 30 | 31 | -------------------------------------------------------------------------------- /word2vec_train/remove_words.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import logging 6 | import sys 7 | import re 8 | 9 | if __name__=='__main__': 10 | 11 | program = os.path.basename(sys.argv[0]) 12 | logger = logging.getLogger(program) 13 | 14 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s') 15 | logging.root.setLevel(level=logging.INFO) 16 | 17 | if len(sys.argv) < 3: 18 | print(globals()['__doc__'] %locals()) 19 | sys.exit(1) 20 | 21 | inp, outp = sys.argv[1:3] 22 | 23 | output = open(outp, 'w') 24 | inputer = open(inp, 'r') 25 | 26 | for line in inputer.readlines(): 27 | ss = re.findall('[\u4e00-\u9fa5a-zA-Z0-9]',line) 28 | # ss = re.findall('[\n\s*\r\u4e00-\u9fa5]', line) 29 | output.write("".join(ss)) 30 | -------------------------------------------------------------------------------- /word2vec_train/process_wiki.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import logging 6 | import sys 7 | 8 | from gensim.corpora import WikiCorpus 9 | 10 | if __name__=='__main__': 11 | 12 | program = os.path.basename(sys.argv[0]) 13 | logger = logging.getLogger(program) 14 | 15 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s') 16 | logging.root.setLevel(level=logging.INFO) 17 | 18 | 19 | if len(sys.argv) < 3: 20 | print(globals()['__doc__'] %locals()) 21 | sys.exit(1) 22 | 23 | inp, outp = sys.argv[1:3] 24 | space =b' ' 25 | i = 0 26 | 27 | output = open(outp, 'wb') 28 | wiki = WikiCorpus(inp, lemmatize=False, dictionary={}) 29 | for text in wiki.get_texts(): 30 | output.write(space.join(text) +b'\n') 31 | i = i + 1 32 | if i % 10000 == 0: 33 | logger.info('Saved ' + str(i) + ' articles') 34 | 35 | output.close() 36 | logger.info('Finished ' + str(i) + ' articles') 37 | -------------------------------------------------------------------------------- /word2vec_train/train_word2vec_100_5_5_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import logging 5 | import os.path 6 | import sys 7 | import multiprocessing 8 | 9 | from time import time 10 | from gensim.corpora import WikiCorpus 11 | from gensim.models import Word2Vec 12 | from gensim.models.word2vec import LineSentence 13 | from gensim.models.word2vec import BrownCorpus 14 | 15 | if __name__== "__main__": 16 | 17 | program = os.path.basename(sys.argv[0]) 18 | logger = logging.getLogger(program) 19 | 20 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s' ) 21 | logging.root.setLevel(level=logging.INFO) 22 | logger.info('running %s' % ' '.join(sys.argv)) 23 | 24 | if len(sys.argv) < 3: 25 | print globals()['__doc__'] % locals() 26 | sys.exit(1) 27 | 28 | inp, outp1, outp2= sys.argv[1:4] 29 | 30 | begin = time() 31 | # model = Word2Vec(BrownCorpus(inp), size=100, window=5, min_count=5, 32 | # workers=multiprocessing.cpu_count()) 33 | model = Word2Vec(LineSentence(inp),sg=0, size=100, window=5, min_count=2, 34 | workers=multiprocessing.cpu_count()) 35 | model.save(outp1) 36 | model.wv.save_word2vec_format(outp2, binary=True) 37 | 38 | end = time() 39 | print("total processing time:%d seconds" %(end-begin)) 40 | -------------------------------------------------------------------------------- /word2vec_train/train_word2vec_skip_ngram_200_5_5_model_wiki.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import logging 5 | import os.path 6 | import sys 7 | import multiprocessing 8 | 9 | from time import time 10 | from gensim.corpora import WikiCorpus 11 | from gensim.models import Word2Vec 12 | from gensim.models.word2vec import LineSentence 13 | from gensim.models.word2vec import BrownCorpus 14 | 15 | if __name__== "__main__": 16 | 17 | program = os.path.basename(sys.argv[0]) 18 | logger = logging.getLogger(program) 19 | 20 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s' ) 21 | logging.root.setLevel(level=logging.INFO) 22 | logger.info('running %s' % ' '.join(sys.argv)) 23 | 24 | if len(sys.argv) < 3: 25 | print globals()['__doc__'] % locals() 26 | sys.exit(1) 27 | 28 | inp, outp1, outp2= sys.argv[1:4] 29 | 30 | begin = time() 31 | # model = Word2Vec(BrownCorpus(inp), size=100, window=5, min_count=5, 32 | # workers=multiprocessing.cpu_count()) 33 | model = Word2Vec(LineSentence(inp),sg=1, size=200, window=5, min_count=5, 34 | workers=multiprocessing.cpu_count()) 35 | model.save(outp1) 36 | model.wv.save_word2vec_format(outp2, binary=True) 37 | 38 | end = time() 39 | print("total processing time:%d seconds" %(end-begin)) 40 | -------------------------------------------------------------------------------- /tools.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# coding=utf-8\n", 12 | "import os\n", 13 | "import sys\n", 14 | "import shutil " 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": { 21 | "collapsed": true 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "\"\"\" 删除指定目录下目录名称中的含有指定字符串的子目录\"\"\"\n", 26 | "def delete_dir(path,s):\n", 27 | " for root,dirs,files in os.walk(path):\n", 28 | " for d in dirs:\n", 29 | " if d==s :\n", 30 | " print \"delete directory \"+os.path.join(root,d)\n", 31 | " shutil.rmtree(os.path.join(root,d))" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/体育/word_separated\n", 44 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/娱乐/word_separated\n", 45 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/家居/word_separated\n", 46 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/彩票/word_separated\n", 47 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/房产/word_separated\n", 48 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/教育/word_separated\n", 49 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/时尚/word_separated\n", 50 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/时政/word_separated\n", 51 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/星座/word_separated\n", 52 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/游戏/word_separated\n", 53 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/社会/word_separated\n", 54 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/科技/word_separated\n", 55 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/股票/word_separated\n", 56 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/财经/word_separated\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "basedir = \"/Users/zhangwei/Documents/paper/data/THUCNews/\"\n", 62 | "s = \"word_separated\"\n", 63 | "delete_dir(basedir,s)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 8, 69 | "metadata": { 70 | "collapsed": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "\"\"\"查看csv.zip文件内容\"\"\"\n", 75 | "import pandas as pd\n", 76 | "def lookat_csv_zip(filename):\n", 77 | " df = pd.read_csv(filename, compression='zip')\n", 78 | " columns = df.columns\n", 79 | " one_line = [df[x][0] for x in columns]\n", 80 | " return columns,one_line\n", 81 | " " 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 10, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/html": [ 92 | "
\n", 93 | "\n", 106 | "\n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | "
typecontent
0互联网思科 周二 盘后 发布 财年 第三季度 财报 财报 显示 受 股票 期权 开支 收购 相关 ...
1互联网eNet 硅谷动力 国外 媒体报道 一名 德国 数据库安全 工程师 日前 撰文 指出 目前为...
2互联网作者 令狐 达 eNet 硅谷动力 国外 媒体报道 美国 高科技 市场调研 公司 M Met...
3互联网作者 令狐 达 eNet 硅谷动力 台湾 媒体 引述 笔记本 厂商 消息人士 透露 Turi...
4互联网作者 令狐 达 eNet 硅谷动力 台湾 媒体报道 液晶 显示器 大厂 宏基 公司 美国 分...
\n", 142 | "
" 143 | ], 144 | "text/plain": [ 145 | " type content\n", 146 | "0 互联网 思科 周二 盘后 发布 财年 第三季度 财报 财报 显示 受 股票 期权 开支 收购 相关 ...\n", 147 | "1 互联网 eNet 硅谷动力 国外 媒体报道 一名 德国 数据库安全 工程师 日前 撰文 指出 目前为...\n", 148 | "2 互联网 作者 令狐 达 eNet 硅谷动力 国外 媒体报道 美国 高科技 市场调研 公司 M Met...\n", 149 | "3 互联网 作者 令狐 达 eNet 硅谷动力 台湾 媒体 引述 笔记本 厂商 消息人士 透露 Turi...\n", 150 | "4 互联网 作者 令狐 达 eNet 硅谷动力 台湾 媒体报道 液晶 显示器 大厂 宏基 公司 美国 分..." 151 | ] 152 | }, 153 | "execution_count": 10, 154 | "metadata": {}, 155 | "output_type": "execute_result" 156 | } 157 | ], 158 | "source": [ 159 | "filename = \"../../data/sogou_resource/dataset/train.csv.zip\"\n", 160 | "# cols,line = lookat_csv_zip(filename)\n", 161 | "# print cols\n", 162 | "# print line\n", 163 | "\n", 164 | "# 调用api\n", 165 | "df = pd.read_csv(filename, compression='zip')\n", 166 | "df.head()" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 4, 172 | "metadata": { 173 | "collapsed": true 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "\"\"\"停用词表除重\"\"\"\n", 178 | "basedir = \"/Users/zhangwei/Documents/paper/data/sogou_resource/dataset/\"\n", 179 | "original_file = os.path.join(basedir,\"stopwords.txt\")\n", 180 | "new_file = os.path.join(basedir,\"new_stopwords.txt\")\n", 181 | "word_list = []\n", 182 | "with open(original_file) as f:\n", 183 | " for line in f.readlines():\n", 184 | " line = line.strip()\n", 185 | " word_list.append(line)\n", 186 | "word_list = list(set(word_list))\n", 187 | "nf = open(new_file,'w')\n", 188 | "for item in word_list:\n", 189 | " nf.write(item+'\\n')\n", 190 | "nf.close()" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": { 197 | "collapsed": true 198 | }, 199 | "outputs": [], 200 | "source": [] 201 | } 202 | ], 203 | "metadata": { 204 | "kernelspec": { 205 | "display_name": "Python 2", 206 | "language": "python", 207 | "name": "python2" 208 | }, 209 | "language_info": { 210 | "codemirror_mode": { 211 | "name": "ipython", 212 | "version": 2 213 | }, 214 | "file_extension": ".py", 215 | "mimetype": "text/x-python", 216 | "name": "python", 217 | "nbconvert_exporter": "python", 218 | "pygments_lexer": "ipython2", 219 | "version": "2.7.10" 220 | } 221 | }, 222 | "nbformat": 4, 223 | "nbformat_minor": 2 224 | } 225 | -------------------------------------------------------------------------------- /clean_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import re 4 | import jieba 5 | import time 6 | import datetime 7 | 8 | """使用jieba分词进行中文分词""" 9 | separated_word_file_dir = "word_separated" 10 | # 清华新闻语料库 11 | types = ["体育", "娱乐", "家居", "彩票", "房产", "教育", "时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"] 12 | 13 | 14 | def ch_and_en_word_extraction(content_raw): 15 | """抽取中文和英文""" 16 | pattern = re.compile(u"([\u4e00-\u9fa5a-zA-Z0-9]+)") 17 | re_data = pattern.findall(content_raw) 18 | clean_content = ' '.join(re_data) 19 | return clean_content 20 | 21 | 22 | def clean_str(s): 23 | # s = s.strip('\n') # 换行符 24 | # s = re.sub("[\t\n\r]*", '', s) # tab, newline, return 25 | s = re.sub('\|+',' ',s) 26 | s = re.sub('\s+',' ',s) 27 | s = s.strip() # 前后的空格 28 | # s = re.sub("<(\S*?)[^>]*>.*?|<.*? />", '', s) # html标签 29 | # s = re.sub(" +|<+|>+", '', s) # html中的空格符号,大于,小于 30 | # s = re.sub("[a-zA-z]+://[^\s]*", '', s) # URL 31 | # s = re.sub(r'([\w-]+(\.[\w-]+)*@[\w-]+(\.[\w-]+)+)', '', s) # email 32 | # 标点符号,需要先转utf-8,否则符号匹配不成功 33 | # s = re.sub(ur"([%s])+" % zhon.hanzi.punctuation, " ", s.decode('utf-8')) 34 | # 抽取中文和英文 35 | # s = ch_and_en_word_extraction(s) 36 | return s 37 | 38 | 39 | def separate_words(infile, outfile): 40 | try: 41 | outf = open(outfile, 'w') 42 | inf = open(infile, 'r') 43 | 44 | space = ' ' 45 | # print 'separate '+infile 46 | isFirstLine = True 47 | for line in inf.readlines(): 48 | line = clean_str(line) 49 | # 除空行 50 | if not len(line): 51 | continue 52 | seg_list = jieba.cut(line) 53 | """此处需要循环每个单词编码为utf-8,jieba.cut将结果转为了unicode编码, 54 | 直接write(space.join(seg_list))会报编码错误""" 55 | for word in seg_list: 56 | if not len(word.strip()): 57 | continue 58 | try: 59 | word = word.strip().encode('UTF-8') 60 | except: 61 | continue 62 | outf.write(word) 63 | outf.write(space) 64 | if isFirstLine: 65 | outf.write("。") 66 | isFirstLine = False 67 | outf.write('\n') 68 | # close file stream 69 | outf.close() 70 | inf.close() 71 | except: 72 | print "error occured when write to " + outfile 73 | 74 | 75 | def is_target_dir(path): 76 | if os.path.dirname(path).split("/")[-1] in types and not re.match(".DS_Store", os.path.basename(path)): 77 | return True 78 | else: 79 | return False 80 | 81 | 82 | def explore(dir): 83 | for root, dirs, files in os.walk(dir): 84 | for file in files: 85 | path = os.path.join(root, file) 86 | if is_target_dir(path): 87 | child_dir = os.path.join(root, separated_word_file_dir) 88 | if not os.path.exists(child_dir): 89 | os.mkdir(child_dir) 90 | print "make dir: " + child_dir 91 | separate_words(path, os.path.join(child_dir, file)) 92 | 93 | 94 | def do_batch_separate(path): 95 | if os.path.isfile(path) and is_target_dir(path): 96 | separate_words(path, os.path.join(root, separated_word_file_dir, path)) 97 | if os.path.isdir(path): 98 | explore(path) 99 | 100 | 101 | original_dir = "THUCNews_deal_title/" 102 | now = datetime.datetime.now() 103 | print "separate word begin time:", now 104 | begin_time = time.time() 105 | do_batch_separate(original_dir) 106 | end_time = time.time() 107 | now = datetime.datetime.now() 108 | print "separate word,end time:", now 109 | print "separate word,time used:" + str(end_time - begin_time) + "秒" 110 | 111 | 112 | """将所有语料,整合成csv类型文件,文件格式:type|content""" 113 | split_mark = '|' 114 | 115 | def combine_file(file, outfile): 116 | # the type of file ,file示例:xxx/互联网/xxx/xxx.txt 117 | label = os.path.dirname(file).split('/')[-2] 118 | content = open(file).read() 119 | # print "content:"+content 120 | # print "len:",len(content) 121 | if len(content) > 1: # 排除前面步骤中写文件时,内容为只写入一个空格的情况 122 | new_content = label + split_mark + content 123 | # print "new_content:\n " + new_content 124 | open(outfile, "a").write(new_content) 125 | 126 | 127 | def do_combine(dir, outfile): 128 | print "deal with dir: " + dir 129 | for root, dirs, files in os.walk(dir): 130 | for file in files: 131 | match = re.match(r'\d+\.txt', file) 132 | if match: 133 | path = os.path.join(root, file) 134 | # print "combine " + path 135 | combine_file(path, outfile) 136 | 137 | 138 | def create_csv_file(dir, filename): 139 | csv_title = "type"+ split_mark + "content\n" 140 | filepath = os.path.join(dir, filename + '.csv') 141 | open(filepath, 'w').write(csv_title) 142 | return filepath 143 | 144 | 145 | base_dir = "THUCNews_deal_title/" 146 | """创建处理后的数据集的目录""" 147 | dataset_dir = os.path.join(base_dir, "dataset") 148 | if not os.path.exists(dataset_dir): 149 | os.mkdir(dataset_dir) 150 | 151 | """创建每个type目录对应的csv文件,并将一个type目录下的文件写到同一个对应的csv文件""" 152 | # 清华新闻语料库 153 | type_name_list = ["体育", "娱乐", "家居", "彩票", "房产", "教育", "时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"] 154 | 155 | combine_begin_time = datetime.datetime.now() 156 | print "combine begin time:",combine_begin_time 157 | for name in type_name_list: 158 | path = create_csv_file(dataset_dir, name) 159 | print "going to combine file to " + path 160 | do_combine(os.path.join(base_dir, name, "word_separated"), path) 161 | 162 | combine_end_time = datetime.datetime.now() 163 | print "combine end time:",combine_end_time 164 | 165 | """随机采样每个类别的约20%作为测试集,80%作为训练集""" 166 | import random 167 | 168 | def extract_test_and_train_set(filepath, train_file, test_file): 169 | try: 170 | test_f = open(test_file, 'a') 171 | train_f = open(train_file, 'a') 172 | try: 173 | with open(filepath) as f: 174 | is_title_line = True 175 | for line in f.readlines(): 176 | if is_title_line: 177 | is_title_line = False 178 | continue 179 | if not len(line): 180 | continue 181 | if random.random() <= 0.2: 182 | test_f.write(line) 183 | else: 184 | train_f.write(line) 185 | except: 186 | print "IO ERROR" 187 | finally: 188 | test_f.close() 189 | train_f.close() 190 | except: 191 | print "can not open file" 192 | 193 | 194 | def do_extract(source_dir, train_f, test_f): 195 | for root, dirs, files in os.walk(source_dir): 196 | for file in files: 197 | if re.match("test|train\.csv", file) or not re.match(".*\.csv", file): 198 | continue 199 | path = os.path.join(root, file) 200 | print "extract file: " + path 201 | extract_test_and_train_set(path, train_f, test_f) 202 | 203 | 204 | # do extract 205 | dataset_dir = "THUCNews_deal_title/dataset/" 206 | train_dataset = os.path.join(dataset_dir, "train.csv") 207 | test_dataset = os.path.join(dataset_dir, "test.csv") 208 | if not os.path.exists(train_dataset): 209 | print "create file: " + train_dataset 210 | open(train_dataset, 'w').write("type"+ split_mark+"content\n") 211 | if not os.path.exists(test_dataset): 212 | print "create file:" + test_dataset 213 | open(test_dataset, 'w').write("type"+split_mark+"content\n") 214 | 215 | do_extract(dataset_dir, train_dataset, test_dataset) 216 | 217 | 218 | """清洗数据,除掉停用词,剔除坏样本""" 219 | 220 | # 221 | # def clean_stopwords(content_raw, stopwords_set): 222 | # content_list = [x for x in re.split(' +|\t+',content_raw) if x != ''] 223 | # common_set = set(content_list) & stopwords_set 224 | # new_content = filter(lambda x: x not in common_set, content_list) 225 | # return new_content 226 | # 227 | # 228 | # def do_clean_stopwords(content_file, stopwords_file, newfile): 229 | # print "clean stopwords in " + content_file 230 | # stopwords = [] 231 | # # 获取停用词 232 | # with open(stopwords_file) as fi: 233 | # for line in fi.readlines(): 234 | # stopwords.append(line.strip()) 235 | # newf = open(newfile, 'w') 236 | # with open(content_file) as f: 237 | # for line in f.readlines(): 238 | # type_content = line.split(split_mark) 239 | # content_raw = type_content[1] 240 | # new_cont = clean_stopwords(content_raw, set(stopwords)) 241 | # new_line = type_content[0] + split_mark + ' '.join(new_cont).strip() 242 | # newf.write(new_line) 243 | # newf.write('\n') 244 | # newf.close() 245 | # 246 | # test_file = "THUCNews_deal_title/dataset/test.csv" 247 | # train_file = "THUCNews_deal_title/dataset/train.csv" 248 | # new_test_file = "THUCNews_deal_title/dataset/cleaned_test.csv" 249 | # new_train_file = "THUCNews_deal_title/dataset/cleaned_train.csv" 250 | # stop_words_file = "THUCNews_deal_title/dataset/news.stopwords.txt" 251 | # do_clean_stopwords(test_file,stop_words_file,new_test_file) 252 | # do_clean_stopwords(train_file,stop_words_file,new_train_file) 253 | -------------------------------------------------------------------------------- /text-classification-deep-learning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # 常量 4 | MAX_SEQUENCE_LENGTH = 300 # 每条新闻最大长度 5 | EMBEDDING_DIM = 200 # 词向量空间维度 6 | VALIDATION_SPLIT = 0.20 # 验证集比例 7 | SENTENCE_NUM = 30 # 句子的数目 8 | model_filepath = "./CNN_word2vec_1_layer_model" 9 | 10 | from keras.preprocessing.text import Tokenizer 11 | import pandas as pd 12 | from keras.models import Sequential 13 | from keras.layers import Dense, Input, Flatten, Dropout 14 | from keras.layers import Conv1D, MaxPooling1D, Embedding,GRU 15 | from keras.optimizers import Adam 16 | from keras import regularizers 17 | import gensim 18 | from time import time 19 | import keras.callbacks 20 | from keras.layers import LSTM,Bidirectional 21 | from keras.preprocessing.sequence import pad_sequences 22 | import numpy as np 23 | 24 | # input data 25 | train_filename = "./THUCNews_deal_title/dataset/sentenced_deal_stop_word_train.csv.zip" 26 | test_filename = "./THUCNews_deal_title/dataset/sentenced_deal_stop_word_test.csv.zip" 27 | train_df = pd.read_csv(train_filename,sep='|',compression = 'zip',error_bad_lines=False) 28 | test_df = pd.read_csv(test_filename,sep='|',compression='zip',error_bad_lines=False) 29 | content_df = train_df.append(test_df, ignore_index=True) 30 | # shuffle data 31 | from sklearn.utils import shuffle 32 | content_df = shuffle(content_df) 33 | 34 | all_texts = content_df['content'] 35 | all_labels = content_df['type'] 36 | print "新闻文本数量:", len(all_texts), len(all_labels) 37 | print "每类新闻的数量:\n", all_labels.value_counts() 38 | 39 | all_texts = all_texts.tolist() 40 | all_labels = all_labels.tolist() 41 | 42 | original_labels = list(set(all_labels)) 43 | num_labels = len(original_labels) 44 | print "label counts:", num_labels 45 | # one_hot encode label 46 | one_hot = np.zeros((num_labels, num_labels), int) 47 | np.fill_diagonal(one_hot, 1) 48 | label_dict = dict(zip(original_labels, one_hot)) 49 | 50 | tokenizer = Tokenizer() 51 | tokenizer.fit_on_texts(all_texts) 52 | sequences = tokenizer.texts_to_sequences(all_texts) 53 | word_index = tokenizer.word_index 54 | print "Found %s unique tokens." % len(word_index) 55 | data = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH) 56 | labels = np.asarray([label_dict[x] for x in all_labels]) 57 | print "shape of data tensor:", data.shape 58 | print "shape of label tensor:", labels.shape 59 | 60 | # 分割训练集和测试集 61 | p_train = int(len(data) * (1 - VALIDATION_SPLIT)) 62 | x_train = data[:p_train] 63 | y_train = labels[:p_train] 64 | x_val = data[p_train:] 65 | y_val = labels[p_train:] 66 | print 'train docs:' + str(len(x_train)) 67 | print 'validate docs:' + str(len(x_val)) 68 | 69 | # 搭建模型 70 | def cnn_model(embedding_layer=None): 71 | model = Sequential() 72 | if embedding_layer: 73 | # embedding layer use pre_trained word2vec model 74 | model.add(embedding_layer) 75 | else: 76 | # random word vector 77 | model.add(Embedding(len(word_index)+1,EMBEDDING_DIM,input_length=MAX_SEQUENCE_LENGTH)) 78 | model.add(Conv1D(128, 5, padding='valid', activation='relu')) 79 | model.add(MaxPooling1D(5)) 80 | # model.add(Conv1D(128, 5, padding='valid', activation='relu')) 81 | # model.add(MaxPooling1D(5)) 82 | # model.add(Conv1D(128, 5, padding='valid', activation='relu')) 83 | # model.add(MaxPooling1D(5)) 84 | model.add(Flatten()) 85 | return model 86 | 87 | def lstm_model(embedding_layer=None): 88 | model = Sequential() 89 | if embedding_layer: 90 | # embedding layer use pre_trained word2vec model 91 | model.add(embedding_layer) 92 | else: 93 | # random word vector 94 | model.add(Embedding(len(word_index)+1,EMBEDDING_DIM,input_length=MAX_SEQUENCE_LENGTH)) 95 | # Native LSTM 96 | model.add(LSTM(200,dropout=0.2,recurrent_dropout=0.2)) 97 | return model 98 | 99 | def gru_model(embedding_layer=None): 100 | model = Sequential() 101 | if embedding_layer: 102 | # embedding layer use pre_trained word2vec model 103 | model.add(embedding_layer) 104 | else: 105 | # random word vector 106 | model.add(Embedding(len(word_index)+1,EMBEDDING_DIM,input_length=MAX_SEQUENCE_LENGTH)) 107 | # GRU 108 | model.add(GRU(200,dropout=0.2,recurrent_dropout=0.2)) 109 | return model 110 | 111 | def bidirectional_lstm_model(embedding_layer=None): 112 | model = Sequential() 113 | if embedding_layer: 114 | # embedding layer use pre_trained word2vec model 115 | model.add(embedding_layer) 116 | else: 117 | # random word vector 118 | model.add(Embedding(len(word_index)+1,EMBEDDING_DIM,input_length=MAX_SEQUENCE_LENGTH)) 119 | # Bidirection LSTM 120 | model.add(Bidirectional(LSTM(200,dropout=0.2,recurrent_dropout=0.2)) 121 | return model 122 | 123 | def cnn_lstm_model(embedding_layer=None): 124 | model = Sequential() 125 | if embedding_layer: 126 | # embedding layer use pre_trained word2vec model 127 | model.add(embedding_layer) 128 | else: 129 | # random word vector 130 | model.add(Embedding(len(word_index)+1,EMBEDDING_DIM,input_length=MAX_SEQUENCE_LENGTH)) 131 | model.add(Conv1D(128, 5, padding='valid', activation='relu')) 132 | model.add(MaxPooling1D(5)) 133 | # model.add(Dropout(0.5)) 134 | # model.add(GRU(128,dropout=0.2,recurrent_dropout=0.1,return_sequences = True)) 135 | model.add(GRU(128,dropout=0.2,recurrent_dropout=0.1)) 136 | return model 137 | 138 | 139 | # load word2vec model 140 | word2vec_model_file = "/home/zwei/workspace/nlp_study/word2vec_wiki_study/word2vec_train_wiki/WIKI_word2vec_model/word2vec_skip_ngram_200_10_5_model.bin" 141 | word2vec_model = gensim.models.KeyedVectors.load_word2vec_format(word2vec_model_file, binary=True) 142 | # print "test word 奥巴马:",word2vec_model['奥巴马'.decode('utf-8')] 143 | 144 | # construct embedding layer 145 | embedding_matrix = np.zeros((len(word_index)+1, EMBEDDING_DIM)) 146 | 147 | no_word_file = open("no_word_file.txt",'w') 148 | 149 | for word, i in word_index.items(): 150 | if word.decode('utf-8') in word2vec_model: 151 | embedding_matrix[i] = np.asarray(word2vec_model[word.decode('utf-8')], dtype='float32') 152 | else: 153 | # print "word not found in word2vec:", word 154 | no_word_file.write(word+"\n") 155 | embedding_matrix[i] = np.random.random(size=EMBEDDING_DIM) 156 | 157 | no_word_file.close() 158 | 159 | embedding_layer = Embedding(len(word_index)+1, 160 | EMBEDDING_DIM, 161 | weights=[embedding_matrix], 162 | input_length=MAX_SEQUENCE_LENGTH, 163 | trainable=False) 164 | 165 | # CNN + word2vec 166 | # model = cnn_model(embedding_layer) 167 | 168 | # CNN + radom word vector 169 | # model = cnn_model() 170 | 171 | # LSTM 172 | # model = lstm_model(embedding_layer) 173 | 174 | # GRU 175 | model = gru_model(embedding_layer) 176 | 177 | # CNN + LSTM 178 | # model = cnn_lstm_model(embedding_layer) 179 | 180 | # common 181 | model.add(Dense(EMBEDDING_DIM, activation='relu')) 182 | model.add(Dropout(0.5)) 183 | model.add(Dense(labels.shape[1], activation='softmax',kernel_regularizer=regularizers.l2(0.1))) 184 | model.summary() 185 | 186 | # adam = Adam(lr=0.01) 187 | 188 | # tensorboard 189 | tensorboard = keras.callbacks.TensorBoard(log_dir="lstm_word2vec_log_THUCNews_deal_title/{}".format(time())) 190 | 191 | # function 192 | from keras import backend as K 193 | def recall(y_true, y_pred): 194 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 195 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) 196 | recall = true_positives / (possible_positives + K.epsilon()) 197 | return recall 198 | def precision(y_true, y_pred): 199 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 200 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 201 | precision = true_positives / (predicted_positives + K.epsilon()) 202 | return precision 203 | 204 | 205 | model.compile(loss='categorical_crossentropy', optimizer='Adadelta', metrics=['acc']) 206 | model.fit(x_train, y_train, epochs=30, batch_size=256,callbacks=[tensorboard], validation_split=0.1) 207 | 208 | # model.fit(x_train, y_train, epochs=20, batch_size=256, validation_data=(x_val, y_val)) 209 | # score = model.evaluate(x_val, y_val, batch_size=128) 210 | # print('Test score:', score[0]) 211 | # print('Test accuracy:', score[1]) 212 | 213 | # # sklearn matrics 214 | # from sklearn.metrics import confusion_matrix 215 | # y_pred = model.predict(x_val,batch_size=64) 216 | # y_pred_label = [c.index(max(c)) for c in y_pred.tolist()] 217 | # y_true_label = [c.index(max(c)) for c in y_val.tolist()] 218 | # y_pred_label = [original_labels[i] for i in y_pred_label] 219 | # y_true_label = [original_labels[i] for i in y_true_label] 220 | # matrix = confusion_matrix(y_true_label, y_pred_label,original_labels) 221 | # # matplotlib 222 | # import matplotlib.pyplot as plt 223 | # plt.matshow(matrix) 224 | # plt.colorbar() 225 | # plt.xlabel('Prediction') 226 | # plt.ylabel('True') 227 | # plt.xticks(matrix[1],label_dict.keys()) 228 | # plt.yticks(matrix[1],label_dict.keys()) 229 | # # plt.show() 230 | # plt.savefig("confusion_matrix.jpg") 231 | # 232 | # # classification_report 233 | # from sklearn.metrics import classification_report 234 | # print "classification_report(left: labels):" 235 | # for key,value in label_dict.iteritems(): 236 | # print "dict[%s]="%key,value 237 | # print classification_report(y_val, y_pred) 238 | # 239 | # 240 | # show model 241 | # from keras.utils import plot_model 242 | # 243 | # model.save_weights(model_filepath) 244 | # plot_model(model, to_file='model.png') 245 | -------------------------------------------------------------------------------- /clean_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# coding=utf-8\n", 12 | "import zipfile\n", 13 | "import os\n", 14 | "import sys\n", 15 | "import logging\n", 16 | "import re" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 3, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "split_type_content_mark = '|'\n", 28 | "base_data_dir = \"../../data/sogou_resource/\"" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 4, 34 | "metadata": { 35 | "collapsed": true 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "\"\"\"转换指定目录下文件的编码为utf-8\"\"\"\n", 40 | "def convert(filename,from_encode=\"GBK\",to_encode=\"UTF-8\"):\n", 41 | " try:\n", 42 | "# print \"convert \"+ filename\n", 43 | " content = open(filename,'rb').read()\n", 44 | "# print \"content:\"+content\n", 45 | " new_content = content.decode(from_encode).encode(to_encode)\n", 46 | "# print \"new_content:\"+new_content\n", 47 | " open(filename,\"w\").write(new_content)\n", 48 | " print \"done\"\n", 49 | " except:\n", 50 | " print \"error\"\n", 51 | "\n", 52 | "def explore(dir):\n", 53 | " for root,dirs,files in os.walk(dir):\n", 54 | " for file in files:\n", 55 | " path = os.path.join(root,file)\n", 56 | " convert(path)\n", 57 | "\n", 58 | "def do_convert(path):\n", 59 | " if os.path.isfile(path):\n", 60 | " convert(path)\n", 61 | " if os.path.isdir(path):\n", 62 | " explore(path)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 5, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "path = \"../../data/sogou_resource/\"\n", 72 | "do_convert(path)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 16, 78 | "metadata": { 79 | "collapsed": true 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "\"\"\"使用jieba分词进行中文分词\"\"\"\n", 84 | "import jieba\n", 85 | "import zhon\n", 86 | "\n", 87 | "separated_word_file_dir = \"word_separated\"\n", 88 | "# 清华新闻语料库\n", 89 | "types = [\"体育\",\"娱乐\",\"家居\",\"彩票\",\"房产\",\"教育\",\"时尚\",\"时政\",\"星座\",\"游戏\",\"社会\",\"科技\",\"股票\",\"财经\"]\n", 90 | "# types = [\"体育\"]\n", 91 | "\n", 92 | "\n", 93 | "# 搜狗实验室新闻语料类别\n", 94 | "# types = [\"体育\",\"健康\",\"军事\",\"招聘\",\"文化\",\"教育\",\"财经\",\"旅游\",\"互联网\"]\n", 95 | "\n", 96 | "def ch_and_en_word_extraction(content_raw):\n", 97 | " \"\"\"抽取中文和英文\"\"\"\n", 98 | " pattern = re.compile(u\"([\\u4e00-\\u9fa5a-zA-Z0-9]+)\")\n", 99 | " re_data = pattern.findall(content_raw)\n", 100 | " clean_content = ' '.join(re_data)\n", 101 | " return clean_content\n", 102 | "\n", 103 | "def clean_str(s):\n", 104 | " s = s.strip() # 前后的空格\n", 105 | " s = s.strip('\\n') #换行符\n", 106 | " s = re.sub(\"[ \\t\\n\\r]*\",'',s) # tab, newline, return\n", 107 | " s = re.sub(\"<(\\S*?)[^>]*>.*?|<.*? />\",'',s) # html标签\n", 108 | " s = re.sub(\" +|<+|>+\",'',s) # html中的空格符号,大于,小于\n", 109 | " s = re.sub(\"[a-zA-z]+://[^\\s]*\",'',s) # URL\n", 110 | " s = re.sub(r'([\\w-]+(\\.[\\w-]+)*@[\\w-]+(\\.[\\w-]+)+)','',s) # email\n", 111 | " # 标点符号,需要先转utf-8,否则符号匹配不成功\n", 112 | " s = re.sub(ur\"([%s])+\" %zhon.hanzi.punctuation,\" \",s.decode('utf-8'))\n", 113 | " # 抽取中文和英文\n", 114 | " s = ch_and_en_word_extraction(s)\n", 115 | " return s\n", 116 | "\n", 117 | "def separate_words(infile,outfile):\n", 118 | " try:\n", 119 | " outf = open(outfile,'w')\n", 120 | " inf = open(infile,'r')\n", 121 | "\n", 122 | " space = ' '\n", 123 | "# print 'separate '+infile\n", 124 | " for line in inf.readlines():\n", 125 | "# line = clean_str(line)\n", 126 | " # 除空行\n", 127 | " if not len(line.strip()):\n", 128 | " continue\n", 129 | " seg_list = jieba.cut(line)\n", 130 | " \"\"\"此处需要循环每个单词编码为utf-8,jieba.cut将结果转为了unicode编码,\n", 131 | " 直接write(space.join(seg_list))会报编码错误\"\"\"\n", 132 | " for word in seg_list:\n", 133 | " try:\n", 134 | " word = word.encode('UTF-8')\n", 135 | " except:\n", 136 | " continue\n", 137 | " outf.write(word)\n", 138 | " outf.write(' ')\n", 139 | " outf.write('\\n')\n", 140 | " # close file stream\n", 141 | " outf.close()\n", 142 | " inf.close()\n", 143 | " except:\n", 144 | " pass\n", 145 | "# print \"error occured when write to \"+outfile\n", 146 | "\n", 147 | "\n", 148 | "\n", 149 | "def is_target_dir(path):\n", 150 | " if os.path.dirname(path).split(\"/\")[-1] in types and not re.match(\".DS_Store\",os.path.basename(path)):\n", 151 | " return True\n", 152 | " else:\n", 153 | " return False\n", 154 | " \n", 155 | "def explore(dir):\n", 156 | " for root,dirs,files in os.walk(dir):\n", 157 | " for file in files:\n", 158 | " path = os.path.join(root,file)\n", 159 | " if is_target_dir(path):\n", 160 | " child_dir = os.path.join(root,separated_word_file_dir)\n", 161 | " if not os.path.exists(child_dir):\n", 162 | " os.mkdir(child_dir)\n", 163 | " print \"make dir: \"+child_dir\n", 164 | " separate_words(path,os.path.join(child_dir,file))\n", 165 | " \n", 166 | "\n", 167 | "def do_batch_separate(path):\n", 168 | " if os.path.isfile(path) and is_target_dir(path):\n", 169 | " separate_words(path,os.path.join(root,separated_word_file_dir,path))\n", 170 | " if os.path.isdir(path):\n", 171 | " explore(path)\n", 172 | " " 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": { 179 | "scrolled": true 180 | }, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "begin time: 2017-11-11 00:32:09.642573\n", 187 | "make dir: ../../data/THUCNews/体育/word_separated\n", 188 | "make dir: ../../data/THUCNews/娱乐/word_separated\n" 189 | ] 190 | } 191 | ], 192 | "source": [ 193 | "import time\n", 194 | "import datetime\n", 195 | "\n", 196 | "original_dir = \"../../data/THUCNews/\"\n", 197 | "now = datetime.datetime.now()\n", 198 | "print \"begin time:\",now\n", 199 | "begin_time = time.time()\n", 200 | "do_batch_separate(original_dir)\n", 201 | "end_time = time.time()\n", 202 | "now = datetime.datetime.now()\n", 203 | "print \"end time:\",now\n", 204 | "print \"time used:\"+str(end_time-begin_time)+\"秒\"" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 410, 210 | "metadata": { 211 | "collapsed": true 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "\"\"\"将所有语料,整合成csv类型文件,文件格式:type,content\"\"\"\n", 216 | "def combine_file(file,outfile):\n", 217 | " # the type of file ,file示例:xxx/互联网/xxx/xxx.txt\n", 218 | " label = os.path.dirname(file).split('/')[-2]\n", 219 | " content = open(file).read()\n", 220 | "# print \"content:\"+content\n", 221 | "# print \"len:\",len(content)\n", 222 | " if len(content)>1: #排除前面步骤中写文件时,内容为只写入一个空格的情况\n", 223 | " new_content = label+\",\"+content\n", 224 | "# print \"new_content:\\n \" + new_content\n", 225 | " open(outfile,\"a\").write(new_content)\n", 226 | "\n", 227 | "def do_combine(dir,outfile):\n", 228 | " print \"deal with dir: \"+ dir\n", 229 | " for root,dirs,files in os.walk(dir):\n", 230 | " for file in files:\n", 231 | " match = re.match(r'\\d+\\.txt',file)\n", 232 | " if match:\n", 233 | " path = os.path.join(root,file)\n", 234 | " print \"combine \"+ path\n", 235 | " combine_file(path,outfile)\n", 236 | " \n", 237 | "def create_csv_file(dir,filename):\n", 238 | " csv_title = \"type,content\\n\"\n", 239 | " filepath = os.path.join(dir,filename+'.csv')\n", 240 | " open(filepath,'w').write(csv_title)\n", 241 | " return filepath" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 6, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "base_dir = \"../../data/THUCNews/\"\n", 251 | "\"\"\"创建处理后的数据集的目录\"\"\"\n", 252 | "dataset_dir = os.path.join(base_dir,\"dataset\")\n", 253 | "if not os.path.exists(dataset_dir):\n", 254 | " os.mkdir(dataset_dir)\n", 255 | " \n", 256 | "\"\"\"创建每个type目录对应的csv文件,并将一个type目录下的文件写到同一个对应的csv文件\"\"\"\n", 257 | "# 清华新闻语料库\n", 258 | "type_name_list = [\"体育\",\"娱乐\",\"家居\",\"彩票\",\"房产\",\"教育\",\"时尚\",\"时政\",\"星座\",\"游戏\",\"社会\",\"科技\",\"股票\",\"财经\"]\n", 259 | "\n", 260 | "# 搜狗新闻语料库\n", 261 | "# type_name_list = [\"体育\",\"健康\",\"军事\",\"招聘\",\"文化\",\"教育\",\"财经\",\"旅游\",\"互联网\"]\n", 262 | "\n", 263 | "for name in type_name_list:\n", 264 | " path = create_csv_file(dataset_dir,name)\n", 265 | " print \"going to combine file to \" + path\n", 266 | " do_combine(os.path.join(base_dir,name,\"word_separated\"),path)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 413, 272 | "metadata": { 273 | "collapsed": true 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "\"\"\"随机采样每个类别的约20%作为测试集,80%作为训练集\"\"\"\n", 278 | "import random\n", 279 | " \n", 280 | "def extract_test_and_train_set(filepath,train_file,test_file):\n", 281 | " try:\n", 282 | " test_f = open(test_file,'a')\n", 283 | " train_f = open(train_file,'a')\n", 284 | " try:\n", 285 | " with open(filepath) as f:\n", 286 | " is_title_line = True\n", 287 | " for line in f.readlines():\n", 288 | " if is_title_line:\n", 289 | " is_title_line = False\n", 290 | " continue\n", 291 | " if not len(line):\n", 292 | " continue\n", 293 | " if random.random() <= 0.2:\n", 294 | " test_f.write(line)\n", 295 | " else:\n", 296 | " train_f.write(line)\n", 297 | " except:\n", 298 | " print \"IO ERROR\"\n", 299 | " finally:\n", 300 | " test_f.close()\n", 301 | " train_f.close()\n", 302 | " except:\n", 303 | " print \"can not open file\"\n", 304 | "\n", 305 | "def do_extract(source_dir,train_f,test_f):\n", 306 | " for root,dirs,files in os.walk(source_dir):\n", 307 | " for file in files:\n", 308 | " if re.match(\"test|train\\.csv\",file) or not re.match(\".*\\.csv\",file):\n", 309 | " continue\n", 310 | " path = os.path.join(root,file)\n", 311 | " print \"extract file: \"+ path\n", 312 | " extract_test_and_train_set(path,train_f,test_f)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 414, 318 | "metadata": {}, 319 | "outputs": [ 320 | { 321 | "name": "stdout", 322 | "output_type": "stream", 323 | "text": [ 324 | "create file: ../../data/sogou_resource/dataset/train.csv\n", 325 | "create file:../../data/sogou_resource/dataset/test.csv\n", 326 | "extract file: ../../data/sogou_resource/dataset/互联网.csv\n", 327 | "extract file: ../../data/sogou_resource/dataset/体育.csv\n", 328 | "extract file: ../../data/sogou_resource/dataset/健康.csv\n", 329 | "extract file: ../../data/sogou_resource/dataset/军事.csv\n", 330 | "extract file: ../../data/sogou_resource/dataset/招聘.csv\n", 331 | "extract file: ../../data/sogou_resource/dataset/教育.csv\n", 332 | "extract file: ../../data/sogou_resource/dataset/文化.csv\n", 333 | "extract file: ../../data/sogou_resource/dataset/旅游.csv\n", 334 | "extract file: ../../data/sogou_resource/dataset/财经.csv\n" 335 | ] 336 | } 337 | ], 338 | "source": [ 339 | "# do extract\n", 340 | "dataset_dir = \"../../data/sogou_resource/dataset/\"\n", 341 | "train_dataset = os.path.join(dataset_dir,\"train.csv\")\n", 342 | "test_dataset = os.path.join(dataset_dir,\"test.csv\")\n", 343 | "if not os.path.exists(train_dataset):\n", 344 | " print \"create file: \"+train_dataset\n", 345 | " open(train_dataset,'w').write(\"type,content\\n\")\n", 346 | "if not os.path.exists(test_dataset):\n", 347 | " print \"create file:\"+test_dataset\n", 348 | " open(test_dataset,'w').write(\"type,content\\n\")\n", 349 | " \n", 350 | "do_extract(dataset_dir,train_dataset,test_dataset)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 76, 356 | "metadata": { 357 | "collapsed": true 358 | }, 359 | "outputs": [], 360 | "source": [ 361 | "# -*- encoding: utf-8 -*-\n", 362 | "\n", 363 | "# 提取命名实体和名词、动词\n", 364 | "import requests\n", 365 | "from bosonnlp import BosonNLP\n", 366 | "\n", 367 | "Token = 'alM_aH0F.11177.NKoYB68fl19N'\n", 368 | "noun = ['n','nr','ns','nt','nz','nl','vd','vi','vl','nx']\n", 369 | "nlp = BosonNLP(Token)\n", 370 | "\n", 371 | "def extract_entity_and_nouns(content_raw):\n", 372 | " result = nlp.ner(content_raw)[0]\n", 373 | " words = result['word']\n", 374 | " entities = result['entity']\n", 375 | " tags = result['tag']\n", 376 | " entity_list = [words[it[0]:it[1]] for it in entities]\n", 377 | " con_entity = reduce(lambda x,y:x+y,entity_list)\n", 378 | " con_noun = [it[0] for it in zip(words,tags) if it[1] in noun]\n", 379 | " entity_noun_union = set(con_entity).union(set(con_noun))\n", 380 | " \n", 381 | " content = [word for word in [s.strip() for s in content_raw.split(' ')] if word in entity_noun_union]\n", 382 | " return content\n", 383 | "\n", 384 | "def extract_nouns(content_raw):\n", 385 | " result = nlp.tag(content_raw,space_mode=1)[0]\n", 386 | " content = [d[0] for d in zip(result['word'],result['tag']) if d[1] in noun]\n", 387 | " # 转成utf-8编码\n", 388 | " content = map(lambda x:x.encode('utf-8'),content)\n", 389 | " return content\n", 390 | "\n", 391 | "def do_extract_entity_and_nouns(content_file,newfile):\n", 392 | " print \"do_extract_entity_and_nouns in \"+content_file\n", 393 | " newf = open(newfile,'w')\n", 394 | " istitle = True\n", 395 | " with open(content_file) as f:\n", 396 | " for line in f.readlines():\n", 397 | " if istitle:\n", 398 | " istitle = False\n", 399 | " newf.write(line)\n", 400 | " continue\n", 401 | " type_content = line.split(\",\")\n", 402 | " content_raw = type_content[1]\n", 403 | " new_cont = extract_nouns(content_raw)\n", 404 | " new_line = type_content[0]+','+' '.join(new_cont)+'\\n'\n", 405 | " print new_line\n", 406 | " newf.write(new_line)\n", 407 | " newf.close()" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 7, 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "test_file = \"../../data/sogou_resource/dataset/test.csv\"\n", 417 | "train_file = \"../../data/sogou_resource/dataset/train.csv\"\n", 418 | "new_test_file = \"../../data/sogou_resource/dataset/extract_test.csv\"\n", 419 | "new_train_file = \"../../data/sogou_resource/dataset/extract_train.csv\"\n", 420 | "do_extract_entity_and_nouns(test_file,new_test_file)\n", 421 | "do_extract_entity_and_nouns(train_file,new_train_file)\n" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 78, 427 | "metadata": { 428 | "collapsed": true 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "\"\"\"清洗数据,除掉停用词,剔除坏样本\"\"\"\n", 433 | "def clean_stopwords(content_raw,stopwords_set):\n", 434 | " content_list = [x for x in content_raw.split(\" \") if x!='']\n", 435 | " common_set = set(content_list) & stopwords_set\n", 436 | " new_content = filter(lambda x:x not in common_set,content_list)\n", 437 | " return new_content\n", 438 | "\n", 439 | "def do_clean_stopwords(content_file,stopwords_file,newfile):\n", 440 | " print \"clean stopwords in \"+content_file\n", 441 | " stopwords = []\n", 442 | " # 获取停用词\n", 443 | " with open(stopwords_file) as fi:\n", 444 | " for line in fi.readlines():\n", 445 | " stopwords.append(line.strip())\n", 446 | " newf = open(newfile,'w')\n", 447 | " with open(content_file) as f:\n", 448 | " for line in f.readlines():\n", 449 | " type_content = line.split(split_type_content_mark)\n", 450 | " content_raw = type_content[1]\n", 451 | " new_cont = clean_stopwords(content_raw,set(stopwords))\n", 452 | " new_line = type_content[0]+ split_type_content_mark +' '.join(new_cont)\n", 453 | " newf.write(new_line)\n", 454 | " newf.close()\n", 455 | " " 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 79, 461 | "metadata": {}, 462 | "outputs": [ 463 | { 464 | "name": "stdout", 465 | "output_type": "stream", 466 | "text": [ 467 | "clean stopwords in ../../data/sogou_resource/dataset/extract_test.csv\n", 468 | "clean stopwords in ../../data/sogou_resource/dataset/extract_train.csv\n" 469 | ] 470 | } 471 | ], 472 | "source": [ 473 | "test_file = \"../../data/sogou_resource/dataset/extract_test.csv\"\n", 474 | "train_file = \"../../data/sogou_resource/dataset/extract_train.csv\"\n", 475 | "new_test_file = \"../../data/sogou_resource/dataset/cleaned_extract_test.csv\"\n", 476 | "new_train_file = \"../../data/sogou_resource/dataset/cleaned_extract_train.csv\"\n", 477 | "stop_words_file = \"../../data/sogou_resource/dataset/news.stopwords.txt\"\n", 478 | "do_clean_stopwords(test_file,stop_words_file,new_test_file)\n", 479 | "do_clean_stopwords(train_file,stop_words_file,new_train_file)" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": 61, 485 | "metadata": { 486 | "collapsed": true 487 | }, 488 | "outputs": [], 489 | "source": [ 490 | "\"\"\"将文本句子化分割\"\"\"\n", 491 | "import zhon.hanzi\n", 492 | "import re\n", 493 | "\n", 494 | "\n", 495 | "def do_sentence_doc(doc_file,new_file):\n", 496 | " print \"to sentence file:\",doc_file\n", 497 | " newf = open(new_file,'w')\n", 498 | " with open(doc_file) as f:\n", 499 | " for line in f.readlines():\n", 500 | " type_content = line.split(split_type_content_mark)\n", 501 | " label = type_content[0]\n", 502 | " content_raw = type_content[1]\n", 503 | " new_content = re.sub(ur\"([%s])+\" %zhon.hanzi.non_stops,\" \",content_raw.decode('utf-8'))\n", 504 | " new_content = re.sub(ur\"([%s])+\" %zhon.hanzi.stops,\" \",content_raw.decode('utf-8'))\n", 505 | " new_type_content = label+split_type_content_mark+new_content.encode('utf-8')\n", 506 | " newf.write(new_type_content)\n", 507 | " newf.close()\n", 508 | " " 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": 62, 514 | "metadata": {}, 515 | "outputs": [ 516 | { 517 | "name": "stdout", 518 | "output_type": "stream", 519 | "text": [ 520 | "to sentence file: ../../data/sogou_resource/dataset/test_sentence_file.csv\n" 521 | ] 522 | } 523 | ], 524 | "source": [ 525 | "test_file = base_data_dir + \"dataset/test_sentence_file.csv\"\n", 526 | "new_file = base_data_dir + \"dataset/sentence_file.csv\"\n", 527 | "do_sentence_doc(test_file,new_file)\n" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": null, 533 | "metadata": { 534 | "collapsed": true 535 | }, 536 | "outputs": [], 537 | "source": [] 538 | } 539 | ], 540 | "metadata": { 541 | "kernelspec": { 542 | "display_name": "Python 2", 543 | "language": "python", 544 | "name": "python2" 545 | }, 546 | "language_info": { 547 | "codemirror_mode": { 548 | "name": "ipython", 549 | "version": 2 550 | }, 551 | "file_extension": ".py", 552 | "mimetype": "text/x-python", 553 | "name": "python", 554 | "nbconvert_exporter": "python", 555 | "pygments_lexer": "ipython2", 556 | "version": "2.7.10" 557 | } 558 | }, 559 | "nbformat": 4, 560 | "nbformat_minor": 2 561 | } 562 | --------------------------------------------------------------------------------