├── README.md
├── word2vec_train
├── separate_words.py
├── remove_words.py
├── process_wiki.py
├── train_word2vec_100_5_5_model.py
└── train_word2vec_skip_ngram_200_5_5_model_wiki.py
├── tools.ipynb
├── clean_data.py
├── text-classification-deep-learning.py
└── clean_data.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # text_classification_with_deep_learning
2 |
3 | 使用深度学习解决新闻文本分类的实验代码(测试准确率可达94%左右)
4 |
5 | 1.代码使用python实现,深度学习网络使用keras框架实现
6 |
7 | 2.代码为实验代码,文件路径都为实验环境中设定的路径,如需使用该代码注意修改路径
8 |
9 | 3.实验使用的数据有两部分:
10 | (1)wiki百科中文语料(https://dumps.wikimedia.org/zhwiki/latest/zhwiki-latest-pages-articles.xml.bz2)
11 | (2)清华自然语言处理实验室的中文文本分类数据集 THUCNews (http://thuctc.thunlp.org/)
12 | 其中,wiki百科语料用来训练word2vec模型;THUCNews数据集为文本分类任务数据集
13 |
14 |
--------------------------------------------------------------------------------
/word2vec_train/separate_words.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import logging
6 | import sys
7 | import jieba
8 |
9 | if __name__=='__main__':
10 |
11 | program = os.path.basename(sys.argv[0])
12 | logger = logging.getLogger(program)
13 |
14 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s')
15 | logging.root.setLevel(level=logging.INFO)
16 |
17 | if len(sys.argv) < 3:
18 | print(globals()['__doc__'] %locals())
19 | sys.exit(1)
20 |
21 | inp, outp = sys.argv[1:3]
22 | space = ' '
23 |
24 | output = open(outp, 'w')
25 | inputer = open(inp, 'r')
26 |
27 | for line in inputer.readlines():
28 | seg_list = jieba.cut(line)
29 | output.write(space.join(seg_list) + '\n')
30 |
31 |
--------------------------------------------------------------------------------
/word2vec_train/remove_words.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import logging
6 | import sys
7 | import re
8 |
9 | if __name__=='__main__':
10 |
11 | program = os.path.basename(sys.argv[0])
12 | logger = logging.getLogger(program)
13 |
14 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s')
15 | logging.root.setLevel(level=logging.INFO)
16 |
17 | if len(sys.argv) < 3:
18 | print(globals()['__doc__'] %locals())
19 | sys.exit(1)
20 |
21 | inp, outp = sys.argv[1:3]
22 |
23 | output = open(outp, 'w')
24 | inputer = open(inp, 'r')
25 |
26 | for line in inputer.readlines():
27 | ss = re.findall('[\u4e00-\u9fa5a-zA-Z0-9]',line)
28 | # ss = re.findall('[\n\s*\r\u4e00-\u9fa5]', line)
29 | output.write("".join(ss))
30 |
--------------------------------------------------------------------------------
/word2vec_train/process_wiki.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import logging
6 | import sys
7 |
8 | from gensim.corpora import WikiCorpus
9 |
10 | if __name__=='__main__':
11 |
12 | program = os.path.basename(sys.argv[0])
13 | logger = logging.getLogger(program)
14 |
15 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s')
16 | logging.root.setLevel(level=logging.INFO)
17 |
18 |
19 | if len(sys.argv) < 3:
20 | print(globals()['__doc__'] %locals())
21 | sys.exit(1)
22 |
23 | inp, outp = sys.argv[1:3]
24 | space =b' '
25 | i = 0
26 |
27 | output = open(outp, 'wb')
28 | wiki = WikiCorpus(inp, lemmatize=False, dictionary={})
29 | for text in wiki.get_texts():
30 | output.write(space.join(text) +b'\n')
31 | i = i + 1
32 | if i % 10000 == 0:
33 | logger.info('Saved ' + str(i) + ' articles')
34 |
35 | output.close()
36 | logger.info('Finished ' + str(i) + ' articles')
37 |
--------------------------------------------------------------------------------
/word2vec_train/train_word2vec_100_5_5_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import logging
5 | import os.path
6 | import sys
7 | import multiprocessing
8 |
9 | from time import time
10 | from gensim.corpora import WikiCorpus
11 | from gensim.models import Word2Vec
12 | from gensim.models.word2vec import LineSentence
13 | from gensim.models.word2vec import BrownCorpus
14 |
15 | if __name__== "__main__":
16 |
17 | program = os.path.basename(sys.argv[0])
18 | logger = logging.getLogger(program)
19 |
20 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s' )
21 | logging.root.setLevel(level=logging.INFO)
22 | logger.info('running %s' % ' '.join(sys.argv))
23 |
24 | if len(sys.argv) < 3:
25 | print globals()['__doc__'] % locals()
26 | sys.exit(1)
27 |
28 | inp, outp1, outp2= sys.argv[1:4]
29 |
30 | begin = time()
31 | # model = Word2Vec(BrownCorpus(inp), size=100, window=5, min_count=5,
32 | # workers=multiprocessing.cpu_count())
33 | model = Word2Vec(LineSentence(inp),sg=0, size=100, window=5, min_count=2,
34 | workers=multiprocessing.cpu_count())
35 | model.save(outp1)
36 | model.wv.save_word2vec_format(outp2, binary=True)
37 |
38 | end = time()
39 | print("total processing time:%d seconds" %(end-begin))
40 |
--------------------------------------------------------------------------------
/word2vec_train/train_word2vec_skip_ngram_200_5_5_model_wiki.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import logging
5 | import os.path
6 | import sys
7 | import multiprocessing
8 |
9 | from time import time
10 | from gensim.corpora import WikiCorpus
11 | from gensim.models import Word2Vec
12 | from gensim.models.word2vec import LineSentence
13 | from gensim.models.word2vec import BrownCorpus
14 |
15 | if __name__== "__main__":
16 |
17 | program = os.path.basename(sys.argv[0])
18 | logger = logging.getLogger(program)
19 |
20 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s' )
21 | logging.root.setLevel(level=logging.INFO)
22 | logger.info('running %s' % ' '.join(sys.argv))
23 |
24 | if len(sys.argv) < 3:
25 | print globals()['__doc__'] % locals()
26 | sys.exit(1)
27 |
28 | inp, outp1, outp2= sys.argv[1:4]
29 |
30 | begin = time()
31 | # model = Word2Vec(BrownCorpus(inp), size=100, window=5, min_count=5,
32 | # workers=multiprocessing.cpu_count())
33 | model = Word2Vec(LineSentence(inp),sg=1, size=200, window=5, min_count=5,
34 | workers=multiprocessing.cpu_count())
35 | model.save(outp1)
36 | model.wv.save_word2vec_format(outp2, binary=True)
37 |
38 | end = time()
39 | print("total processing time:%d seconds" %(end-begin))
40 |
--------------------------------------------------------------------------------
/tools.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# coding=utf-8\n",
12 | "import os\n",
13 | "import sys\n",
14 | "import shutil "
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "metadata": {
21 | "collapsed": true
22 | },
23 | "outputs": [],
24 | "source": [
25 | "\"\"\" 删除指定目录下目录名称中的含有指定字符串的子目录\"\"\"\n",
26 | "def delete_dir(path,s):\n",
27 | " for root,dirs,files in os.walk(path):\n",
28 | " for d in dirs:\n",
29 | " if d==s :\n",
30 | " print \"delete directory \"+os.path.join(root,d)\n",
31 | " shutil.rmtree(os.path.join(root,d))"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 3,
37 | "metadata": {},
38 | "outputs": [
39 | {
40 | "name": "stdout",
41 | "output_type": "stream",
42 | "text": [
43 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/体育/word_separated\n",
44 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/娱乐/word_separated\n",
45 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/家居/word_separated\n",
46 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/彩票/word_separated\n",
47 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/房产/word_separated\n",
48 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/教育/word_separated\n",
49 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/时尚/word_separated\n",
50 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/时政/word_separated\n",
51 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/星座/word_separated\n",
52 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/游戏/word_separated\n",
53 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/社会/word_separated\n",
54 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/科技/word_separated\n",
55 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/股票/word_separated\n",
56 | "delete directory /Users/zhangwei/Documents/paper/data/THUCNews/财经/word_separated\n"
57 | ]
58 | }
59 | ],
60 | "source": [
61 | "basedir = \"/Users/zhangwei/Documents/paper/data/THUCNews/\"\n",
62 | "s = \"word_separated\"\n",
63 | "delete_dir(basedir,s)"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 8,
69 | "metadata": {
70 | "collapsed": true
71 | },
72 | "outputs": [],
73 | "source": [
74 | "\"\"\"查看csv.zip文件内容\"\"\"\n",
75 | "import pandas as pd\n",
76 | "def lookat_csv_zip(filename):\n",
77 | " df = pd.read_csv(filename, compression='zip')\n",
78 | " columns = df.columns\n",
79 | " one_line = [df[x][0] for x in columns]\n",
80 | " return columns,one_line\n",
81 | " "
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 10,
87 | "metadata": {},
88 | "outputs": [
89 | {
90 | "data": {
91 | "text/html": [
92 | "
\n",
93 | "\n",
106 | "
\n",
107 | " \n",
108 | " \n",
109 | " | \n",
110 | " type | \n",
111 | " content | \n",
112 | "
\n",
113 | " \n",
114 | " \n",
115 | " \n",
116 | " | 0 | \n",
117 | " 互联网 | \n",
118 | " 思科 周二 盘后 发布 财年 第三季度 财报 财报 显示 受 股票 期权 开支 收购 相关 ... | \n",
119 | "
\n",
120 | " \n",
121 | " | 1 | \n",
122 | " 互联网 | \n",
123 | " eNet 硅谷动力 国外 媒体报道 一名 德国 数据库安全 工程师 日前 撰文 指出 目前为... | \n",
124 | "
\n",
125 | " \n",
126 | " | 2 | \n",
127 | " 互联网 | \n",
128 | " 作者 令狐 达 eNet 硅谷动力 国外 媒体报道 美国 高科技 市场调研 公司 M Met... | \n",
129 | "
\n",
130 | " \n",
131 | " | 3 | \n",
132 | " 互联网 | \n",
133 | " 作者 令狐 达 eNet 硅谷动力 台湾 媒体 引述 笔记本 厂商 消息人士 透露 Turi... | \n",
134 | "
\n",
135 | " \n",
136 | " | 4 | \n",
137 | " 互联网 | \n",
138 | " 作者 令狐 达 eNet 硅谷动力 台湾 媒体报道 液晶 显示器 大厂 宏基 公司 美国 分... | \n",
139 | "
\n",
140 | " \n",
141 | "
\n",
142 | "
"
143 | ],
144 | "text/plain": [
145 | " type content\n",
146 | "0 互联网 思科 周二 盘后 发布 财年 第三季度 财报 财报 显示 受 股票 期权 开支 收购 相关 ...\n",
147 | "1 互联网 eNet 硅谷动力 国外 媒体报道 一名 德国 数据库安全 工程师 日前 撰文 指出 目前为...\n",
148 | "2 互联网 作者 令狐 达 eNet 硅谷动力 国外 媒体报道 美国 高科技 市场调研 公司 M Met...\n",
149 | "3 互联网 作者 令狐 达 eNet 硅谷动力 台湾 媒体 引述 笔记本 厂商 消息人士 透露 Turi...\n",
150 | "4 互联网 作者 令狐 达 eNet 硅谷动力 台湾 媒体报道 液晶 显示器 大厂 宏基 公司 美国 分..."
151 | ]
152 | },
153 | "execution_count": 10,
154 | "metadata": {},
155 | "output_type": "execute_result"
156 | }
157 | ],
158 | "source": [
159 | "filename = \"../../data/sogou_resource/dataset/train.csv.zip\"\n",
160 | "# cols,line = lookat_csv_zip(filename)\n",
161 | "# print cols\n",
162 | "# print line\n",
163 | "\n",
164 | "# 调用api\n",
165 | "df = pd.read_csv(filename, compression='zip')\n",
166 | "df.head()"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 4,
172 | "metadata": {
173 | "collapsed": true
174 | },
175 | "outputs": [],
176 | "source": [
177 | "\"\"\"停用词表除重\"\"\"\n",
178 | "basedir = \"/Users/zhangwei/Documents/paper/data/sogou_resource/dataset/\"\n",
179 | "original_file = os.path.join(basedir,\"stopwords.txt\")\n",
180 | "new_file = os.path.join(basedir,\"new_stopwords.txt\")\n",
181 | "word_list = []\n",
182 | "with open(original_file) as f:\n",
183 | " for line in f.readlines():\n",
184 | " line = line.strip()\n",
185 | " word_list.append(line)\n",
186 | "word_list = list(set(word_list))\n",
187 | "nf = open(new_file,'w')\n",
188 | "for item in word_list:\n",
189 | " nf.write(item+'\\n')\n",
190 | "nf.close()"
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": null,
196 | "metadata": {
197 | "collapsed": true
198 | },
199 | "outputs": [],
200 | "source": []
201 | }
202 | ],
203 | "metadata": {
204 | "kernelspec": {
205 | "display_name": "Python 2",
206 | "language": "python",
207 | "name": "python2"
208 | },
209 | "language_info": {
210 | "codemirror_mode": {
211 | "name": "ipython",
212 | "version": 2
213 | },
214 | "file_extension": ".py",
215 | "mimetype": "text/x-python",
216 | "name": "python",
217 | "nbconvert_exporter": "python",
218 | "pygments_lexer": "ipython2",
219 | "version": "2.7.10"
220 | }
221 | },
222 | "nbformat": 4,
223 | "nbformat_minor": 2
224 | }
225 |
--------------------------------------------------------------------------------
/clean_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import os
3 | import re
4 | import jieba
5 | import time
6 | import datetime
7 |
8 | """使用jieba分词进行中文分词"""
9 | separated_word_file_dir = "word_separated"
10 | # 清华新闻语料库
11 | types = ["体育", "娱乐", "家居", "彩票", "房产", "教育", "时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"]
12 |
13 |
14 | def ch_and_en_word_extraction(content_raw):
15 | """抽取中文和英文"""
16 | pattern = re.compile(u"([\u4e00-\u9fa5a-zA-Z0-9]+)")
17 | re_data = pattern.findall(content_raw)
18 | clean_content = ' '.join(re_data)
19 | return clean_content
20 |
21 |
22 | def clean_str(s):
23 | # s = s.strip('\n') # 换行符
24 | # s = re.sub("[\t\n\r]*", '', s) # tab, newline, return
25 | s = re.sub('\|+',' ',s)
26 | s = re.sub('\s+',' ',s)
27 | s = s.strip() # 前后的空格
28 | # s = re.sub("<(\S*?)[^>]*>.*?\1>|<.*? />", '', s) # html标签
29 | # s = re.sub(" +|<+|>+", '', s) # html中的空格符号,大于,小于
30 | # s = re.sub("[a-zA-z]+://[^\s]*", '', s) # URL
31 | # s = re.sub(r'([\w-]+(\.[\w-]+)*@[\w-]+(\.[\w-]+)+)', '', s) # email
32 | # 标点符号,需要先转utf-8,否则符号匹配不成功
33 | # s = re.sub(ur"([%s])+" % zhon.hanzi.punctuation, " ", s.decode('utf-8'))
34 | # 抽取中文和英文
35 | # s = ch_and_en_word_extraction(s)
36 | return s
37 |
38 |
39 | def separate_words(infile, outfile):
40 | try:
41 | outf = open(outfile, 'w')
42 | inf = open(infile, 'r')
43 |
44 | space = ' '
45 | # print 'separate '+infile
46 | isFirstLine = True
47 | for line in inf.readlines():
48 | line = clean_str(line)
49 | # 除空行
50 | if not len(line):
51 | continue
52 | seg_list = jieba.cut(line)
53 | """此处需要循环每个单词编码为utf-8,jieba.cut将结果转为了unicode编码,
54 | 直接write(space.join(seg_list))会报编码错误"""
55 | for word in seg_list:
56 | if not len(word.strip()):
57 | continue
58 | try:
59 | word = word.strip().encode('UTF-8')
60 | except:
61 | continue
62 | outf.write(word)
63 | outf.write(space)
64 | if isFirstLine:
65 | outf.write("。")
66 | isFirstLine = False
67 | outf.write('\n')
68 | # close file stream
69 | outf.close()
70 | inf.close()
71 | except:
72 | print "error occured when write to " + outfile
73 |
74 |
75 | def is_target_dir(path):
76 | if os.path.dirname(path).split("/")[-1] in types and not re.match(".DS_Store", os.path.basename(path)):
77 | return True
78 | else:
79 | return False
80 |
81 |
82 | def explore(dir):
83 | for root, dirs, files in os.walk(dir):
84 | for file in files:
85 | path = os.path.join(root, file)
86 | if is_target_dir(path):
87 | child_dir = os.path.join(root, separated_word_file_dir)
88 | if not os.path.exists(child_dir):
89 | os.mkdir(child_dir)
90 | print "make dir: " + child_dir
91 | separate_words(path, os.path.join(child_dir, file))
92 |
93 |
94 | def do_batch_separate(path):
95 | if os.path.isfile(path) and is_target_dir(path):
96 | separate_words(path, os.path.join(root, separated_word_file_dir, path))
97 | if os.path.isdir(path):
98 | explore(path)
99 |
100 |
101 | original_dir = "THUCNews_deal_title/"
102 | now = datetime.datetime.now()
103 | print "separate word begin time:", now
104 | begin_time = time.time()
105 | do_batch_separate(original_dir)
106 | end_time = time.time()
107 | now = datetime.datetime.now()
108 | print "separate word,end time:", now
109 | print "separate word,time used:" + str(end_time - begin_time) + "秒"
110 |
111 |
112 | """将所有语料,整合成csv类型文件,文件格式:type|content"""
113 | split_mark = '|'
114 |
115 | def combine_file(file, outfile):
116 | # the type of file ,file示例:xxx/互联网/xxx/xxx.txt
117 | label = os.path.dirname(file).split('/')[-2]
118 | content = open(file).read()
119 | # print "content:"+content
120 | # print "len:",len(content)
121 | if len(content) > 1: # 排除前面步骤中写文件时,内容为只写入一个空格的情况
122 | new_content = label + split_mark + content
123 | # print "new_content:\n " + new_content
124 | open(outfile, "a").write(new_content)
125 |
126 |
127 | def do_combine(dir, outfile):
128 | print "deal with dir: " + dir
129 | for root, dirs, files in os.walk(dir):
130 | for file in files:
131 | match = re.match(r'\d+\.txt', file)
132 | if match:
133 | path = os.path.join(root, file)
134 | # print "combine " + path
135 | combine_file(path, outfile)
136 |
137 |
138 | def create_csv_file(dir, filename):
139 | csv_title = "type"+ split_mark + "content\n"
140 | filepath = os.path.join(dir, filename + '.csv')
141 | open(filepath, 'w').write(csv_title)
142 | return filepath
143 |
144 |
145 | base_dir = "THUCNews_deal_title/"
146 | """创建处理后的数据集的目录"""
147 | dataset_dir = os.path.join(base_dir, "dataset")
148 | if not os.path.exists(dataset_dir):
149 | os.mkdir(dataset_dir)
150 |
151 | """创建每个type目录对应的csv文件,并将一个type目录下的文件写到同一个对应的csv文件"""
152 | # 清华新闻语料库
153 | type_name_list = ["体育", "娱乐", "家居", "彩票", "房产", "教育", "时尚", "时政", "星座", "游戏", "社会", "科技", "股票", "财经"]
154 |
155 | combine_begin_time = datetime.datetime.now()
156 | print "combine begin time:",combine_begin_time
157 | for name in type_name_list:
158 | path = create_csv_file(dataset_dir, name)
159 | print "going to combine file to " + path
160 | do_combine(os.path.join(base_dir, name, "word_separated"), path)
161 |
162 | combine_end_time = datetime.datetime.now()
163 | print "combine end time:",combine_end_time
164 |
165 | """随机采样每个类别的约20%作为测试集,80%作为训练集"""
166 | import random
167 |
168 | def extract_test_and_train_set(filepath, train_file, test_file):
169 | try:
170 | test_f = open(test_file, 'a')
171 | train_f = open(train_file, 'a')
172 | try:
173 | with open(filepath) as f:
174 | is_title_line = True
175 | for line in f.readlines():
176 | if is_title_line:
177 | is_title_line = False
178 | continue
179 | if not len(line):
180 | continue
181 | if random.random() <= 0.2:
182 | test_f.write(line)
183 | else:
184 | train_f.write(line)
185 | except:
186 | print "IO ERROR"
187 | finally:
188 | test_f.close()
189 | train_f.close()
190 | except:
191 | print "can not open file"
192 |
193 |
194 | def do_extract(source_dir, train_f, test_f):
195 | for root, dirs, files in os.walk(source_dir):
196 | for file in files:
197 | if re.match("test|train\.csv", file) or not re.match(".*\.csv", file):
198 | continue
199 | path = os.path.join(root, file)
200 | print "extract file: " + path
201 | extract_test_and_train_set(path, train_f, test_f)
202 |
203 |
204 | # do extract
205 | dataset_dir = "THUCNews_deal_title/dataset/"
206 | train_dataset = os.path.join(dataset_dir, "train.csv")
207 | test_dataset = os.path.join(dataset_dir, "test.csv")
208 | if not os.path.exists(train_dataset):
209 | print "create file: " + train_dataset
210 | open(train_dataset, 'w').write("type"+ split_mark+"content\n")
211 | if not os.path.exists(test_dataset):
212 | print "create file:" + test_dataset
213 | open(test_dataset, 'w').write("type"+split_mark+"content\n")
214 |
215 | do_extract(dataset_dir, train_dataset, test_dataset)
216 |
217 |
218 | """清洗数据,除掉停用词,剔除坏样本"""
219 |
220 | #
221 | # def clean_stopwords(content_raw, stopwords_set):
222 | # content_list = [x for x in re.split(' +|\t+',content_raw) if x != '']
223 | # common_set = set(content_list) & stopwords_set
224 | # new_content = filter(lambda x: x not in common_set, content_list)
225 | # return new_content
226 | #
227 | #
228 | # def do_clean_stopwords(content_file, stopwords_file, newfile):
229 | # print "clean stopwords in " + content_file
230 | # stopwords = []
231 | # # 获取停用词
232 | # with open(stopwords_file) as fi:
233 | # for line in fi.readlines():
234 | # stopwords.append(line.strip())
235 | # newf = open(newfile, 'w')
236 | # with open(content_file) as f:
237 | # for line in f.readlines():
238 | # type_content = line.split(split_mark)
239 | # content_raw = type_content[1]
240 | # new_cont = clean_stopwords(content_raw, set(stopwords))
241 | # new_line = type_content[0] + split_mark + ' '.join(new_cont).strip()
242 | # newf.write(new_line)
243 | # newf.write('\n')
244 | # newf.close()
245 | #
246 | # test_file = "THUCNews_deal_title/dataset/test.csv"
247 | # train_file = "THUCNews_deal_title/dataset/train.csv"
248 | # new_test_file = "THUCNews_deal_title/dataset/cleaned_test.csv"
249 | # new_train_file = "THUCNews_deal_title/dataset/cleaned_train.csv"
250 | # stop_words_file = "THUCNews_deal_title/dataset/news.stopwords.txt"
251 | # do_clean_stopwords(test_file,stop_words_file,new_test_file)
252 | # do_clean_stopwords(train_file,stop_words_file,new_train_file)
253 |
--------------------------------------------------------------------------------
/text-classification-deep-learning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | # 常量
4 | MAX_SEQUENCE_LENGTH = 300 # 每条新闻最大长度
5 | EMBEDDING_DIM = 200 # 词向量空间维度
6 | VALIDATION_SPLIT = 0.20 # 验证集比例
7 | SENTENCE_NUM = 30 # 句子的数目
8 | model_filepath = "./CNN_word2vec_1_layer_model"
9 |
10 | from keras.preprocessing.text import Tokenizer
11 | import pandas as pd
12 | from keras.models import Sequential
13 | from keras.layers import Dense, Input, Flatten, Dropout
14 | from keras.layers import Conv1D, MaxPooling1D, Embedding,GRU
15 | from keras.optimizers import Adam
16 | from keras import regularizers
17 | import gensim
18 | from time import time
19 | import keras.callbacks
20 | from keras.layers import LSTM,Bidirectional
21 | from keras.preprocessing.sequence import pad_sequences
22 | import numpy as np
23 |
24 | # input data
25 | train_filename = "./THUCNews_deal_title/dataset/sentenced_deal_stop_word_train.csv.zip"
26 | test_filename = "./THUCNews_deal_title/dataset/sentenced_deal_stop_word_test.csv.zip"
27 | train_df = pd.read_csv(train_filename,sep='|',compression = 'zip',error_bad_lines=False)
28 | test_df = pd.read_csv(test_filename,sep='|',compression='zip',error_bad_lines=False)
29 | content_df = train_df.append(test_df, ignore_index=True)
30 | # shuffle data
31 | from sklearn.utils import shuffle
32 | content_df = shuffle(content_df)
33 |
34 | all_texts = content_df['content']
35 | all_labels = content_df['type']
36 | print "新闻文本数量:", len(all_texts), len(all_labels)
37 | print "每类新闻的数量:\n", all_labels.value_counts()
38 |
39 | all_texts = all_texts.tolist()
40 | all_labels = all_labels.tolist()
41 |
42 | original_labels = list(set(all_labels))
43 | num_labels = len(original_labels)
44 | print "label counts:", num_labels
45 | # one_hot encode label
46 | one_hot = np.zeros((num_labels, num_labels), int)
47 | np.fill_diagonal(one_hot, 1)
48 | label_dict = dict(zip(original_labels, one_hot))
49 |
50 | tokenizer = Tokenizer()
51 | tokenizer.fit_on_texts(all_texts)
52 | sequences = tokenizer.texts_to_sequences(all_texts)
53 | word_index = tokenizer.word_index
54 | print "Found %s unique tokens." % len(word_index)
55 | data = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)
56 | labels = np.asarray([label_dict[x] for x in all_labels])
57 | print "shape of data tensor:", data.shape
58 | print "shape of label tensor:", labels.shape
59 |
60 | # 分割训练集和测试集
61 | p_train = int(len(data) * (1 - VALIDATION_SPLIT))
62 | x_train = data[:p_train]
63 | y_train = labels[:p_train]
64 | x_val = data[p_train:]
65 | y_val = labels[p_train:]
66 | print 'train docs:' + str(len(x_train))
67 | print 'validate docs:' + str(len(x_val))
68 |
69 | # 搭建模型
70 | def cnn_model(embedding_layer=None):
71 | model = Sequential()
72 | if embedding_layer:
73 | # embedding layer use pre_trained word2vec model
74 | model.add(embedding_layer)
75 | else:
76 | # random word vector
77 | model.add(Embedding(len(word_index)+1,EMBEDDING_DIM,input_length=MAX_SEQUENCE_LENGTH))
78 | model.add(Conv1D(128, 5, padding='valid', activation='relu'))
79 | model.add(MaxPooling1D(5))
80 | # model.add(Conv1D(128, 5, padding='valid', activation='relu'))
81 | # model.add(MaxPooling1D(5))
82 | # model.add(Conv1D(128, 5, padding='valid', activation='relu'))
83 | # model.add(MaxPooling1D(5))
84 | model.add(Flatten())
85 | return model
86 |
87 | def lstm_model(embedding_layer=None):
88 | model = Sequential()
89 | if embedding_layer:
90 | # embedding layer use pre_trained word2vec model
91 | model.add(embedding_layer)
92 | else:
93 | # random word vector
94 | model.add(Embedding(len(word_index)+1,EMBEDDING_DIM,input_length=MAX_SEQUENCE_LENGTH))
95 | # Native LSTM
96 | model.add(LSTM(200,dropout=0.2,recurrent_dropout=0.2))
97 | return model
98 |
99 | def gru_model(embedding_layer=None):
100 | model = Sequential()
101 | if embedding_layer:
102 | # embedding layer use pre_trained word2vec model
103 | model.add(embedding_layer)
104 | else:
105 | # random word vector
106 | model.add(Embedding(len(word_index)+1,EMBEDDING_DIM,input_length=MAX_SEQUENCE_LENGTH))
107 | # GRU
108 | model.add(GRU(200,dropout=0.2,recurrent_dropout=0.2))
109 | return model
110 |
111 | def bidirectional_lstm_model(embedding_layer=None):
112 | model = Sequential()
113 | if embedding_layer:
114 | # embedding layer use pre_trained word2vec model
115 | model.add(embedding_layer)
116 | else:
117 | # random word vector
118 | model.add(Embedding(len(word_index)+1,EMBEDDING_DIM,input_length=MAX_SEQUENCE_LENGTH))
119 | # Bidirection LSTM
120 | model.add(Bidirectional(LSTM(200,dropout=0.2,recurrent_dropout=0.2))
121 | return model
122 |
123 | def cnn_lstm_model(embedding_layer=None):
124 | model = Sequential()
125 | if embedding_layer:
126 | # embedding layer use pre_trained word2vec model
127 | model.add(embedding_layer)
128 | else:
129 | # random word vector
130 | model.add(Embedding(len(word_index)+1,EMBEDDING_DIM,input_length=MAX_SEQUENCE_LENGTH))
131 | model.add(Conv1D(128, 5, padding='valid', activation='relu'))
132 | model.add(MaxPooling1D(5))
133 | # model.add(Dropout(0.5))
134 | # model.add(GRU(128,dropout=0.2,recurrent_dropout=0.1,return_sequences = True))
135 | model.add(GRU(128,dropout=0.2,recurrent_dropout=0.1))
136 | return model
137 |
138 |
139 | # load word2vec model
140 | word2vec_model_file = "/home/zwei/workspace/nlp_study/word2vec_wiki_study/word2vec_train_wiki/WIKI_word2vec_model/word2vec_skip_ngram_200_10_5_model.bin"
141 | word2vec_model = gensim.models.KeyedVectors.load_word2vec_format(word2vec_model_file, binary=True)
142 | # print "test word 奥巴马:",word2vec_model['奥巴马'.decode('utf-8')]
143 |
144 | # construct embedding layer
145 | embedding_matrix = np.zeros((len(word_index)+1, EMBEDDING_DIM))
146 |
147 | no_word_file = open("no_word_file.txt",'w')
148 |
149 | for word, i in word_index.items():
150 | if word.decode('utf-8') in word2vec_model:
151 | embedding_matrix[i] = np.asarray(word2vec_model[word.decode('utf-8')], dtype='float32')
152 | else:
153 | # print "word not found in word2vec:", word
154 | no_word_file.write(word+"\n")
155 | embedding_matrix[i] = np.random.random(size=EMBEDDING_DIM)
156 |
157 | no_word_file.close()
158 |
159 | embedding_layer = Embedding(len(word_index)+1,
160 | EMBEDDING_DIM,
161 | weights=[embedding_matrix],
162 | input_length=MAX_SEQUENCE_LENGTH,
163 | trainable=False)
164 |
165 | # CNN + word2vec
166 | # model = cnn_model(embedding_layer)
167 |
168 | # CNN + radom word vector
169 | # model = cnn_model()
170 |
171 | # LSTM
172 | # model = lstm_model(embedding_layer)
173 |
174 | # GRU
175 | model = gru_model(embedding_layer)
176 |
177 | # CNN + LSTM
178 | # model = cnn_lstm_model(embedding_layer)
179 |
180 | # common
181 | model.add(Dense(EMBEDDING_DIM, activation='relu'))
182 | model.add(Dropout(0.5))
183 | model.add(Dense(labels.shape[1], activation='softmax',kernel_regularizer=regularizers.l2(0.1)))
184 | model.summary()
185 |
186 | # adam = Adam(lr=0.01)
187 |
188 | # tensorboard
189 | tensorboard = keras.callbacks.TensorBoard(log_dir="lstm_word2vec_log_THUCNews_deal_title/{}".format(time()))
190 |
191 | # function
192 | from keras import backend as K
193 | def recall(y_true, y_pred):
194 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
195 | possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
196 | recall = true_positives / (possible_positives + K.epsilon())
197 | return recall
198 | def precision(y_true, y_pred):
199 | true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
200 | predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
201 | precision = true_positives / (predicted_positives + K.epsilon())
202 | return precision
203 |
204 |
205 | model.compile(loss='categorical_crossentropy', optimizer='Adadelta', metrics=['acc'])
206 | model.fit(x_train, y_train, epochs=30, batch_size=256,callbacks=[tensorboard], validation_split=0.1)
207 |
208 | # model.fit(x_train, y_train, epochs=20, batch_size=256, validation_data=(x_val, y_val))
209 | # score = model.evaluate(x_val, y_val, batch_size=128)
210 | # print('Test score:', score[0])
211 | # print('Test accuracy:', score[1])
212 |
213 | # # sklearn matrics
214 | # from sklearn.metrics import confusion_matrix
215 | # y_pred = model.predict(x_val,batch_size=64)
216 | # y_pred_label = [c.index(max(c)) for c in y_pred.tolist()]
217 | # y_true_label = [c.index(max(c)) for c in y_val.tolist()]
218 | # y_pred_label = [original_labels[i] for i in y_pred_label]
219 | # y_true_label = [original_labels[i] for i in y_true_label]
220 | # matrix = confusion_matrix(y_true_label, y_pred_label,original_labels)
221 | # # matplotlib
222 | # import matplotlib.pyplot as plt
223 | # plt.matshow(matrix)
224 | # plt.colorbar()
225 | # plt.xlabel('Prediction')
226 | # plt.ylabel('True')
227 | # plt.xticks(matrix[1],label_dict.keys())
228 | # plt.yticks(matrix[1],label_dict.keys())
229 | # # plt.show()
230 | # plt.savefig("confusion_matrix.jpg")
231 | #
232 | # # classification_report
233 | # from sklearn.metrics import classification_report
234 | # print "classification_report(left: labels):"
235 | # for key,value in label_dict.iteritems():
236 | # print "dict[%s]="%key,value
237 | # print classification_report(y_val, y_pred)
238 | #
239 | #
240 | # show model
241 | # from keras.utils import plot_model
242 | #
243 | # model.save_weights(model_filepath)
244 | # plot_model(model, to_file='model.png')
245 |
--------------------------------------------------------------------------------
/clean_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# coding=utf-8\n",
12 | "import zipfile\n",
13 | "import os\n",
14 | "import sys\n",
15 | "import logging\n",
16 | "import re"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 3,
22 | "metadata": {
23 | "collapsed": true
24 | },
25 | "outputs": [],
26 | "source": [
27 | "split_type_content_mark = '|'\n",
28 | "base_data_dir = \"../../data/sogou_resource/\""
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 4,
34 | "metadata": {
35 | "collapsed": true
36 | },
37 | "outputs": [],
38 | "source": [
39 | "\"\"\"转换指定目录下文件的编码为utf-8\"\"\"\n",
40 | "def convert(filename,from_encode=\"GBK\",to_encode=\"UTF-8\"):\n",
41 | " try:\n",
42 | "# print \"convert \"+ filename\n",
43 | " content = open(filename,'rb').read()\n",
44 | "# print \"content:\"+content\n",
45 | " new_content = content.decode(from_encode).encode(to_encode)\n",
46 | "# print \"new_content:\"+new_content\n",
47 | " open(filename,\"w\").write(new_content)\n",
48 | " print \"done\"\n",
49 | " except:\n",
50 | " print \"error\"\n",
51 | "\n",
52 | "def explore(dir):\n",
53 | " for root,dirs,files in os.walk(dir):\n",
54 | " for file in files:\n",
55 | " path = os.path.join(root,file)\n",
56 | " convert(path)\n",
57 | "\n",
58 | "def do_convert(path):\n",
59 | " if os.path.isfile(path):\n",
60 | " convert(path)\n",
61 | " if os.path.isdir(path):\n",
62 | " explore(path)"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 5,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "path = \"../../data/sogou_resource/\"\n",
72 | "do_convert(path)"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 16,
78 | "metadata": {
79 | "collapsed": true
80 | },
81 | "outputs": [],
82 | "source": [
83 | "\"\"\"使用jieba分词进行中文分词\"\"\"\n",
84 | "import jieba\n",
85 | "import zhon\n",
86 | "\n",
87 | "separated_word_file_dir = \"word_separated\"\n",
88 | "# 清华新闻语料库\n",
89 | "types = [\"体育\",\"娱乐\",\"家居\",\"彩票\",\"房产\",\"教育\",\"时尚\",\"时政\",\"星座\",\"游戏\",\"社会\",\"科技\",\"股票\",\"财经\"]\n",
90 | "# types = [\"体育\"]\n",
91 | "\n",
92 | "\n",
93 | "# 搜狗实验室新闻语料类别\n",
94 | "# types = [\"体育\",\"健康\",\"军事\",\"招聘\",\"文化\",\"教育\",\"财经\",\"旅游\",\"互联网\"]\n",
95 | "\n",
96 | "def ch_and_en_word_extraction(content_raw):\n",
97 | " \"\"\"抽取中文和英文\"\"\"\n",
98 | " pattern = re.compile(u\"([\\u4e00-\\u9fa5a-zA-Z0-9]+)\")\n",
99 | " re_data = pattern.findall(content_raw)\n",
100 | " clean_content = ' '.join(re_data)\n",
101 | " return clean_content\n",
102 | "\n",
103 | "def clean_str(s):\n",
104 | " s = s.strip() # 前后的空格\n",
105 | " s = s.strip('\\n') #换行符\n",
106 | " s = re.sub(\"[ \\t\\n\\r]*\",'',s) # tab, newline, return\n",
107 | " s = re.sub(\"<(\\S*?)[^>]*>.*?\\1>|<.*? />\",'',s) # html标签\n",
108 | " s = re.sub(\" +|<+|>+\",'',s) # html中的空格符号,大于,小于\n",
109 | " s = re.sub(\"[a-zA-z]+://[^\\s]*\",'',s) # URL\n",
110 | " s = re.sub(r'([\\w-]+(\\.[\\w-]+)*@[\\w-]+(\\.[\\w-]+)+)','',s) # email\n",
111 | " # 标点符号,需要先转utf-8,否则符号匹配不成功\n",
112 | " s = re.sub(ur\"([%s])+\" %zhon.hanzi.punctuation,\" \",s.decode('utf-8'))\n",
113 | " # 抽取中文和英文\n",
114 | " s = ch_and_en_word_extraction(s)\n",
115 | " return s\n",
116 | "\n",
117 | "def separate_words(infile,outfile):\n",
118 | " try:\n",
119 | " outf = open(outfile,'w')\n",
120 | " inf = open(infile,'r')\n",
121 | "\n",
122 | " space = ' '\n",
123 | "# print 'separate '+infile\n",
124 | " for line in inf.readlines():\n",
125 | "# line = clean_str(line)\n",
126 | " # 除空行\n",
127 | " if not len(line.strip()):\n",
128 | " continue\n",
129 | " seg_list = jieba.cut(line)\n",
130 | " \"\"\"此处需要循环每个单词编码为utf-8,jieba.cut将结果转为了unicode编码,\n",
131 | " 直接write(space.join(seg_list))会报编码错误\"\"\"\n",
132 | " for word in seg_list:\n",
133 | " try:\n",
134 | " word = word.encode('UTF-8')\n",
135 | " except:\n",
136 | " continue\n",
137 | " outf.write(word)\n",
138 | " outf.write(' ')\n",
139 | " outf.write('\\n')\n",
140 | " # close file stream\n",
141 | " outf.close()\n",
142 | " inf.close()\n",
143 | " except:\n",
144 | " pass\n",
145 | "# print \"error occured when write to \"+outfile\n",
146 | "\n",
147 | "\n",
148 | "\n",
149 | "def is_target_dir(path):\n",
150 | " if os.path.dirname(path).split(\"/\")[-1] in types and not re.match(\".DS_Store\",os.path.basename(path)):\n",
151 | " return True\n",
152 | " else:\n",
153 | " return False\n",
154 | " \n",
155 | "def explore(dir):\n",
156 | " for root,dirs,files in os.walk(dir):\n",
157 | " for file in files:\n",
158 | " path = os.path.join(root,file)\n",
159 | " if is_target_dir(path):\n",
160 | " child_dir = os.path.join(root,separated_word_file_dir)\n",
161 | " if not os.path.exists(child_dir):\n",
162 | " os.mkdir(child_dir)\n",
163 | " print \"make dir: \"+child_dir\n",
164 | " separate_words(path,os.path.join(child_dir,file))\n",
165 | " \n",
166 | "\n",
167 | "def do_batch_separate(path):\n",
168 | " if os.path.isfile(path) and is_target_dir(path):\n",
169 | " separate_words(path,os.path.join(root,separated_word_file_dir,path))\n",
170 | " if os.path.isdir(path):\n",
171 | " explore(path)\n",
172 | " "
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "metadata": {
179 | "scrolled": true
180 | },
181 | "outputs": [
182 | {
183 | "name": "stdout",
184 | "output_type": "stream",
185 | "text": [
186 | "begin time: 2017-11-11 00:32:09.642573\n",
187 | "make dir: ../../data/THUCNews/体育/word_separated\n",
188 | "make dir: ../../data/THUCNews/娱乐/word_separated\n"
189 | ]
190 | }
191 | ],
192 | "source": [
193 | "import time\n",
194 | "import datetime\n",
195 | "\n",
196 | "original_dir = \"../../data/THUCNews/\"\n",
197 | "now = datetime.datetime.now()\n",
198 | "print \"begin time:\",now\n",
199 | "begin_time = time.time()\n",
200 | "do_batch_separate(original_dir)\n",
201 | "end_time = time.time()\n",
202 | "now = datetime.datetime.now()\n",
203 | "print \"end time:\",now\n",
204 | "print \"time used:\"+str(end_time-begin_time)+\"秒\""
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": 410,
210 | "metadata": {
211 | "collapsed": true
212 | },
213 | "outputs": [],
214 | "source": [
215 | "\"\"\"将所有语料,整合成csv类型文件,文件格式:type,content\"\"\"\n",
216 | "def combine_file(file,outfile):\n",
217 | " # the type of file ,file示例:xxx/互联网/xxx/xxx.txt\n",
218 | " label = os.path.dirname(file).split('/')[-2]\n",
219 | " content = open(file).read()\n",
220 | "# print \"content:\"+content\n",
221 | "# print \"len:\",len(content)\n",
222 | " if len(content)>1: #排除前面步骤中写文件时,内容为只写入一个空格的情况\n",
223 | " new_content = label+\",\"+content\n",
224 | "# print \"new_content:\\n \" + new_content\n",
225 | " open(outfile,\"a\").write(new_content)\n",
226 | "\n",
227 | "def do_combine(dir,outfile):\n",
228 | " print \"deal with dir: \"+ dir\n",
229 | " for root,dirs,files in os.walk(dir):\n",
230 | " for file in files:\n",
231 | " match = re.match(r'\\d+\\.txt',file)\n",
232 | " if match:\n",
233 | " path = os.path.join(root,file)\n",
234 | " print \"combine \"+ path\n",
235 | " combine_file(path,outfile)\n",
236 | " \n",
237 | "def create_csv_file(dir,filename):\n",
238 | " csv_title = \"type,content\\n\"\n",
239 | " filepath = os.path.join(dir,filename+'.csv')\n",
240 | " open(filepath,'w').write(csv_title)\n",
241 | " return filepath"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 6,
247 | "metadata": {},
248 | "outputs": [],
249 | "source": [
250 | "base_dir = \"../../data/THUCNews/\"\n",
251 | "\"\"\"创建处理后的数据集的目录\"\"\"\n",
252 | "dataset_dir = os.path.join(base_dir,\"dataset\")\n",
253 | "if not os.path.exists(dataset_dir):\n",
254 | " os.mkdir(dataset_dir)\n",
255 | " \n",
256 | "\"\"\"创建每个type目录对应的csv文件,并将一个type目录下的文件写到同一个对应的csv文件\"\"\"\n",
257 | "# 清华新闻语料库\n",
258 | "type_name_list = [\"体育\",\"娱乐\",\"家居\",\"彩票\",\"房产\",\"教育\",\"时尚\",\"时政\",\"星座\",\"游戏\",\"社会\",\"科技\",\"股票\",\"财经\"]\n",
259 | "\n",
260 | "# 搜狗新闻语料库\n",
261 | "# type_name_list = [\"体育\",\"健康\",\"军事\",\"招聘\",\"文化\",\"教育\",\"财经\",\"旅游\",\"互联网\"]\n",
262 | "\n",
263 | "for name in type_name_list:\n",
264 | " path = create_csv_file(dataset_dir,name)\n",
265 | " print \"going to combine file to \" + path\n",
266 | " do_combine(os.path.join(base_dir,name,\"word_separated\"),path)"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": 413,
272 | "metadata": {
273 | "collapsed": true
274 | },
275 | "outputs": [],
276 | "source": [
277 | "\"\"\"随机采样每个类别的约20%作为测试集,80%作为训练集\"\"\"\n",
278 | "import random\n",
279 | " \n",
280 | "def extract_test_and_train_set(filepath,train_file,test_file):\n",
281 | " try:\n",
282 | " test_f = open(test_file,'a')\n",
283 | " train_f = open(train_file,'a')\n",
284 | " try:\n",
285 | " with open(filepath) as f:\n",
286 | " is_title_line = True\n",
287 | " for line in f.readlines():\n",
288 | " if is_title_line:\n",
289 | " is_title_line = False\n",
290 | " continue\n",
291 | " if not len(line):\n",
292 | " continue\n",
293 | " if random.random() <= 0.2:\n",
294 | " test_f.write(line)\n",
295 | " else:\n",
296 | " train_f.write(line)\n",
297 | " except:\n",
298 | " print \"IO ERROR\"\n",
299 | " finally:\n",
300 | " test_f.close()\n",
301 | " train_f.close()\n",
302 | " except:\n",
303 | " print \"can not open file\"\n",
304 | "\n",
305 | "def do_extract(source_dir,train_f,test_f):\n",
306 | " for root,dirs,files in os.walk(source_dir):\n",
307 | " for file in files:\n",
308 | " if re.match(\"test|train\\.csv\",file) or not re.match(\".*\\.csv\",file):\n",
309 | " continue\n",
310 | " path = os.path.join(root,file)\n",
311 | " print \"extract file: \"+ path\n",
312 | " extract_test_and_train_set(path,train_f,test_f)"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": 414,
318 | "metadata": {},
319 | "outputs": [
320 | {
321 | "name": "stdout",
322 | "output_type": "stream",
323 | "text": [
324 | "create file: ../../data/sogou_resource/dataset/train.csv\n",
325 | "create file:../../data/sogou_resource/dataset/test.csv\n",
326 | "extract file: ../../data/sogou_resource/dataset/互联网.csv\n",
327 | "extract file: ../../data/sogou_resource/dataset/体育.csv\n",
328 | "extract file: ../../data/sogou_resource/dataset/健康.csv\n",
329 | "extract file: ../../data/sogou_resource/dataset/军事.csv\n",
330 | "extract file: ../../data/sogou_resource/dataset/招聘.csv\n",
331 | "extract file: ../../data/sogou_resource/dataset/教育.csv\n",
332 | "extract file: ../../data/sogou_resource/dataset/文化.csv\n",
333 | "extract file: ../../data/sogou_resource/dataset/旅游.csv\n",
334 | "extract file: ../../data/sogou_resource/dataset/财经.csv\n"
335 | ]
336 | }
337 | ],
338 | "source": [
339 | "# do extract\n",
340 | "dataset_dir = \"../../data/sogou_resource/dataset/\"\n",
341 | "train_dataset = os.path.join(dataset_dir,\"train.csv\")\n",
342 | "test_dataset = os.path.join(dataset_dir,\"test.csv\")\n",
343 | "if not os.path.exists(train_dataset):\n",
344 | " print \"create file: \"+train_dataset\n",
345 | " open(train_dataset,'w').write(\"type,content\\n\")\n",
346 | "if not os.path.exists(test_dataset):\n",
347 | " print \"create file:\"+test_dataset\n",
348 | " open(test_dataset,'w').write(\"type,content\\n\")\n",
349 | " \n",
350 | "do_extract(dataset_dir,train_dataset,test_dataset)"
351 | ]
352 | },
353 | {
354 | "cell_type": "code",
355 | "execution_count": 76,
356 | "metadata": {
357 | "collapsed": true
358 | },
359 | "outputs": [],
360 | "source": [
361 | "# -*- encoding: utf-8 -*-\n",
362 | "\n",
363 | "# 提取命名实体和名词、动词\n",
364 | "import requests\n",
365 | "from bosonnlp import BosonNLP\n",
366 | "\n",
367 | "Token = 'alM_aH0F.11177.NKoYB68fl19N'\n",
368 | "noun = ['n','nr','ns','nt','nz','nl','vd','vi','vl','nx']\n",
369 | "nlp = BosonNLP(Token)\n",
370 | "\n",
371 | "def extract_entity_and_nouns(content_raw):\n",
372 | " result = nlp.ner(content_raw)[0]\n",
373 | " words = result['word']\n",
374 | " entities = result['entity']\n",
375 | " tags = result['tag']\n",
376 | " entity_list = [words[it[0]:it[1]] for it in entities]\n",
377 | " con_entity = reduce(lambda x,y:x+y,entity_list)\n",
378 | " con_noun = [it[0] for it in zip(words,tags) if it[1] in noun]\n",
379 | " entity_noun_union = set(con_entity).union(set(con_noun))\n",
380 | " \n",
381 | " content = [word for word in [s.strip() for s in content_raw.split(' ')] if word in entity_noun_union]\n",
382 | " return content\n",
383 | "\n",
384 | "def extract_nouns(content_raw):\n",
385 | " result = nlp.tag(content_raw,space_mode=1)[0]\n",
386 | " content = [d[0] for d in zip(result['word'],result['tag']) if d[1] in noun]\n",
387 | " # 转成utf-8编码\n",
388 | " content = map(lambda x:x.encode('utf-8'),content)\n",
389 | " return content\n",
390 | "\n",
391 | "def do_extract_entity_and_nouns(content_file,newfile):\n",
392 | " print \"do_extract_entity_and_nouns in \"+content_file\n",
393 | " newf = open(newfile,'w')\n",
394 | " istitle = True\n",
395 | " with open(content_file) as f:\n",
396 | " for line in f.readlines():\n",
397 | " if istitle:\n",
398 | " istitle = False\n",
399 | " newf.write(line)\n",
400 | " continue\n",
401 | " type_content = line.split(\",\")\n",
402 | " content_raw = type_content[1]\n",
403 | " new_cont = extract_nouns(content_raw)\n",
404 | " new_line = type_content[0]+','+' '.join(new_cont)+'\\n'\n",
405 | " print new_line\n",
406 | " newf.write(new_line)\n",
407 | " newf.close()"
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "execution_count": 7,
413 | "metadata": {},
414 | "outputs": [],
415 | "source": [
416 | "test_file = \"../../data/sogou_resource/dataset/test.csv\"\n",
417 | "train_file = \"../../data/sogou_resource/dataset/train.csv\"\n",
418 | "new_test_file = \"../../data/sogou_resource/dataset/extract_test.csv\"\n",
419 | "new_train_file = \"../../data/sogou_resource/dataset/extract_train.csv\"\n",
420 | "do_extract_entity_and_nouns(test_file,new_test_file)\n",
421 | "do_extract_entity_and_nouns(train_file,new_train_file)\n"
422 | ]
423 | },
424 | {
425 | "cell_type": "code",
426 | "execution_count": 78,
427 | "metadata": {
428 | "collapsed": true
429 | },
430 | "outputs": [],
431 | "source": [
432 | "\"\"\"清洗数据,除掉停用词,剔除坏样本\"\"\"\n",
433 | "def clean_stopwords(content_raw,stopwords_set):\n",
434 | " content_list = [x for x in content_raw.split(\" \") if x!='']\n",
435 | " common_set = set(content_list) & stopwords_set\n",
436 | " new_content = filter(lambda x:x not in common_set,content_list)\n",
437 | " return new_content\n",
438 | "\n",
439 | "def do_clean_stopwords(content_file,stopwords_file,newfile):\n",
440 | " print \"clean stopwords in \"+content_file\n",
441 | " stopwords = []\n",
442 | " # 获取停用词\n",
443 | " with open(stopwords_file) as fi:\n",
444 | " for line in fi.readlines():\n",
445 | " stopwords.append(line.strip())\n",
446 | " newf = open(newfile,'w')\n",
447 | " with open(content_file) as f:\n",
448 | " for line in f.readlines():\n",
449 | " type_content = line.split(split_type_content_mark)\n",
450 | " content_raw = type_content[1]\n",
451 | " new_cont = clean_stopwords(content_raw,set(stopwords))\n",
452 | " new_line = type_content[0]+ split_type_content_mark +' '.join(new_cont)\n",
453 | " newf.write(new_line)\n",
454 | " newf.close()\n",
455 | " "
456 | ]
457 | },
458 | {
459 | "cell_type": "code",
460 | "execution_count": 79,
461 | "metadata": {},
462 | "outputs": [
463 | {
464 | "name": "stdout",
465 | "output_type": "stream",
466 | "text": [
467 | "clean stopwords in ../../data/sogou_resource/dataset/extract_test.csv\n",
468 | "clean stopwords in ../../data/sogou_resource/dataset/extract_train.csv\n"
469 | ]
470 | }
471 | ],
472 | "source": [
473 | "test_file = \"../../data/sogou_resource/dataset/extract_test.csv\"\n",
474 | "train_file = \"../../data/sogou_resource/dataset/extract_train.csv\"\n",
475 | "new_test_file = \"../../data/sogou_resource/dataset/cleaned_extract_test.csv\"\n",
476 | "new_train_file = \"../../data/sogou_resource/dataset/cleaned_extract_train.csv\"\n",
477 | "stop_words_file = \"../../data/sogou_resource/dataset/news.stopwords.txt\"\n",
478 | "do_clean_stopwords(test_file,stop_words_file,new_test_file)\n",
479 | "do_clean_stopwords(train_file,stop_words_file,new_train_file)"
480 | ]
481 | },
482 | {
483 | "cell_type": "code",
484 | "execution_count": 61,
485 | "metadata": {
486 | "collapsed": true
487 | },
488 | "outputs": [],
489 | "source": [
490 | "\"\"\"将文本句子化分割\"\"\"\n",
491 | "import zhon.hanzi\n",
492 | "import re\n",
493 | "\n",
494 | "\n",
495 | "def do_sentence_doc(doc_file,new_file):\n",
496 | " print \"to sentence file:\",doc_file\n",
497 | " newf = open(new_file,'w')\n",
498 | " with open(doc_file) as f:\n",
499 | " for line in f.readlines():\n",
500 | " type_content = line.split(split_type_content_mark)\n",
501 | " label = type_content[0]\n",
502 | " content_raw = type_content[1]\n",
503 | " new_content = re.sub(ur\"([%s])+\" %zhon.hanzi.non_stops,\" \",content_raw.decode('utf-8'))\n",
504 | " new_content = re.sub(ur\"([%s])+\" %zhon.hanzi.stops,\" \",content_raw.decode('utf-8'))\n",
505 | " new_type_content = label+split_type_content_mark+new_content.encode('utf-8')\n",
506 | " newf.write(new_type_content)\n",
507 | " newf.close()\n",
508 | " "
509 | ]
510 | },
511 | {
512 | "cell_type": "code",
513 | "execution_count": 62,
514 | "metadata": {},
515 | "outputs": [
516 | {
517 | "name": "stdout",
518 | "output_type": "stream",
519 | "text": [
520 | "to sentence file: ../../data/sogou_resource/dataset/test_sentence_file.csv\n"
521 | ]
522 | }
523 | ],
524 | "source": [
525 | "test_file = base_data_dir + \"dataset/test_sentence_file.csv\"\n",
526 | "new_file = base_data_dir + \"dataset/sentence_file.csv\"\n",
527 | "do_sentence_doc(test_file,new_file)\n"
528 | ]
529 | },
530 | {
531 | "cell_type": "code",
532 | "execution_count": null,
533 | "metadata": {
534 | "collapsed": true
535 | },
536 | "outputs": [],
537 | "source": []
538 | }
539 | ],
540 | "metadata": {
541 | "kernelspec": {
542 | "display_name": "Python 2",
543 | "language": "python",
544 | "name": "python2"
545 | },
546 | "language_info": {
547 | "codemirror_mode": {
548 | "name": "ipython",
549 | "version": 2
550 | },
551 | "file_extension": ".py",
552 | "mimetype": "text/x-python",
553 | "name": "python",
554 | "nbconvert_exporter": "python",
555 | "pygments_lexer": "ipython2",
556 | "version": "2.7.10"
557 | }
558 | },
559 | "nbformat": 4,
560 | "nbformat_minor": 2
561 | }
562 |
--------------------------------------------------------------------------------