├── preprocessed_data.npz
├── 用户意图分类APP
├── index_2_label_dict.pkl
├── word_2_index_dict.pkl
├── model_structure_json.pkl
├── SMP2018_GlobalAveragePooling1D_model(F1_86).h5
├── app.py
├── APP说明和使用APP.ipynb
└── SMP应用详细代码.ipynb
├── README.md
├── SMP2018_EDA_and_Baseline_Model(Keras).ipynb
└── SMP2018_EDA_and_Baseline_Model.ipynb
/preprocessed_data.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuanxiaosc/SMP2018/HEAD/preprocessed_data.npz
--------------------------------------------------------------------------------
/用户意图分类APP/index_2_label_dict.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuanxiaosc/SMP2018/HEAD/用户意图分类APP/index_2_label_dict.pkl
--------------------------------------------------------------------------------
/用户意图分类APP/word_2_index_dict.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuanxiaosc/SMP2018/HEAD/用户意图分类APP/word_2_index_dict.pkl
--------------------------------------------------------------------------------
/用户意图分类APP/model_structure_json.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuanxiaosc/SMP2018/HEAD/用户意图分类APP/model_structure_json.pkl
--------------------------------------------------------------------------------
/用户意图分类APP/SMP2018_GlobalAveragePooling1D_model(F1_86).h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuanxiaosc/SMP2018/HEAD/用户意图分类APP/SMP2018_GlobalAveragePooling1D_model(F1_86).h5
--------------------------------------------------------------------------------
/用户意图分类APP/app.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from keras.models import model_from_json
4 | from keras.preprocessing.sequence import pad_sequences
5 | import jieba
6 | import pickle
7 |
8 | # 加载 pickle 对象的函数
9 | def load_obj(name ):
10 | with open(name + '.pkl', 'rb') as f:
11 | return pickle.load(f)
12 |
13 | # 输入模型的最终单句长度
14 | max_cut_query_lenth = 26
15 |
16 | # 加载查询词汇和对应 ID 的字典
17 | word_2_index_dict = load_obj('word_2_index_dict')
18 | # 加载模型输出 ID 和对应标签(种类)的字典
19 | index_2_label_dict = load_obj('index_2_label_dict')
20 | # 加载模型结构
21 | model_structure_json = load_obj('model_structure_json')
22 | model = model_from_json(model_structure_json)
23 | # 加载模型权重
24 | model.load_weights('SMP2018_GlobalAveragePooling1D_model(F1_86).h5')
25 |
26 | def query_2_label(query_sentence):
27 | '''
28 | input query: "从中山到西安的汽车。"
29 | return label: "bus"
30 | '''
31 | x_input = []
32 | # 分词 ['从', '中山', '到', '西安', '的', '汽车', '。']
33 | query_sentence_list = list(jieba.cut(query_sentence))
34 | # 序列化 [54, 717, 0, 8, 0, 0, 1, 0, 183, 2]
35 | x = [word_2_index_dict.get(w, 0) for w in query_sentence]
36 | # 填充 array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
37 | # 0, 0, 0, 54, 717, 0, 8, 0, 0, 1, 0, 183, 2]], dtype=int32)
38 | x_input.append(x)
39 | x_input = pad_sequences(x_input, maxlen=max_cut_query_lenth)
40 | # 预测
41 | y_hat = model.predict(x_input)
42 | # 取最大值所在的序号 11
43 | pred_y_index = np.argmax(y_hat)
44 | # 查找序号所对应标签(类别)
45 | label = index_2_label_dict[pred_y_index]
46 | return label
47 |
48 | if __name__=="__main__":
49 | query_sentence = '狐臭怎么治?'
50 | print(query_2_label(query_sentence))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SMP2018
2 | > 通过SMP2018,展现处理中文文本分类的一般方法。特别是 [Keras 和中文分词工具 jieba 的联合使用](SMP2018_EDA_and_Baseline_Model(Keras).ipynb)
3 |
4 | SMP2018中文人机对话技术评测由中国中文信息学会社会媒体处理专委会主办,哈尔滨工业大学、科大讯飞股份有限公司承办,讯飞公司提供数据,华为公司提供奖金。旨在促进中文人机对话系统相关研究的发展,为人机对话技术相关的学术研究人员和产业界从业人员提供一个良好的沟通平台。在此,评测会务组诚邀各个单位参加本次人机对话技术评测活动!
5 |
6 | # 用户意图领域分类
7 |
8 | 在人机对话系统的应用过程中,用户可能会有多种意图,相应地会触发人机对话系统中的多个领域(domain) ,其中包括任务型垂直领域(如查询机票、酒店、公交车等)、知识型问答以及闲聊等。因而,人机对话系统的一个关键任务就是正确地将用户的输入分类到相应的领域(domain)中,从而才能返回正确的回复结果。
9 |
10 | **例如**
11 |
12 | 1) 你好啊,很高兴见到你! — 闲聊类
13 |
14 | 2) 我想订一张去北京的机票。 — 任务型垂类(订机票)
15 |
16 | 3) 我想找一家五道口附近便宜干净的快捷酒店 — 任务型垂类(订酒店)
17 |
18 | ## 相关资源
19 |
20 | |标题|说明|
21 | |-|-|
22 | |[CodaLab评测主页](https://worksheets.codalab.org/worksheets/0x27203f932f8341b79841d50ce0fd684f/)|[数据下载](https://worksheets.codalab.org/worksheets/0x27203f932f8341b79841d50ce0fd684f/#)|
23 | |[CodaLab 评测教程](https://worksheets.codalab.org/worksheets/0x1a7d7d33243c476984ff3d151c4977d4/)||20181010|
24 | |[评测排行榜](https://smp2018ecdt.github.io/Leader-board/)||
25 | |[SMP2018-ECDT评测主页](http://smp2018.cips-smp.org/ecdt_index.html)||
26 | |[SMP2018-ECDT评测成绩公告链接](https://mp.weixin.qq.com/s/_VHEuXzR7oXRTo5loqJp8A)||
27 |
28 |
29 | # [SMP2018中文人机对话技术评测(ECDT)](http://smp2018.cips-smp.org/ecdt_index.html)
30 |
31 | 1. 本资源是一个完整的针对 [SMP2018中文人机对话技术评测(ECDT)](http://smp2018.cips-smp.org/ecdt_index.html) 的实验,由该实验训练的基线模型能达到评测排行榜的前三的水平。
32 | 2. 通过本实验,可以掌握处理自然语言文本数据的一般方法。
33 | 3. 推荐自己修改此文件,达到更好的实验效果,比如改变以下几个超参数
34 |
35 | ```python
36 | # 词嵌入的维度
37 | embedding_word_dims = 32
38 | # 批次大小
39 | batch_size = 30
40 | # 周期
41 | epochs = 20
42 | ```
43 |
44 | ## 本实验还可以改进的地方举例
45 |
46 | 1. 预处理阶段使用其它的分词工具
47 | 2. 采用字符向量和词向量结合的方式
48 | 3. 使用预先训练好的词向量
49 | 4. 改变模型结构
50 | 5. 改变模型超参数
51 |
52 | ## 资源说明
53 |
54 | + [SMP2018_EDA_and_Baseline_Model.ipynb](SMP2018_EDA_and_Baseline_Model.ipynb) 是完整的数据分析和模型构建过程的代码
55 | + [app.py](用户意图分类APP/app.py) 是根据训练的模型构建的 用户意图分类应用
56 | + 其它资源见名知意
57 |
--------------------------------------------------------------------------------
/用户意图分类APP/APP说明和使用APP.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 用户意图领域分类\n",
8 | " 在人机对话系统的应用过程中,用户可能会有多种意图,相应地会触发人机对话系统中的多个领域(domain) ,其中包括任务型垂直领域(如查询机票、酒店、公交车等)、知识型问答以及闲聊等。因而,人机对话系统的一个关键任务就是正确地将用户的输入分类到相应的领域(domain)中,从而才能返回正确的回复结果。"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "metadata": {},
14 | "source": [
15 | "## 分类的类别说明"
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {},
21 | "source": [
22 | "+ 包含闲聊和垂类两大类,其中垂类又细分为30个垂直领域。\n",
23 | "+ 本次评测任务1中,仅考虑针对单轮对话用户意图的领域分类,多轮对话整体意图的领域分类不在此次评测范围之内。"
24 | ]
25 | },
26 | {
27 | "cell_type": "raw",
28 | "metadata": {},
29 | "source": [
30 | "类别 = ['website', 'tvchannel', 'lottery', 'chat', 'match',\n",
31 | " 'datetime', 'weather', 'bus', 'novel', 'video', 'riddle',\n",
32 | " 'calc', 'telephone', 'health', 'contacts', 'epg', 'app', 'music',\n",
33 | " 'cookbook', 'stock', 'map', 'message', 'poetry', 'cinemas', 'news',\n",
34 | " 'flight', 'translation', 'train', 'schedule', 'radio', 'email']"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "# 开始使用"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 1,
47 | "metadata": {},
48 | "outputs": [
49 | {
50 | "name": "stderr",
51 | "output_type": "stream",
52 | "text": [
53 | "Using TensorFlow backend.\n"
54 | ]
55 | }
56 | ],
57 | "source": [
58 | "from app import query_2_label"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 2,
64 | "metadata": {},
65 | "outputs": [
66 | {
67 | "name": "stderr",
68 | "output_type": "stream",
69 | "text": [
70 | "Building prefix dict from the default dictionary ...\n",
71 | "Loading model from cache /tmp/jieba.cache\n",
72 | "Loading model cost 0.945 seconds.\n",
73 | "Prefix dict has been built succesfully.\n"
74 | ]
75 | },
76 | {
77 | "data": {
78 | "text/plain": [
79 | "'chat'"
80 | ]
81 | },
82 | "execution_count": 2,
83 | "metadata": {},
84 | "output_type": "execute_result"
85 | }
86 | ],
87 | "source": [
88 | "query_2_label('我喜欢你')"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "metadata": {},
94 | "source": [
95 | "# 运行下面代码进行查询,输入 0 结束查询"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {},
102 | "outputs": [
103 | {
104 | "name": "stdin",
105 | "output_type": "stream",
106 | "text": [
107 | " 今天东莞天气如何\n"
108 | ]
109 | },
110 | {
111 | "name": "stdout",
112 | "output_type": "stream",
113 | "text": [
114 | "----------\n",
115 | "predict label:\t datetime\n",
116 | "----------\n"
117 | ]
118 | },
119 | {
120 | "name": "stdin",
121 | "output_type": "stream",
122 | "text": [
123 | " 怎么治疗感冒?\n"
124 | ]
125 | },
126 | {
127 | "name": "stdout",
128 | "output_type": "stream",
129 | "text": [
130 | "----------\n",
131 | "predict label:\t health\n",
132 | "----------\n"
133 | ]
134 | },
135 | {
136 | "name": "stdin",
137 | "output_type": "stream",
138 | "text": [
139 | " 你好?\n"
140 | ]
141 | },
142 | {
143 | "name": "stdout",
144 | "output_type": "stream",
145 | "text": [
146 | "----------\n",
147 | "predict label:\t chat\n",
148 | "----------\n"
149 | ]
150 | }
151 | ],
152 | "source": [
153 | "while True:\n",
154 | " your_query_sentence = input()\n",
155 | " print('-'*10)\n",
156 | " label = query_2_label(your_query_sentence)\n",
157 | " print('predict label:\\t', label)\n",
158 | " print('-'*10)\n",
159 | " if your_query_sentence=='0':\n",
160 | " break"
161 | ]
162 | }
163 | ],
164 | "metadata": {
165 | "kernelspec": {
166 | "display_name": "Python 3",
167 | "language": "python",
168 | "name": "python3"
169 | },
170 | "language_info": {
171 | "codemirror_mode": {
172 | "name": "ipython",
173 | "version": 3
174 | },
175 | "file_extension": ".py",
176 | "mimetype": "text/x-python",
177 | "name": "python",
178 | "nbconvert_exporter": "python",
179 | "pygments_lexer": "ipython3",
180 | "version": "3.6.5"
181 | }
182 | },
183 | "nbformat": 4,
184 | "nbformat_minor": 2
185 | }
186 |
--------------------------------------------------------------------------------
/用户意图分类APP/SMP应用详细代码.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 导入相关库"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 37,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "import pandas as pd\n",
18 | "from keras.models import model_from_json\n",
19 | "from keras.preprocessing.sequence import pad_sequences\n",
20 | "import jieba\n",
21 | "import pickle"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "# 加载模型 SMP2018_model(F1_86)"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 38,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "# 加载 pickle 对象的函数\n",
38 | "def load_obj(name ):\n",
39 | " with open(name + '.pkl', 'rb') as f:\n",
40 | " return pickle.load(f)\n",
41 | " \n",
42 | "# 输入模型的最终单句长度\n",
43 | "max_cut_query_lenth = 26\n",
44 | "\n",
45 | "# 加载查询词汇和对应 ID 的字典\n",
46 | "word_2_index_dict = load_obj('word_2_index_dict')\n",
47 | "# 加载模型输出 ID 和对应标签(种类)的字典\n",
48 | "index_2_label_dict = load_obj('index_2_label_dict')\n",
49 | "# 加载模型结构\n",
50 | "model_structure_json = load_obj('model_structure_json')\n",
51 | "model = model_from_json(model_structure_json)\n",
52 | "# 加载模型权重\n",
53 | "model.load_weights('SMP2018_GlobalAveragePooling1D_model(F1_86).h5')"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "metadata": {},
59 | "source": [
60 | "# 使用模型的函数"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 39,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "def query_2_label(query_sentence):\n",
70 | " '''\n",
71 | " input query: \"从中山到西安的汽车。\"\n",
72 | " return label: \"bus\"\n",
73 | " '''\n",
74 | " x_input = []\n",
75 | " # 分词 ['从', '中山', '到', '西安', '的', '汽车', '。']\n",
76 | " query_sentence_list = list(jieba.cut(query_sentence))\n",
77 | " # 序列化 [54, 717, 0, 8, 0, 0, 1, 0, 183, 2]\n",
78 | " x = [word_2_index_dict.get(w, 0) for w in query_sentence]\n",
79 | " # 填充 array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
80 | " # 0, 0, 0, 54, 717, 0, 8, 0, 0, 1, 0, 183, 2]], dtype=int32)\n",
81 | " x_input.append(x)\n",
82 | " x_input = pad_sequences(x_input, maxlen=max_cut_query_lenth)\n",
83 | " # 预测\n",
84 | " y_hat = model.predict(x_input)\n",
85 | " # 取最大值所在的序号 11\n",
86 | " pred_y_index = np.argmax(y_hat)\n",
87 | " # 查找序号所对应标签(类别)\n",
88 | " label = index_2_label_dict[pred_y_index]\n",
89 | " return label"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "metadata": {},
95 | "source": [
96 | "# 使用例子"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 49,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "query_sentence = '狐臭怎么治?'\n",
106 | "\n",
107 | "print(query_2_label(query_sentence))"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "# 对 2299 条数据进行预测演示"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {},
120 | "source": [
121 | "## 获取数据"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 51,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "def get_json_data(path):\n",
131 | " # read data\n",
132 | " data_df = pd.read_json(path)\n",
133 | " # change row and colunm\n",
134 | " data_df = data_df.transpose()\n",
135 | " # change colunm order\n",
136 | " data_df = data_df[['query', 'label']]\n",
137 | " return data_df"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 52,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "data_df = get_json_data(path=\"../data/train.json\")"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 53,
152 | "metadata": {},
153 | "outputs": [
154 | {
155 | "data": {
156 | "text/html": [
157 | "
\n",
158 | "\n",
171 | "
\n",
172 | " \n",
173 | " \n",
174 | " | \n",
175 | " query | \n",
176 | " label | \n",
177 | "
\n",
178 | " \n",
179 | " \n",
180 | " \n",
181 | " | count | \n",
182 | " 2299 | \n",
183 | " 2299 | \n",
184 | "
\n",
185 | " \n",
186 | " | unique | \n",
187 | " 2299 | \n",
188 | " 31 | \n",
189 | "
\n",
190 | " \n",
191 | " | top | \n",
192 | " 还是想知道你能做些什么 | \n",
193 | " chat | \n",
194 | "
\n",
195 | " \n",
196 | " | freq | \n",
197 | " 1 | \n",
198 | " 455 | \n",
199 | "
\n",
200 | " \n",
201 | "
\n",
202 | "
"
203 | ],
204 | "text/plain": [
205 | " query label\n",
206 | "count 2299 2299\n",
207 | "unique 2299 31\n",
208 | "top 还是想知道你能做些什么 chat\n",
209 | "freq 1 455"
210 | ]
211 | },
212 | "execution_count": 53,
213 | "metadata": {},
214 | "output_type": "execute_result"
215 | }
216 | ],
217 | "source": [
218 | "data_df.describe()"
219 | ]
220 | },
221 | {
222 | "cell_type": "markdown",
223 | "metadata": {},
224 | "source": [
225 | "## 查看前 10 条数据 "
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": 54,
231 | "metadata": {},
232 | "outputs": [
233 | {
234 | "data": {
235 | "text/html": [
236 | "\n",
237 | "\n",
250 | "
\n",
251 | " \n",
252 | " \n",
253 | " | \n",
254 | " query | \n",
255 | " label | \n",
256 | "
\n",
257 | " \n",
258 | " \n",
259 | " \n",
260 | " | 0 | \n",
261 | " 今天东莞天气如何 | \n",
262 | " weather | \n",
263 | "
\n",
264 | " \n",
265 | " | 1 | \n",
266 | " 从观音桥到重庆市图书馆怎么走 | \n",
267 | " map | \n",
268 | "
\n",
269 | " \n",
270 | " | 2 | \n",
271 | " 鸭蛋怎么腌? | \n",
272 | " cookbook | \n",
273 | "
\n",
274 | " \n",
275 | " | 3 | \n",
276 | " 怎么治疗牛皮癣 | \n",
277 | " health | \n",
278 | "
\n",
279 | " \n",
280 | " | 4 | \n",
281 | " 唠什么 | \n",
282 | " chat | \n",
283 | "
\n",
284 | " \n",
285 | " | 5 | \n",
286 | " 阳澄湖大闸蟹的做法。 | \n",
287 | " cookbook | \n",
288 | "
\n",
289 | " \n",
290 | " | 6 | \n",
291 | " 昆山大润发在哪里 | \n",
292 | " map | \n",
293 | "
\n",
294 | " \n",
295 | " | 7 | \n",
296 | " 红烧肉怎么做?嗯? | \n",
297 | " cookbook | \n",
298 | "
\n",
299 | " \n",
300 | " | 8 | \n",
301 | " 南京到厦门的火车票 | \n",
302 | " train | \n",
303 | "
\n",
304 | " \n",
305 | " | 9 | \n",
306 | " 6的平方 | \n",
307 | " calc | \n",
308 | "
\n",
309 | " \n",
310 | "
\n",
311 | "
"
312 | ],
313 | "text/plain": [
314 | " query label\n",
315 | "0 今天东莞天气如何 weather\n",
316 | "1 从观音桥到重庆市图书馆怎么走 map\n",
317 | "2 鸭蛋怎么腌? cookbook\n",
318 | "3 怎么治疗牛皮癣 health\n",
319 | "4 唠什么 chat\n",
320 | "5 阳澄湖大闸蟹的做法。 cookbook\n",
321 | "6 昆山大润发在哪里 map\n",
322 | "7 红烧肉怎么做?嗯? cookbook\n",
323 | "8 南京到厦门的火车票 train\n",
324 | "9 6的平方 calc"
325 | ]
326 | },
327 | "execution_count": 54,
328 | "metadata": {},
329 | "output_type": "execute_result"
330 | }
331 | ],
332 | "source": [
333 | "data_df.head(10)"
334 | ]
335 | },
336 | {
337 | "cell_type": "markdown",
338 | "metadata": {},
339 | "source": [
340 | "## 模型预测,并查看前 10 条数据"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": 55,
346 | "metadata": {},
347 | "outputs": [
348 | {
349 | "data": {
350 | "text/html": [
351 | "\n",
352 | "\n",
365 | "
\n",
366 | " \n",
367 | " \n",
368 | " | \n",
369 | " query | \n",
370 | " label | \n",
371 | " model_prediction_label | \n",
372 | "
\n",
373 | " \n",
374 | " \n",
375 | " \n",
376 | " | 0 | \n",
377 | " 今天东莞天气如何 | \n",
378 | " weather | \n",
379 | " datetime | \n",
380 | "
\n",
381 | " \n",
382 | " | 1 | \n",
383 | " 从观音桥到重庆市图书馆怎么走 | \n",
384 | " map | \n",
385 | " map | \n",
386 | "
\n",
387 | " \n",
388 | " | 2 | \n",
389 | " 鸭蛋怎么腌? | \n",
390 | " cookbook | \n",
391 | " cookbook | \n",
392 | "
\n",
393 | " \n",
394 | " | 3 | \n",
395 | " 怎么治疗牛皮癣 | \n",
396 | " health | \n",
397 | " chat | \n",
398 | "
\n",
399 | " \n",
400 | " | 4 | \n",
401 | " 唠什么 | \n",
402 | " chat | \n",
403 | " chat | \n",
404 | "
\n",
405 | " \n",
406 | " | 5 | \n",
407 | " 阳澄湖大闸蟹的做法。 | \n",
408 | " cookbook | \n",
409 | " cookbook | \n",
410 | "
\n",
411 | " \n",
412 | " | 6 | \n",
413 | " 昆山大润发在哪里 | \n",
414 | " map | \n",
415 | " chat | \n",
416 | "
\n",
417 | " \n",
418 | " | 7 | \n",
419 | " 红烧肉怎么做?嗯? | \n",
420 | " cookbook | \n",
421 | " cookbook | \n",
422 | "
\n",
423 | " \n",
424 | " | 8 | \n",
425 | " 南京到厦门的火车票 | \n",
426 | " train | \n",
427 | " bus | \n",
428 | "
\n",
429 | " \n",
430 | " | 9 | \n",
431 | " 6的平方 | \n",
432 | " calc | \n",
433 | " calc | \n",
434 | "
\n",
435 | " \n",
436 | "
\n",
437 | "
"
438 | ],
439 | "text/plain": [
440 | " query label model_prediction_label\n",
441 | "0 今天东莞天气如何 weather datetime\n",
442 | "1 从观音桥到重庆市图书馆怎么走 map map\n",
443 | "2 鸭蛋怎么腌? cookbook cookbook\n",
444 | "3 怎么治疗牛皮癣 health chat\n",
445 | "4 唠什么 chat chat\n",
446 | "5 阳澄湖大闸蟹的做法。 cookbook cookbook\n",
447 | "6 昆山大润发在哪里 map chat\n",
448 | "7 红烧肉怎么做?嗯? cookbook cookbook\n",
449 | "8 南京到厦门的火车票 train bus\n",
450 | "9 6的平方 calc calc"
451 | ]
452 | },
453 | "execution_count": 55,
454 | "metadata": {},
455 | "output_type": "execute_result"
456 | }
457 | ],
458 | "source": [
459 | "data_df['model_prediction_label'] = data_df['query'].apply(query_2_label)\n",
460 | "\n",
461 | "data_df.head(10)"
462 | ]
463 | }
464 | ],
465 | "metadata": {
466 | "kernelspec": {
467 | "display_name": "Python 3",
468 | "language": "python",
469 | "name": "python3"
470 | },
471 | "language_info": {
472 | "codemirror_mode": {
473 | "name": "ipython",
474 | "version": 3
475 | },
476 | "file_extension": ".py",
477 | "mimetype": "text/x-python",
478 | "name": "python",
479 | "nbconvert_exporter": "python",
480 | "pygments_lexer": "ipython3",
481 | "version": "3.6.5"
482 | }
483 | },
484 | "nbformat": 4,
485 | "nbformat_minor": 2
486 | }
487 |
--------------------------------------------------------------------------------
/SMP2018_EDA_and_Baseline_Model(Keras).ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# $$SMP2018中文人机对话技术评测(ECDT)$$"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "1. 下面是一个完整的针对 [SMP2018中文人机对话技术评测(ECDT)](http://smp2018.cips-smp.org/ecdt_index.html) 的实验,由该实验训练的基线模型能达到评测排行榜的前三的水平。\n",
15 | "2. 通过本实验,可以掌握处理自然语言文本数据的一般方法。\n",
16 | "3. 推荐自己修改此文件,达到更好的实验效果,比如改变以下几个超参数 "
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "```python\n",
24 | "# 词嵌入的维度\n",
25 | "embedding_word_dims = 32\n",
26 | "# 批次大小\n",
27 | "batch_size = 30\n",
28 | "# 周期\n",
29 | "epochs = 20\n",
30 | "```"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "# 本实验还可以改进的地方举例 "
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "1. 预处理阶段使用其它的分词工具\n",
45 | "2. 采用字符向量和词向量结合的方式\n",
46 | "3. 使用预先训练好的词向量\n",
47 | "4. 改变模型结构\n",
48 | "5. 改变模型超参数"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "# 导入依赖库"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 1,
61 | "metadata": {},
62 | "outputs": [
63 | {
64 | "name": "stderr",
65 | "output_type": "stream",
66 | "text": [
67 | "Using TensorFlow backend.\n"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "import numpy as np\n",
73 | "import pandas as pd\n",
74 | "import collections\n",
75 | "import jieba\n",
76 | "from keras.preprocessing.text import Tokenizer\n",
77 | "from keras.preprocessing.sequence import pad_sequences\n",
78 | "from keras.models import Sequential\n",
79 | "from keras.layers import Embedding, LSTM, Dense\n",
80 | "from keras.utils import to_categorical,plot_model\n",
81 | "from keras.callbacks import TensorBoard, Callback\n",
82 | "\n",
83 | "from sklearn.metrics import classification_report\n",
84 | "\n",
85 | "import requests \n",
86 | "\n",
87 | "import time\n",
88 | "\n",
89 | "import os"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "metadata": {},
95 | "source": [
96 | "# 辅助函数"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 2,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "from keras import backend as K\n",
106 | "\n",
107 | "# 计算 F1 值的函数\n",
108 | "def f1(y_true, y_pred):\n",
109 | " def recall(y_true, y_pred):\n",
110 | " \"\"\"Recall metric.\n",
111 | "\n",
112 | " Only computes a batch-wise average of recall.\n",
113 | "\n",
114 | " Computes the recall, a metric for multi-label classification of\n",
115 | " how many relevant items are selected.\n",
116 | " \"\"\"\n",
117 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n",
118 | " possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))\n",
119 | " recall = true_positives / (possible_positives + K.epsilon())\n",
120 | " return recall\n",
121 | "\n",
122 | " def precision(y_true, y_pred):\n",
123 | " \"\"\"Precision metric.\n",
124 | "\n",
125 | " Only computes a batch-wise average of precision.\n",
126 | "\n",
127 | " Computes the precision, a metric for multi-label classification of\n",
128 | " how many selected items are relevant.\n",
129 | " \"\"\"\n",
130 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n",
131 | " predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))\n",
132 | " precision = true_positives / (predicted_positives + K.epsilon())\n",
133 | " return precision\n",
134 | " precision = precision(y_true, y_pred)\n",
135 | " recall = recall(y_true, y_pred)\n",
136 | " return 2*((precision*recall)/(precision+recall+K.epsilon()))"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 3,
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "# 获取自定义时间格式的字符串\n",
146 | "def get_customization_time():\n",
147 | " # return '2018_10_10_18_11_45' 年月日时分秒\n",
148 | " time_tuple = time.localtime(time.time())\n",
149 | " customization_time = \"{}_{}_{}_{}_{}_{}\".format(time_tuple[0], time_tuple[1], time_tuple[2], time_tuple[3], time_tuple[4], time_tuple[5])\n",
150 | " return customization_time"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {},
156 | "source": [
157 | "# 准备数据"
158 | ]
159 | },
160 | {
161 | "cell_type": "markdown",
162 | "metadata": {},
163 | "source": [
164 | "## [下载SMP2018官方数据](https://worksheets.codalab.org/worksheets/0x27203f932f8341b79841d50ce0fd684f/)"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": 4,
170 | "metadata": {},
171 | "outputs": [],
172 | "source": [
173 | "raw_train_data_url = \"https://worksheets.codalab.org/rest/bundles/0x0161fd2fb40d4dd48541c2643d04b0b8/contents/blob/\"\n",
174 | "raw_test_data_url = \"https://worksheets.codalab.org/rest/bundles/0x1f96bc12222641209ad057e762910252/contents/blob/\"\n",
175 | "\n",
176 | "# 如果不存在 SMP2018 数据,则下载\n",
177 | "if (not os.path.exists('./data/train.json')) or (not os.path.exists('./data/dev.json')):\n",
178 | " raw_train = requests.get(raw_train_data_url) \n",
179 | " raw_test = requests.get(raw_test_data_url) \n",
180 | " if not os.path.exists('./data'):\n",
181 | " os.makedirs('./data')\n",
182 | " with open(\"./data/train.json\", \"wb\") as code:\n",
183 | " code.write(raw_train.content)\n",
184 | " with open(\"./data/dev.json\", \"wb\") as code:\n",
185 | " code.write(raw_test.content)"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 5,
191 | "metadata": {},
192 | "outputs": [],
193 | "source": [
194 | "def get_json_data(path):\n",
195 | " # read data\n",
196 | " data_df = pd.read_json(path)\n",
197 | " # change row and colunm\n",
198 | " data_df = data_df.transpose()\n",
199 | " # change colunm order\n",
200 | " data_df = data_df[['query', 'label']]\n",
201 | " return data_df"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 6,
207 | "metadata": {},
208 | "outputs": [],
209 | "source": [
210 | "train_data_df = get_json_data(path=\"data/train.json\")\n",
211 | "\n",
212 | "test_data_df = get_json_data(path=\"data/dev.json\")"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": 7,
218 | "metadata": {},
219 | "outputs": [
220 | {
221 | "data": {
222 | "text/html": [
223 | "\n",
224 | "\n",
237 | "
\n",
238 | " \n",
239 | " \n",
240 | " | \n",
241 | " query | \n",
242 | " label | \n",
243 | "
\n",
244 | " \n",
245 | " \n",
246 | " \n",
247 | " | 0 | \n",
248 | " 今天东莞天气如何 | \n",
249 | " weather | \n",
250 | "
\n",
251 | " \n",
252 | " | 1 | \n",
253 | " 从观音桥到重庆市图书馆怎么走 | \n",
254 | " map | \n",
255 | "
\n",
256 | " \n",
257 | " | 2 | \n",
258 | " 鸭蛋怎么腌? | \n",
259 | " cookbook | \n",
260 | "
\n",
261 | " \n",
262 | " | 3 | \n",
263 | " 怎么治疗牛皮癣 | \n",
264 | " health | \n",
265 | "
\n",
266 | " \n",
267 | " | 4 | \n",
268 | " 唠什么 | \n",
269 | " chat | \n",
270 | "
\n",
271 | " \n",
272 | "
\n",
273 | "
"
274 | ],
275 | "text/plain": [
276 | " query label\n",
277 | "0 今天东莞天气如何 weather\n",
278 | "1 从观音桥到重庆市图书馆怎么走 map\n",
279 | "2 鸭蛋怎么腌? cookbook\n",
280 | "3 怎么治疗牛皮癣 health\n",
281 | "4 唠什么 chat"
282 | ]
283 | },
284 | "execution_count": 7,
285 | "metadata": {},
286 | "output_type": "execute_result"
287 | }
288 | ],
289 | "source": [
290 | "train_data_df.head()"
291 | ]
292 | },
293 | {
294 | "cell_type": "markdown",
295 | "metadata": {},
296 | "source": [
297 | "---"
298 | ]
299 | },
300 | {
301 | "cell_type": "markdown",
302 | "metadata": {},
303 | "source": [
304 | "## [结巴分词](https://github.com/fxsjy/jieba)示例,下面将使用结巴分词对原数据进行处理"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 8,
310 | "metadata": {},
311 | "outputs": [
312 | {
313 | "name": "stderr",
314 | "output_type": "stream",
315 | "text": [
316 | "Building prefix dict from the default dictionary ...\n",
317 | "Loading model from cache /tmp/jieba.cache\n",
318 | "Loading model cost 1.022 seconds.\n",
319 | "Prefix dict has been built succesfully.\n"
320 | ]
321 | },
322 | {
323 | "name": "stdout",
324 | "output_type": "stream",
325 | "text": [
326 | "['他', '来到', '了', '网易', '杭研', '大厦']\n"
327 | ]
328 | }
329 | ],
330 | "source": [
331 | "seg_list = jieba.cut(\"他来到了网易杭研大厦\") # 默认是精确模式\n",
332 | "print(list(seg_list))"
333 | ]
334 | },
335 | {
336 | "cell_type": "markdown",
337 | "metadata": {},
338 | "source": [
339 | "---"
340 | ]
341 | },
342 | {
343 | "cell_type": "markdown",
344 | "metadata": {},
345 | "source": [
346 | "# 序列化"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": 9,
352 | "metadata": {},
353 | "outputs": [],
354 | "source": [
355 | "def use_jieba_cut(a_sentence):\n",
356 | " return list(jieba.cut(a_sentence))\n",
357 | "\n",
358 | "train_data_df['cut_query'] = train_data_df['query'].apply(use_jieba_cut)\n",
359 | "test_data_df['cut_query'] = test_data_df['query'].apply(use_jieba_cut)"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": 10,
365 | "metadata": {},
366 | "outputs": [
367 | {
368 | "data": {
369 | "text/html": [
370 | "\n",
371 | "\n",
384 | "
\n",
385 | " \n",
386 | " \n",
387 | " | \n",
388 | " query | \n",
389 | " label | \n",
390 | " cut_query | \n",
391 | "
\n",
392 | " \n",
393 | " \n",
394 | " \n",
395 | " | 0 | \n",
396 | " 今天东莞天气如何 | \n",
397 | " weather | \n",
398 | " [今天, 东莞, 天气, 如何] | \n",
399 | "
\n",
400 | " \n",
401 | " | 1 | \n",
402 | " 从观音桥到重庆市图书馆怎么走 | \n",
403 | " map | \n",
404 | " [从, 观音桥, 到, 重庆市, 图书馆, 怎么, 走] | \n",
405 | "
\n",
406 | " \n",
407 | " | 2 | \n",
408 | " 鸭蛋怎么腌? | \n",
409 | " cookbook | \n",
410 | " [鸭蛋, 怎么, 腌, ?] | \n",
411 | "
\n",
412 | " \n",
413 | " | 3 | \n",
414 | " 怎么治疗牛皮癣 | \n",
415 | " health | \n",
416 | " [怎么, 治疗, 牛皮癣] | \n",
417 | "
\n",
418 | " \n",
419 | " | 4 | \n",
420 | " 唠什么 | \n",
421 | " chat | \n",
422 | " [唠, 什么] | \n",
423 | "
\n",
424 | " \n",
425 | " | 5 | \n",
426 | " 阳澄湖大闸蟹的做法。 | \n",
427 | " cookbook | \n",
428 | " [阳澄湖, 大闸蟹, 的, 做法, 。] | \n",
429 | "
\n",
430 | " \n",
431 | " | 6 | \n",
432 | " 昆山大润发在哪里 | \n",
433 | " map | \n",
434 | " [昆山, 大润发, 在, 哪里] | \n",
435 | "
\n",
436 | " \n",
437 | " | 7 | \n",
438 | " 红烧肉怎么做?嗯? | \n",
439 | " cookbook | \n",
440 | " [红烧肉, 怎么, 做, ?, 嗯, ?] | \n",
441 | "
\n",
442 | " \n",
443 | " | 8 | \n",
444 | " 南京到厦门的火车票 | \n",
445 | " train | \n",
446 | " [南京, 到, 厦门, 的, 火车票] | \n",
447 | "
\n",
448 | " \n",
449 | " | 9 | \n",
450 | " 6的平方 | \n",
451 | " calc | \n",
452 | " [6, 的, 平方] | \n",
453 | "
\n",
454 | " \n",
455 | "
\n",
456 | "
"
457 | ],
458 | "text/plain": [
459 | " query label cut_query\n",
460 | "0 今天东莞天气如何 weather [今天, 东莞, 天气, 如何]\n",
461 | "1 从观音桥到重庆市图书馆怎么走 map [从, 观音桥, 到, 重庆市, 图书馆, 怎么, 走]\n",
462 | "2 鸭蛋怎么腌? cookbook [鸭蛋, 怎么, 腌, ?]\n",
463 | "3 怎么治疗牛皮癣 health [怎么, 治疗, 牛皮癣]\n",
464 | "4 唠什么 chat [唠, 什么]\n",
465 | "5 阳澄湖大闸蟹的做法。 cookbook [阳澄湖, 大闸蟹, 的, 做法, 。]\n",
466 | "6 昆山大润发在哪里 map [昆山, 大润发, 在, 哪里]\n",
467 | "7 红烧肉怎么做?嗯? cookbook [红烧肉, 怎么, 做, ?, 嗯, ?]\n",
468 | "8 南京到厦门的火车票 train [南京, 到, 厦门, 的, 火车票]\n",
469 | "9 6的平方 calc [6, 的, 平方]"
470 | ]
471 | },
472 | "execution_count": 10,
473 | "metadata": {},
474 | "output_type": "execute_result"
475 | }
476 | ],
477 | "source": [
478 | "train_data_df.head(10)"
479 | ]
480 | },
481 | {
482 | "cell_type": "markdown",
483 | "metadata": {},
484 | "source": [
485 | "## 处理特征"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "execution_count": 11,
491 | "metadata": {},
492 | "outputs": [],
493 | "source": [
494 | "tokenizer = Tokenizer()"
495 | ]
496 | },
497 | {
498 | "cell_type": "code",
499 | "execution_count": 12,
500 | "metadata": {},
501 | "outputs": [],
502 | "source": [
503 | "tokenizer.fit_on_texts(train_data_df['cut_query'])"
504 | ]
505 | },
506 | {
507 | "cell_type": "code",
508 | "execution_count": 13,
509 | "metadata": {},
510 | "outputs": [
511 | {
512 | "data": {
513 | "text/plain": [
514 | "2883"
515 | ]
516 | },
517 | "execution_count": 13,
518 | "metadata": {},
519 | "output_type": "execute_result"
520 | }
521 | ],
522 | "source": [
523 | "max_features = len(tokenizer.index_word)\n",
524 | "\n",
525 | "len(tokenizer.index_word)"
526 | ]
527 | },
528 | {
529 | "cell_type": "code",
530 | "execution_count": 14,
531 | "metadata": {},
532 | "outputs": [],
533 | "source": [
534 | "x_train = tokenizer.texts_to_sequences(train_data_df['cut_query'])\n",
535 | "\n",
536 | "x_test = tokenizer.texts_to_sequences(test_data_df['cut_query'])"
537 | ]
538 | },
539 | {
540 | "cell_type": "code",
541 | "execution_count": 15,
542 | "metadata": {},
543 | "outputs": [],
544 | "source": [
545 | "max_cut_query_lenth = 26"
546 | ]
547 | },
548 | {
549 | "cell_type": "code",
550 | "execution_count": 16,
551 | "metadata": {},
552 | "outputs": [],
553 | "source": [
554 | "x_train = pad_sequences(x_train, max_cut_query_lenth)\n",
555 | "\n",
556 | "x_test = pad_sequences(x_test, max_cut_query_lenth)"
557 | ]
558 | },
559 | {
560 | "cell_type": "code",
561 | "execution_count": 17,
562 | "metadata": {},
563 | "outputs": [
564 | {
565 | "data": {
566 | "text/plain": [
567 | "(2299, 26)"
568 | ]
569 | },
570 | "execution_count": 17,
571 | "metadata": {},
572 | "output_type": "execute_result"
573 | }
574 | ],
575 | "source": [
576 | "x_train.shape"
577 | ]
578 | },
579 | {
580 | "cell_type": "code",
581 | "execution_count": 18,
582 | "metadata": {},
583 | "outputs": [
584 | {
585 | "data": {
586 | "text/plain": [
587 | "(770, 26)"
588 | ]
589 | },
590 | "execution_count": 18,
591 | "metadata": {},
592 | "output_type": "execute_result"
593 | }
594 | ],
595 | "source": [
596 | "x_test.shape"
597 | ]
598 | },
599 | {
600 | "cell_type": "markdown",
601 | "metadata": {},
602 | "source": [
603 | "## 处理标签"
604 | ]
605 | },
606 | {
607 | "cell_type": "code",
608 | "execution_count": 19,
609 | "metadata": {},
610 | "outputs": [],
611 | "source": [
612 | "label_tokenizer = Tokenizer()"
613 | ]
614 | },
615 | {
616 | "cell_type": "code",
617 | "execution_count": 20,
618 | "metadata": {},
619 | "outputs": [],
620 | "source": [
621 | "label_tokenizer.fit_on_texts(train_data_df['label'])"
622 | ]
623 | },
624 | {
625 | "cell_type": "code",
626 | "execution_count": 21,
627 | "metadata": {},
628 | "outputs": [],
629 | "source": [
630 | "label_numbers = len(label_tokenizer.word_counts)"
631 | ]
632 | },
633 | {
634 | "cell_type": "code",
635 | "execution_count": 22,
636 | "metadata": {},
637 | "outputs": [],
638 | "source": [
639 | "NUM_CLASSES = len(label_tokenizer.word_counts)"
640 | ]
641 | },
642 | {
643 | "cell_type": "code",
644 | "execution_count": 23,
645 | "metadata": {},
646 | "outputs": [
647 | {
648 | "data": {
649 | "text/plain": [
650 | "OrderedDict([('weather', 66),\n",
651 | " ('map', 68),\n",
652 | " ('cookbook', 269),\n",
653 | " ('health', 55),\n",
654 | " ('chat', 455),\n",
655 | " ('train', 70),\n",
656 | " ('calc', 24),\n",
657 | " ('translation', 61),\n",
658 | " ('music', 66),\n",
659 | " ('tvchannel', 71),\n",
660 | " ('poetry', 102),\n",
661 | " ('telephone', 63),\n",
662 | " ('stock', 71),\n",
663 | " ('radio', 24),\n",
664 | " ('contacts', 30),\n",
665 | " ('lottery', 24),\n",
666 | " ('website', 54),\n",
667 | " ('video', 182),\n",
668 | " ('news', 58),\n",
669 | " ('bus', 24),\n",
670 | " ('app', 53),\n",
671 | " ('flight', 62),\n",
672 | " ('epg', 107),\n",
673 | " ('message', 63),\n",
674 | " ('match', 24),\n",
675 | " ('schedule', 29),\n",
676 | " ('novel', 24),\n",
677 | " ('riddle', 34),\n",
678 | " ('email', 24),\n",
679 | " ('datetime', 18),\n",
680 | " ('cinemas', 24)])"
681 | ]
682 | },
683 | "execution_count": 23,
684 | "metadata": {},
685 | "output_type": "execute_result"
686 | }
687 | ],
688 | "source": [
689 | "label_tokenizer.word_counts"
690 | ]
691 | },
692 | {
693 | "cell_type": "code",
694 | "execution_count": 24,
695 | "metadata": {},
696 | "outputs": [],
697 | "source": [
698 | "y_train = label_tokenizer.texts_to_sequences(train_data_df['label'])"
699 | ]
700 | },
701 | {
702 | "cell_type": "code",
703 | "execution_count": 25,
704 | "metadata": {},
705 | "outputs": [
706 | {
707 | "data": {
708 | "text/plain": [
709 | "[[10], [9], [2], [17], [1], [2], [9], [2], [8], [23]]"
710 | ]
711 | },
712 | "execution_count": 25,
713 | "metadata": {},
714 | "output_type": "execute_result"
715 | }
716 | ],
717 | "source": [
718 | "y_train[:10]"
719 | ]
720 | },
721 | {
722 | "cell_type": "code",
723 | "execution_count": 26,
724 | "metadata": {},
725 | "outputs": [],
726 | "source": [
727 | "y_train = [[y[0]-1] for y in y_train]"
728 | ]
729 | },
730 | {
731 | "cell_type": "code",
732 | "execution_count": 27,
733 | "metadata": {},
734 | "outputs": [
735 | {
736 | "data": {
737 | "text/plain": [
738 | "[[9], [8], [1], [16], [0], [1], [8], [1], [7], [22]]"
739 | ]
740 | },
741 | "execution_count": 27,
742 | "metadata": {},
743 | "output_type": "execute_result"
744 | }
745 | ],
746 | "source": [
747 | "y_train[:10]"
748 | ]
749 | },
750 | {
751 | "cell_type": "code",
752 | "execution_count": 28,
753 | "metadata": {},
754 | "outputs": [
755 | {
756 | "data": {
757 | "text/plain": [
758 | "(2299, 31)"
759 | ]
760 | },
761 | "execution_count": 28,
762 | "metadata": {},
763 | "output_type": "execute_result"
764 | }
765 | ],
766 | "source": [
767 | "y_train = to_categorical(y_train, label_numbers)\n",
768 | "y_train.shape"
769 | ]
770 | },
771 | {
772 | "cell_type": "code",
773 | "execution_count": 29,
774 | "metadata": {},
775 | "outputs": [
776 | {
777 | "data": {
778 | "text/plain": [
779 | "(770, 31)"
780 | ]
781 | },
782 | "execution_count": 29,
783 | "metadata": {},
784 | "output_type": "execute_result"
785 | }
786 | ],
787 | "source": [
788 | "y_test = label_tokenizer.texts_to_sequences(test_data_df['label'])\n",
789 | "y_test = [y[0]-1 for y in y_test]\n",
790 | "y_test = to_categorical(y_test, label_numbers)\n",
791 | "y_test.shape"
792 | ]
793 | },
794 | {
795 | "cell_type": "code",
796 | "execution_count": 30,
797 | "metadata": {},
798 | "outputs": [
799 | {
800 | "data": {
801 | "text/plain": [
802 | "array([0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
803 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
804 | " dtype=float32)"
805 | ]
806 | },
807 | "execution_count": 30,
808 | "metadata": {},
809 | "output_type": "execute_result"
810 | }
811 | ],
812 | "source": [
813 | "y_test[0]"
814 | ]
815 | },
816 | {
817 | "cell_type": "markdown",
818 | "metadata": {},
819 | "source": [
820 | "# 设计模型"
821 | ]
822 | },
823 | {
824 | "cell_type": "code",
825 | "execution_count": 45,
826 | "metadata": {},
827 | "outputs": [],
828 | "source": [
829 | "def create_SMP2018_lstm_model(max_features, max_cut_query_lenth, label_numbers):\n",
830 | " model = Sequential()\n",
831 | " model.add(Embedding(input_dim=max_features+1, output_dim=32, input_length=max_cut_query_lenth))\n",
832 | " model.add(LSTM(units=64, dropout=0.2, recurrent_dropout=0.2))\n",
833 | " model.add(Dense(label_numbers, activation='softmax'))\n",
834 | " # try using different optimizers and different optimizer configs\n",
835 | " model.compile(loss='categorical_crossentropy',\n",
836 | " optimizer='adam',\n",
837 | " metrics=[f1])\n",
838 | "\n",
839 | " plot_model(model, to_file='SMP2018_lstm_model.png', show_shapes=True)\n",
840 | " \n",
841 | " return model"
842 | ]
843 | },
844 | {
845 | "cell_type": "markdown",
846 | "metadata": {},
847 | "source": [
848 | "# 训练模型"
849 | ]
850 | },
851 | {
852 | "cell_type": "code",
853 | "execution_count": 46,
854 | "metadata": {},
855 | "outputs": [],
856 | "source": [
857 | "if 'max_features' not in dir():\n",
858 | " max_features = 2888\n",
859 | " print('not find max_features variable, use default max_features values:\\t{}'.format(max_features))\n",
860 | "if 'max_cut_query_lenth' not in dir():\n",
861 | " max_cut_query_lenth = 26\n",
862 | " print('not find max_cut_query_lenth, use default max_features values:\\t{}'.format(max_cut_query_lenth))\n",
863 | "if 'label_numbers' not in dir():\n",
864 | " label_numbers = 31\n",
865 | " print('not find label_numbers, use default max_features values:\\t{}'.format(label_numbers))"
866 | ]
867 | },
868 | {
869 | "cell_type": "code",
870 | "execution_count": 47,
871 | "metadata": {},
872 | "outputs": [],
873 | "source": [
874 | "model = create_SMP2018_lstm_model(max_features, max_cut_query_lenth, label_numbers)"
875 | ]
876 | },
877 | {
878 | "cell_type": "code",
879 | "execution_count": 48,
880 | "metadata": {},
881 | "outputs": [],
882 | "source": [
883 | "batch_size = 20\n",
884 | "epochs = 30"
885 | ]
886 | },
887 | {
888 | "cell_type": "code",
889 | "execution_count": 49,
890 | "metadata": {},
891 | "outputs": [
892 | {
893 | "name": "stdout",
894 | "output_type": "stream",
895 | "text": [
896 | "(2299, 26) (2299, 31)\n"
897 | ]
898 | }
899 | ],
900 | "source": [
901 | "print(x_train.shape, y_train.shape)"
902 | ]
903 | },
904 | {
905 | "cell_type": "code",
906 | "execution_count": 50,
907 | "metadata": {},
908 | "outputs": [
909 | {
910 | "name": "stdout",
911 | "output_type": "stream",
912 | "text": [
913 | "(770, 26) (770, 31)\n"
914 | ]
915 | }
916 | ],
917 | "source": [
918 | "print(x_test.shape, y_test.shape)"
919 | ]
920 | },
921 | {
922 | "cell_type": "code",
923 | "execution_count": 51,
924 | "metadata": {},
925 | "outputs": [
926 | {
927 | "name": "stdout",
928 | "output_type": "stream",
929 | "text": [
930 | "Train...\n",
931 | "Epoch 1/30\n",
932 | "2299/2299 [==============================] - 16s 7ms/step - loss: 3.0916 - f1: 0.0000e+00\n",
933 | "Epoch 2/30\n",
934 | "2299/2299 [==============================] - 14s 6ms/step - loss: 2.6594 - f1: 0.1409\n",
935 | "Epoch 3/30\n",
936 | "2299/2299 [==============================] - 13s 6ms/step - loss: 2.0817 - f1: 0.4055\n",
937 | "Epoch 4/30\n",
938 | "2299/2299 [==============================] - 14s 6ms/step - loss: 1.6032 - f1: 0.4689\n",
939 | "Epoch 5/30\n",
940 | "2299/2299 [==============================] - 14s 6ms/step - loss: 1.1318 - f1: 0.6176\n",
941 | "Epoch 6/30\n",
942 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.8090 - f1: 0.7399\n",
943 | "Epoch 7/30\n",
944 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.5704 - f1: 0.8298\n",
945 | "Epoch 8/30\n",
946 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.4051 - f1: 0.8879\n",
947 | "Epoch 9/30\n",
948 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.3002 - f1: 0.9280\n",
949 | "Epoch 10/30\n",
950 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.2317 - f1: 0.9467\n",
951 | "Epoch 11/30\n",
952 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.1755 - f1: 0.9678\n",
953 | "Epoch 12/30\n",
954 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.1391 - f1: 0.9758\n",
955 | "Epoch 13/30\n",
956 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.1131 - f1: 0.9800\n",
957 | "Epoch 14/30\n",
958 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0883 - f1: 0.9861\n",
959 | "Epoch 15/30\n",
960 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0725 - f1: 0.9894\n",
961 | "Epoch 16/30\n",
962 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0615 - f1: 0.9929\n",
963 | "Epoch 17/30\n",
964 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0507 - f1: 0.9945\n",
965 | "Epoch 18/30\n",
966 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0455 - f1: 0.9963\n",
967 | "Epoch 19/30\n",
968 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0398 - f1: 0.9960\n",
969 | "Epoch 20/30\n",
970 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0313 - f1: 0.9978\n",
971 | "Epoch 21/30\n",
972 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0266 - f1: 0.9984\n",
973 | "Epoch 22/30\n",
974 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0279 - f1: 0.9965\n",
975 | "Epoch 23/30\n",
976 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0250 - f1: 0.9976\n",
977 | "Epoch 24/30\n",
978 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0219 - f1: 0.9982\n",
979 | "Epoch 25/30\n",
980 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0195 - f1: 0.9982\n",
981 | "Epoch 26/30\n",
982 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0179 - f1: 0.9989\n",
983 | "Epoch 27/30\n",
984 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0177 - f1: 0.9974\n",
985 | "Epoch 28/30\n",
986 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0139 - f1: 0.9987\n",
987 | "Epoch 29/30\n",
988 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0139 - f1: 0.9989\n",
989 | "Epoch 30/30\n",
990 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0129 - f1: 0.9987\n"
991 | ]
992 | },
993 | {
994 | "data": {
995 | "text/plain": [
996 | ""
997 | ]
998 | },
999 | "execution_count": 51,
1000 | "metadata": {},
1001 | "output_type": "execute_result"
1002 | }
1003 | ],
1004 | "source": [
1005 | "print('Train...')\n",
1006 | "model.fit(x_train, y_train,\n",
1007 | " batch_size=batch_size,\n",
1008 | " epochs=epochs)"
1009 | ]
1010 | },
1011 | {
1012 | "cell_type": "markdown",
1013 | "metadata": {},
1014 | "source": [
1015 | "# 评估模型"
1016 | ]
1017 | },
1018 | {
1019 | "cell_type": "code",
1020 | "execution_count": 52,
1021 | "metadata": {},
1022 | "outputs": [
1023 | {
1024 | "name": "stdout",
1025 | "output_type": "stream",
1026 | "text": [
1027 | "770/770 [==============================] - 1s 1ms/step\n",
1028 | "Test score: 0.6803552009068526\n",
1029 | "Test f1: 0.8464262740952628\n"
1030 | ]
1031 | }
1032 | ],
1033 | "source": [
1034 | "score = model.evaluate(x_test, y_test,\n",
1035 | " batch_size=batch_size, verbose=1)\n",
1036 | "\n",
1037 | "print('Test score:', score[0])\n",
1038 | "print('Test f1:', score[1])"
1039 | ]
1040 | },
1041 | {
1042 | "cell_type": "code",
1043 | "execution_count": 53,
1044 | "metadata": {},
1045 | "outputs": [],
1046 | "source": [
1047 | "y_hat_test = model.predict(x_test)"
1048 | ]
1049 | },
1050 | {
1051 | "cell_type": "code",
1052 | "execution_count": 55,
1053 | "metadata": {},
1054 | "outputs": [
1055 | {
1056 | "name": "stdout",
1057 | "output_type": "stream",
1058 | "text": [
1059 | "(770, 31)\n"
1060 | ]
1061 | }
1062 | ],
1063 | "source": [
1064 | "print(y_hat_test.shape)"
1065 | ]
1066 | },
1067 | {
1068 | "cell_type": "markdown",
1069 | "metadata": {},
1070 | "source": [
1071 | "## 将 one-hot 张量转换成对应的整数"
1072 | ]
1073 | },
1074 | {
1075 | "cell_type": "code",
1076 | "execution_count": 54,
1077 | "metadata": {},
1078 | "outputs": [],
1079 | "source": [
1080 | "y_pred = np.argmax(y_hat_test, axis=1).tolist()"
1081 | ]
1082 | },
1083 | {
1084 | "cell_type": "code",
1085 | "execution_count": 55,
1086 | "metadata": {},
1087 | "outputs": [],
1088 | "source": [
1089 | "y_true = np.argmax(y_test, axis=1).tolist()"
1090 | ]
1091 | },
1092 | {
1093 | "cell_type": "markdown",
1094 | "metadata": {},
1095 | "source": [
1096 | "## 查看多分类的 准确率、召回率、F1 值"
1097 | ]
1098 | },
1099 | {
1100 | "cell_type": "code",
1101 | "execution_count": 56,
1102 | "metadata": {},
1103 | "outputs": [
1104 | {
1105 | "name": "stdout",
1106 | "output_type": "stream",
1107 | "text": [
1108 | " precision recall f1-score support\n",
1109 | "\n",
1110 | " 0 0.78 0.93 0.85 154\n",
1111 | " 1 0.92 0.97 0.95 89\n",
1112 | " 2 0.67 0.62 0.64 60\n",
1113 | " 3 0.83 0.83 0.83 36\n",
1114 | " 4 0.79 1.00 0.88 34\n",
1115 | " 5 0.83 0.65 0.73 23\n",
1116 | " 6 1.00 0.83 0.91 24\n",
1117 | " 7 1.00 1.00 1.00 24\n",
1118 | " 8 0.68 0.65 0.67 23\n",
1119 | " 9 0.90 0.86 0.88 22\n",
1120 | " 10 0.85 0.50 0.63 22\n",
1121 | " 11 0.88 1.00 0.93 21\n",
1122 | " 12 1.00 0.90 0.95 21\n",
1123 | " 13 0.91 0.95 0.93 21\n",
1124 | " 14 1.00 0.95 0.98 21\n",
1125 | " 15 0.79 0.95 0.86 20\n",
1126 | " 16 0.90 0.47 0.62 19\n",
1127 | " 17 0.79 0.61 0.69 18\n",
1128 | " 18 0.63 0.67 0.65 18\n",
1129 | " 19 0.90 0.82 0.86 11\n",
1130 | " 20 1.00 0.70 0.82 10\n",
1131 | " 21 1.00 0.67 0.80 9\n",
1132 | " 22 1.00 0.88 0.93 8\n",
1133 | " 23 1.00 0.62 0.77 8\n",
1134 | " 24 1.00 1.00 1.00 8\n",
1135 | " 25 1.00 0.88 0.93 8\n",
1136 | " 26 0.88 0.88 0.88 8\n",
1137 | " 27 0.86 0.75 0.80 8\n",
1138 | " 28 1.00 1.00 1.00 8\n",
1139 | " 29 0.75 0.75 0.75 8\n",
1140 | " 30 0.75 1.00 0.86 6\n",
1141 | "\n",
1142 | " micro avg 0.84 0.84 0.84 770\n",
1143 | " macro avg 0.88 0.82 0.84 770\n",
1144 | "weighted avg 0.85 0.84 0.84 770\n",
1145 | "\n"
1146 | ]
1147 | }
1148 | ],
1149 | "source": [
1150 | "print(classification_report(y_true, y_pred))"
1151 | ]
1152 | }
1153 | ],
1154 | "metadata": {
1155 | "kernelspec": {
1156 | "display_name": "Python 3",
1157 | "language": "python",
1158 | "name": "python3"
1159 | },
1160 | "language_info": {
1161 | "codemirror_mode": {
1162 | "name": "ipython",
1163 | "version": 3
1164 | },
1165 | "file_extension": ".py",
1166 | "mimetype": "text/x-python",
1167 | "name": "python",
1168 | "nbconvert_exporter": "python",
1169 | "pygments_lexer": "ipython3",
1170 | "version": "3.6.5"
1171 | }
1172 | },
1173 | "nbformat": 4,
1174 | "nbformat_minor": 2
1175 | }
1176 |
--------------------------------------------------------------------------------
/SMP2018_EDA_and_Baseline_Model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# $$SMP2018中文人机对话技术评测(ECDT)$$"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "# **不要直接修改此文件,可把该文件拷贝至自己的文件夹下再进行操作**"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "1. 下面是一个完整的针对 [SMP2018中文人机对话技术评测(ECDT)](http://smp2018.cips-smp.org/ecdt_index.html) 的实验,由该实验训练的基线模型能达到评测排行榜的前三的水平。\n",
22 | "2. 通过本实验,可以掌握处理自然语言文本数据的一般方法。\n",
23 | "3. 推荐自己修改此文件,达到更好的实验效果,比如改变以下几个超参数 "
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "```python\n",
31 | "# 词嵌入的维度\n",
32 | "embedding_word_dims = 32\n",
33 | "# 批次大小\n",
34 | "batch_size = 30\n",
35 | "# 周期\n",
36 | "epochs = 20\n",
37 | "```"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "# 本实验还可以改进的地方举例 "
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "1. 预处理阶段使用其它的分词工具\n",
52 | "2. 采用字符向量和词向量结合的方式\n",
53 | "3. 使用预先训练好的词向量\n",
54 | "4. 改变模型结构\n",
55 | "5. 改变模型超参数"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "# 导入依赖库"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 1,
68 | "metadata": {},
69 | "outputs": [
70 | {
71 | "name": "stderr",
72 | "output_type": "stream",
73 | "text": [
74 | "Using TensorFlow backend.\n"
75 | ]
76 | }
77 | ],
78 | "source": [
79 | "import numpy as np\n",
80 | "import pandas as pd\n",
81 | "import collections\n",
82 | "import jieba\n",
83 | "from keras.preprocessing.sequence import pad_sequences\n",
84 | "from keras.models import Sequential\n",
85 | "from keras.layers import Embedding, LSTM, Dense\n",
86 | "from keras.utils import to_categorical,plot_model\n",
87 | "from keras.callbacks import TensorBoard, Callback\n",
88 | "\n",
89 | "from sklearn.metrics import classification_report\n",
90 | "\n",
91 | "import requests \n",
92 | "\n",
93 | "import time\n",
94 | "\n",
95 | "import os"
96 | ]
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {},
101 | "source": [
102 | "# 辅助函数"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 2,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "from keras import backend as K\n",
112 | "\n",
113 | "# 计算 F1 值的函数\n",
114 | "def f1(y_true, y_pred):\n",
115 | " def recall(y_true, y_pred):\n",
116 | " \"\"\"Recall metric.\n",
117 | "\n",
118 | " Only computes a batch-wise average of recall.\n",
119 | "\n",
120 | " Computes the recall, a metric for multi-label classification of\n",
121 | " how many relevant items are selected.\n",
122 | " \"\"\"\n",
123 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n",
124 | " possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))\n",
125 | " recall = true_positives / (possible_positives + K.epsilon())\n",
126 | " return recall\n",
127 | "\n",
128 | " def precision(y_true, y_pred):\n",
129 | " \"\"\"Precision metric.\n",
130 | "\n",
131 | " Only computes a batch-wise average of precision.\n",
132 | "\n",
133 | " Computes the precision, a metric for multi-label classification of\n",
134 | " how many selected items are relevant.\n",
135 | " \"\"\"\n",
136 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n",
137 | " predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))\n",
138 | " precision = true_positives / (predicted_positives + K.epsilon())\n",
139 | " return precision\n",
140 | " precision = precision(y_true, y_pred)\n",
141 | " recall = recall(y_true, y_pred)\n",
142 | " return 2*((precision*recall)/(precision+recall+K.epsilon()))"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": 3,
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "# 获取自定义时间格式的字符串\n",
152 | "def get_customization_time():\n",
153 | " # return '2018_10_10_18_11_45' 年月日时分秒\n",
154 | " time_tuple = time.localtime(time.time())\n",
155 | " customization_time = \"{}_{}_{}_{}_{}_{}\".format(time_tuple[0], time_tuple[1], time_tuple[2], time_tuple[3], time_tuple[4], time_tuple[5])\n",
156 | " return customization_time"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "# 准备数据"
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {},
169 | "source": [
170 | "## [下载SMP2018官方数据](https://worksheets.codalab.org/worksheets/0x27203f932f8341b79841d50ce0fd684f/)"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": 3,
176 | "metadata": {},
177 | "outputs": [],
178 | "source": [
179 | "raw_train_data_url = \"https://worksheets.codalab.org/rest/bundles/0x0161fd2fb40d4dd48541c2643d04b0b8/contents/blob/\"\n",
180 | "raw_test_data_url = \"https://worksheets.codalab.org/rest/bundles/0x1f96bc12222641209ad057e762910252/contents/blob/\"\n",
181 | "\n",
182 | "# 如果不存在 SMP2018 数据,则下载\n",
183 | "if (not os.path.exists('./data/train.json')) or (not os.path.exists('./data/dev.json')):\n",
184 | " raw_train = requests.get(raw_train_data_url) \n",
185 | " raw_test = requests.get(raw_test_data_url) \n",
186 | " if not os.path.exists('./data'):\n",
187 | " os.makedirs('./data')\n",
188 | " with open(\"./data/train.json\", \"wb\") as code:\n",
189 | " code.write(raw_train.content)\n",
190 | " with open(\"./data/dev.json\", \"wb\") as code:\n",
191 | " code.write(raw_test.content)"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": 4,
197 | "metadata": {},
198 | "outputs": [],
199 | "source": [
200 | "def get_json_data(path):\n",
201 | " # read data\n",
202 | " data_df = pd.read_json(path)\n",
203 | " # change row and colunm\n",
204 | " data_df = data_df.transpose()\n",
205 | " # change colunm order\n",
206 | " data_df = data_df[['query', 'label']]\n",
207 | " return data_df"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": 5,
213 | "metadata": {},
214 | "outputs": [],
215 | "source": [
216 | "train_data_df = get_json_data(path=\"data/train.json\")\n",
217 | "\n",
218 | "test_data_df = get_json_data(path=\"data/dev.json\")"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": 6,
224 | "metadata": {},
225 | "outputs": [
226 | {
227 | "data": {
228 | "text/html": [
229 | "\n",
230 | "\n",
243 | "
\n",
244 | " \n",
245 | " \n",
246 | " | \n",
247 | " query | \n",
248 | " label | \n",
249 | "
\n",
250 | " \n",
251 | " \n",
252 | " \n",
253 | " | 0 | \n",
254 | " 今天东莞天气如何 | \n",
255 | " weather | \n",
256 | "
\n",
257 | " \n",
258 | " | 1 | \n",
259 | " 从观音桥到重庆市图书馆怎么走 | \n",
260 | " map | \n",
261 | "
\n",
262 | " \n",
263 | " | 2 | \n",
264 | " 鸭蛋怎么腌? | \n",
265 | " cookbook | \n",
266 | "
\n",
267 | " \n",
268 | " | 3 | \n",
269 | " 怎么治疗牛皮癣 | \n",
270 | " health | \n",
271 | "
\n",
272 | " \n",
273 | " | 4 | \n",
274 | " 唠什么 | \n",
275 | " chat | \n",
276 | "
\n",
277 | " \n",
278 | "
\n",
279 | "
"
280 | ],
281 | "text/plain": [
282 | " query label\n",
283 | "0 今天东莞天气如何 weather\n",
284 | "1 从观音桥到重庆市图书馆怎么走 map\n",
285 | "2 鸭蛋怎么腌? cookbook\n",
286 | "3 怎么治疗牛皮癣 health\n",
287 | "4 唠什么 chat"
288 | ]
289 | },
290 | "execution_count": 6,
291 | "metadata": {},
292 | "output_type": "execute_result"
293 | }
294 | ],
295 | "source": [
296 | "train_data_df.head()"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": 7,
302 | "metadata": {},
303 | "outputs": [
304 | {
305 | "data": {
306 | "text/html": [
307 | "\n",
308 | "\n",
321 | "
\n",
322 | " \n",
323 | " \n",
324 | " | \n",
325 | " query | \n",
326 | " label | \n",
327 | "
\n",
328 | " \n",
329 | " \n",
330 | " \n",
331 | " | 0 | \n",
332 | " 毛泽东的诗哦。 | \n",
333 | " poetry | \n",
334 | "
\n",
335 | " \n",
336 | " | 1 | \n",
337 | " 有房有车吗微笑 | \n",
338 | " chat | \n",
339 | "
\n",
340 | " \n",
341 | " | 2 | \n",
342 | " 2013年亚洲冠军联赛恒广州恒大比赛时间。 | \n",
343 | " match | \n",
344 | "
\n",
345 | " \n",
346 | " | 3 | \n",
347 | " 若相惜不弃下一句是什么? | \n",
348 | " poetry | \n",
349 | "
\n",
350 | " \n",
351 | " | 4 | \n",
352 | " 苹果翻译成英语 | \n",
353 | " translation | \n",
354 | "
\n",
355 | " \n",
356 | "
\n",
357 | "
"
358 | ],
359 | "text/plain": [
360 | " query label\n",
361 | "0 毛泽东的诗哦。 poetry\n",
362 | "1 有房有车吗微笑 chat\n",
363 | "2 2013年亚洲冠军联赛恒广州恒大比赛时间。 match\n",
364 | "3 若相惜不弃下一句是什么? poetry\n",
365 | "4 苹果翻译成英语 translation"
366 | ]
367 | },
368 | "execution_count": 7,
369 | "metadata": {},
370 | "output_type": "execute_result"
371 | }
372 | ],
373 | "source": [
374 | "test_data_df.head()"
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "execution_count": 8,
380 | "metadata": {},
381 | "outputs": [
382 | {
383 | "data": {
384 | "text/html": [
385 | "\n",
386 | "\n",
399 | "
\n",
400 | " \n",
401 | " \n",
402 | " | \n",
403 | " query | \n",
404 | " label | \n",
405 | "
\n",
406 | " \n",
407 | " \n",
408 | " \n",
409 | " | count | \n",
410 | " 2299 | \n",
411 | " 2299 | \n",
412 | "
\n",
413 | " \n",
414 | " | unique | \n",
415 | " 2299 | \n",
416 | " 31 | \n",
417 | "
\n",
418 | " \n",
419 | " | top | \n",
420 | " 中国新闻网网站 | \n",
421 | " chat | \n",
422 | "
\n",
423 | " \n",
424 | " | freq | \n",
425 | " 1 | \n",
426 | " 455 | \n",
427 | "
\n",
428 | " \n",
429 | "
\n",
430 | "
"
431 | ],
432 | "text/plain": [
433 | " query label\n",
434 | "count 2299 2299\n",
435 | "unique 2299 31\n",
436 | "top 中国新闻网网站 chat\n",
437 | "freq 1 455"
438 | ]
439 | },
440 | "execution_count": 8,
441 | "metadata": {},
442 | "output_type": "execute_result"
443 | }
444 | ],
445 | "source": [
446 | "train_data_df.describe()"
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "execution_count": 9,
452 | "metadata": {},
453 | "outputs": [
454 | {
455 | "data": {
456 | "text/html": [
457 | "\n",
458 | "\n",
471 | "
\n",
472 | " \n",
473 | " \n",
474 | " | \n",
475 | " query | \n",
476 | " label | \n",
477 | "
\n",
478 | " \n",
479 | " \n",
480 | " \n",
481 | " | count | \n",
482 | " 770 | \n",
483 | " 770 | \n",
484 | "
\n",
485 | " \n",
486 | " | unique | \n",
487 | " 770 | \n",
488 | " 31 | \n",
489 | "
\n",
490 | " \n",
491 | " | top | \n",
492 | " 查下安徽电视台今天节目单 | \n",
493 | " chat | \n",
494 | "
\n",
495 | " \n",
496 | " | freq | \n",
497 | " 1 | \n",
498 | " 154 | \n",
499 | "
\n",
500 | " \n",
501 | "
\n",
502 | "
"
503 | ],
504 | "text/plain": [
505 | " query label\n",
506 | "count 770 770\n",
507 | "unique 770 31\n",
508 | "top 查下安徽电视台今天节目单 chat\n",
509 | "freq 1 154"
510 | ]
511 | },
512 | "execution_count": 9,
513 | "metadata": {},
514 | "output_type": "execute_result"
515 | }
516 | ],
517 | "source": [
518 | "test_data_df.describe()"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "execution_count": 10,
524 | "metadata": {},
525 | "outputs": [],
526 | "source": [
527 | "# 获取所以标签,也就是分类的类别\n",
528 | "labels = list(set(train_data_df['label'].tolist()))"
529 | ]
530 | },
531 | {
532 | "cell_type": "raw",
533 | "metadata": {},
534 | "source": [
535 | "# All labels\n",
536 | "labels = ['website', 'tvchannel', 'lottery', 'chat', 'match',\n",
537 | " 'datetime', 'weather', 'bus', 'novel', 'video', 'riddle',\n",
538 | " 'calc', 'telephone', 'health', 'contacts', 'epg', 'app', 'music',\n",
539 | " 'cookbook', 'stock', 'map', 'message', 'poetry', 'cinemas', 'news',\n",
540 | " 'flight', 'translation', 'train', 'schedule', 'radio', 'email']"
541 | ]
542 | },
543 | {
544 | "cell_type": "code",
545 | "execution_count": 11,
546 | "metadata": {},
547 | "outputs": [
548 | {
549 | "name": "stdout",
550 | "output_type": "stream",
551 | "text": [
552 | "label_numbers:\t 31\n"
553 | ]
554 | }
555 | ],
556 | "source": [
557 | "label_numbers = len(labels)\n",
558 | "print('label_numbers:\\t', label_numbers)"
559 | ]
560 | },
561 | {
562 | "cell_type": "markdown",
563 | "metadata": {},
564 | "source": [
565 | "## 标签和对应ID的映射字典"
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "execution_count": 12,
571 | "metadata": {},
572 | "outputs": [],
573 | "source": [
574 | "label_2_index_dict = dict([(label, index) for index, label in enumerate(labels)])\n",
575 | "index_2_label_dict = dict([(index, label) for index, label in enumerate(labels)])"
576 | ]
577 | },
578 | {
579 | "cell_type": "markdown",
580 | "metadata": {},
581 | "source": [
582 | "---"
583 | ]
584 | },
585 | {
586 | "cell_type": "markdown",
587 | "metadata": {},
588 | "source": [
589 | "## [结巴分词](https://github.com/fxsjy/jieba)示例,下面将使用结巴分词对原数据进行处理"
590 | ]
591 | },
592 | {
593 | "cell_type": "code",
594 | "execution_count": 13,
595 | "metadata": {},
596 | "outputs": [
597 | {
598 | "name": "stderr",
599 | "output_type": "stream",
600 | "text": [
601 | "Building prefix dict from the default dictionary ...\n",
602 | "Loading model from cache /tmp/jieba.cache\n",
603 | "Loading model cost 0.903 seconds.\n",
604 | "Prefix dict has been built succesfully.\n"
605 | ]
606 | },
607 | {
608 | "name": "stdout",
609 | "output_type": "stream",
610 | "text": [
611 | "['他', '来到', '了', '网易', '杭研', '大厦']\n"
612 | ]
613 | }
614 | ],
615 | "source": [
616 | "seg_list = jieba.cut(\"他来到了网易杭研大厦\") # 默认是精确模式\n",
617 | "print(list(seg_list))"
618 | ]
619 | },
620 | {
621 | "cell_type": "markdown",
622 | "metadata": {},
623 | "source": [
624 | "---"
625 | ]
626 | },
627 | {
628 | "cell_type": "markdown",
629 | "metadata": {},
630 | "source": [
631 | "# 序列化"
632 | ]
633 | },
634 | {
635 | "cell_type": "code",
636 | "execution_count": 14,
637 | "metadata": {},
638 | "outputs": [],
639 | "source": [
640 | "def use_jieba_cut(a_sentence):\n",
641 | " return list(jieba.cut(a_sentence))\n",
642 | "\n",
643 | "train_data_df['cut_query'] = train_data_df['query'].apply(use_jieba_cut)\n",
644 | "test_data_df['cut_query'] = test_data_df['query'].apply(use_jieba_cut)"
645 | ]
646 | },
647 | {
648 | "cell_type": "code",
649 | "execution_count": 15,
650 | "metadata": {},
651 | "outputs": [
652 | {
653 | "data": {
654 | "text/html": [
655 | "\n",
656 | "\n",
669 | "
\n",
670 | " \n",
671 | " \n",
672 | " | \n",
673 | " query | \n",
674 | " label | \n",
675 | " cut_query | \n",
676 | "
\n",
677 | " \n",
678 | " \n",
679 | " \n",
680 | " | 0 | \n",
681 | " 今天东莞天气如何 | \n",
682 | " weather | \n",
683 | " [今天, 东莞, 天气, 如何] | \n",
684 | "
\n",
685 | " \n",
686 | " | 1 | \n",
687 | " 从观音桥到重庆市图书馆怎么走 | \n",
688 | " map | \n",
689 | " [从, 观音桥, 到, 重庆市, 图书馆, 怎么, 走] | \n",
690 | "
\n",
691 | " \n",
692 | " | 2 | \n",
693 | " 鸭蛋怎么腌? | \n",
694 | " cookbook | \n",
695 | " [鸭蛋, 怎么, 腌, ?] | \n",
696 | "
\n",
697 | " \n",
698 | " | 3 | \n",
699 | " 怎么治疗牛皮癣 | \n",
700 | " health | \n",
701 | " [怎么, 治疗, 牛皮癣] | \n",
702 | "
\n",
703 | " \n",
704 | " | 4 | \n",
705 | " 唠什么 | \n",
706 | " chat | \n",
707 | " [唠, 什么] | \n",
708 | "
\n",
709 | " \n",
710 | " | 5 | \n",
711 | " 阳澄湖大闸蟹的做法。 | \n",
712 | " cookbook | \n",
713 | " [阳澄湖, 大闸蟹, 的, 做法, 。] | \n",
714 | "
\n",
715 | " \n",
716 | " | 6 | \n",
717 | " 昆山大润发在哪里 | \n",
718 | " map | \n",
719 | " [昆山, 大润发, 在, 哪里] | \n",
720 | "
\n",
721 | " \n",
722 | " | 7 | \n",
723 | " 红烧肉怎么做?嗯? | \n",
724 | " cookbook | \n",
725 | " [红烧肉, 怎么, 做, ?, 嗯, ?] | \n",
726 | "
\n",
727 | " \n",
728 | " | 8 | \n",
729 | " 南京到厦门的火车票 | \n",
730 | " train | \n",
731 | " [南京, 到, 厦门, 的, 火车票] | \n",
732 | "
\n",
733 | " \n",
734 | " | 9 | \n",
735 | " 6的平方 | \n",
736 | " calc | \n",
737 | " [6, 的, 平方] | \n",
738 | "
\n",
739 | " \n",
740 | "
\n",
741 | "
"
742 | ],
743 | "text/plain": [
744 | " query label cut_query\n",
745 | "0 今天东莞天气如何 weather [今天, 东莞, 天气, 如何]\n",
746 | "1 从观音桥到重庆市图书馆怎么走 map [从, 观音桥, 到, 重庆市, 图书馆, 怎么, 走]\n",
747 | "2 鸭蛋怎么腌? cookbook [鸭蛋, 怎么, 腌, ?]\n",
748 | "3 怎么治疗牛皮癣 health [怎么, 治疗, 牛皮癣]\n",
749 | "4 唠什么 chat [唠, 什么]\n",
750 | "5 阳澄湖大闸蟹的做法。 cookbook [阳澄湖, 大闸蟹, 的, 做法, 。]\n",
751 | "6 昆山大润发在哪里 map [昆山, 大润发, 在, 哪里]\n",
752 | "7 红烧肉怎么做?嗯? cookbook [红烧肉, 怎么, 做, ?, 嗯, ?]\n",
753 | "8 南京到厦门的火车票 train [南京, 到, 厦门, 的, 火车票]\n",
754 | "9 6的平方 calc [6, 的, 平方]"
755 | ]
756 | },
757 | "execution_count": 15,
758 | "metadata": {},
759 | "output_type": "execute_result"
760 | }
761 | ],
762 | "source": [
763 | "train_data_df.head(10)"
764 | ]
765 | },
766 | {
767 | "cell_type": "code",
768 | "execution_count": 16,
769 | "metadata": {},
770 | "outputs": [],
771 | "source": [
772 | "# 获取数据的所有词汇\n",
773 | "def get_all_vocab_from_data(data, colunm_name):\n",
774 | " train_vocab_list = []\n",
775 | " max_cut_query_lenth = 0\n",
776 | " for cut_query in data[colunm_name]:\n",
777 | " if len(cut_query) > max_cut_query_lenth:\n",
778 | " max_cut_query_lenth = len(cut_query)\n",
779 | " train_vocab_list += cut_query\n",
780 | " return train_vocab_list, max_cut_query_lenth "
781 | ]
782 | },
783 | {
784 | "cell_type": "code",
785 | "execution_count": 17,
786 | "metadata": {},
787 | "outputs": [],
788 | "source": [
789 | "train_vocab_list, max_cut_query_lenth = get_all_vocab_from_data(train_data_df, 'cut_query')"
790 | ]
791 | },
792 | {
793 | "cell_type": "code",
794 | "execution_count": 18,
795 | "metadata": {},
796 | "outputs": [
797 | {
798 | "name": "stdout",
799 | "output_type": "stream",
800 | "text": [
801 | "Number of words:\t 11498\n"
802 | ]
803 | }
804 | ],
805 | "source": [
806 | "print('Number of words:\\t', len(train_vocab_list))"
807 | ]
808 | },
809 | {
810 | "cell_type": "code",
811 | "execution_count": 19,
812 | "metadata": {},
813 | "outputs": [
814 | {
815 | "name": "stdout",
816 | "output_type": "stream",
817 | "text": [
818 | "max_cut_query_lenth:\t 26\n"
819 | ]
820 | }
821 | ],
822 | "source": [
823 | "print('max_cut_query_lenth:\\t', max_cut_query_lenth)"
824 | ]
825 | },
826 | {
827 | "cell_type": "code",
828 | "execution_count": 20,
829 | "metadata": {},
830 | "outputs": [],
831 | "source": [
832 | "test_vocab_list, test_max_cut_query_lenth = get_all_vocab_from_data(train_data_df, 'cut_query')"
833 | ]
834 | },
835 | {
836 | "cell_type": "code",
837 | "execution_count": 21,
838 | "metadata": {},
839 | "outputs": [
840 | {
841 | "name": "stdout",
842 | "output_type": "stream",
843 | "text": [
844 | "test_max_cut_query_lenth:\t 26\n"
845 | ]
846 | }
847 | ],
848 | "source": [
849 | "print('test_max_cut_query_lenth:\\t', test_max_cut_query_lenth)"
850 | ]
851 | },
852 | {
853 | "cell_type": "code",
854 | "execution_count": 22,
855 | "metadata": {},
856 | "outputs": [
857 | {
858 | "data": {
859 | "text/plain": [
860 | "['今天', '东莞', '天气', '如何', '从', '观音桥', '到', '重庆市', '图书馆', '怎么']"
861 | ]
862 | },
863 | "execution_count": 22,
864 | "metadata": {},
865 | "output_type": "execute_result"
866 | }
867 | ],
868 | "source": [
869 | "train_vocab_list[:10]"
870 | ]
871 | },
872 | {
873 | "cell_type": "code",
874 | "execution_count": 23,
875 | "metadata": {},
876 | "outputs": [],
877 | "source": [
878 | "train_vocab_counter = collections.Counter(train_vocab_list)"
879 | ]
880 | },
881 | {
882 | "cell_type": "code",
883 | "execution_count": 24,
884 | "metadata": {},
885 | "outputs": [
886 | {
887 | "name": "stdout",
888 | "output_type": "stream",
889 | "text": [
890 | "Number of different words:\t 2887\n"
891 | ]
892 | }
893 | ],
894 | "source": [
895 | "print('Number of different words:\\t', len(train_vocab_counter.keys()))"
896 | ]
897 | },
898 | {
899 | "cell_type": "markdown",
900 | "metadata": {},
901 | "source": [
902 | "## 不同种类的词汇个数,预留一个位置给不存在的词汇(不存在的词汇标记为0) "
903 | ]
904 | },
905 | {
906 | "cell_type": "code",
907 | "execution_count": 26,
908 | "metadata": {},
909 | "outputs": [],
910 | "source": [
911 | "max_features = len(train_vocab_counter.keys()) + 1"
912 | ]
913 | },
914 | {
915 | "cell_type": "code",
916 | "execution_count": 27,
917 | "metadata": {},
918 | "outputs": [
919 | {
920 | "name": "stdout",
921 | "output_type": "stream",
922 | "text": [
923 | "2888\n"
924 | ]
925 | }
926 | ],
927 | "source": [
928 | "print(max_features)"
929 | ]
930 | },
931 | {
932 | "cell_type": "code",
933 | "execution_count": 28,
934 | "metadata": {},
935 | "outputs": [
936 | {
937 | "data": {
938 | "text/plain": [
939 | "[('的', 605),\n",
940 | " ('。', 341),\n",
941 | " ('我', 320),\n",
942 | " ('你', 297),\n",
943 | " ('怎么', 273),\n",
944 | " ('?', 251),\n",
945 | " ('什么', 210),\n",
946 | " ('到', 165),\n",
947 | " ('给', 154),\n",
948 | " ('做', 148)]"
949 | ]
950 | },
951 | "execution_count": 28,
952 | "metadata": {},
953 | "output_type": "execute_result"
954 | }
955 | ],
956 | "source": [
957 | "# 10 words with the highest frequency\n",
958 | "train_vocab_counter.most_common(10)"
959 | ]
960 | },
961 | {
962 | "cell_type": "markdown",
963 | "metadata": {},
964 | "source": [
965 | "## 统计低频词语"
966 | ]
967 | },
968 | {
969 | "cell_type": "code",
970 | "execution_count": 29,
971 | "metadata": {},
972 | "outputs": [
973 | {
974 | "name": "stdout",
975 | "output_type": "stream",
976 | "text": [
977 | "word_times_zero:\t 1978\n",
978 | "word_times_zero/all:\t 0.685140284031867\n"
979 | ]
980 | }
981 | ],
982 | "source": [
983 | "word_times_zero = 0\n",
984 | "for word, word_times in train_vocab_counter.items():\n",
985 | " if word_times <=1:\n",
986 | " word_times_zero+=1\n",
987 | "print('word_times_zero:\\t', word_times_zero)\n",
988 | "print('word_times_zero/all:\\t', word_times_zero/len(train_vocab_counter))"
989 | ]
990 | },
991 | {
992 | "cell_type": "markdown",
993 | "metadata": {},
994 | "source": [
995 | "## 制作词汇字典"
996 | ]
997 | },
998 | {
999 | "cell_type": "code",
1000 | "execution_count": 30,
1001 | "metadata": {},
1002 | "outputs": [],
1003 | "source": [
1004 | "def create_train_vocab_dict(train_vocab_counter):\n",
1005 | " word_2_index, index_2_word = {}, {}\n",
1006 | " # Reserve 0 for masking via pad_sequences\n",
1007 | " index_number = 1\n",
1008 | " for word, word_times in train_vocab_counter.most_common():\n",
1009 | " word_2_index[word] = index_number\n",
1010 | " index_2_word[index_number] = word\n",
1011 | " index_number += 1\n",
1012 | " return word_2_index, index_2_word "
1013 | ]
1014 | },
1015 | {
1016 | "cell_type": "code",
1017 | "execution_count": 31,
1018 | "metadata": {},
1019 | "outputs": [],
1020 | "source": [
1021 | "word_2_index_dict, index_2_word_dict = create_train_vocab_dict(train_vocab_counter)"
1022 | ]
1023 | },
1024 | {
1025 | "cell_type": "code",
1026 | "execution_count": 32,
1027 | "metadata": {},
1028 | "outputs": [
1029 | {
1030 | "name": "stdout",
1031 | "output_type": "stream",
1032 | "text": [
1033 | "1 2\n"
1034 | ]
1035 | }
1036 | ],
1037 | "source": [
1038 | "print(word_2_index_dict['的'], word_2_index_dict['。'])"
1039 | ]
1040 | },
1041 | {
1042 | "cell_type": "code",
1043 | "execution_count": 33,
1044 | "metadata": {},
1045 | "outputs": [
1046 | {
1047 | "name": "stdout",
1048 | "output_type": "stream",
1049 | "text": [
1050 | "的 。\n"
1051 | ]
1052 | }
1053 | ],
1054 | "source": [
1055 | "print(index_2_word_dict[1], index_2_word_dict[2])"
1056 | ]
1057 | },
1058 | {
1059 | "cell_type": "code",
1060 | "execution_count": 34,
1061 | "metadata": {},
1062 | "outputs": [
1063 | {
1064 | "name": "stdout",
1065 | "output_type": "stream",
1066 | "text": [
1067 | "今天东莞天气如何 weather ['今天', '东莞', '天气', '如何']\n",
1068 | "从观音桥到重庆市图书馆怎么走 map ['从', '观音桥', '到', '重庆市', '图书馆', '怎么', '走']\n",
1069 | "鸭蛋怎么腌? cookbook ['鸭蛋', '怎么', '腌', '?']\n",
1070 | "怎么治疗牛皮癣 health ['怎么', '治疗', '牛皮癣']\n",
1071 | "唠什么 chat ['唠', '什么']\n",
1072 | "阳澄湖大闸蟹的做法。 cookbook ['阳澄湖', '大闸蟹', '的', '做法', '。']\n",
1073 | "昆山大润发在哪里 map ['昆山', '大润发', '在', '哪里']\n",
1074 | "红烧肉怎么做?嗯? cookbook ['红烧肉', '怎么', '做', '?', '嗯', '?']\n",
1075 | "南京到厦门的火车票 train ['南京', '到', '厦门', '的', '火车票']\n",
1076 | "6的平方 calc ['6', '的', '平方']\n"
1077 | ]
1078 | }
1079 | ],
1080 | "source": [
1081 | "pq= 0\n",
1082 | "for index, row in train_data_df.iterrows():\n",
1083 | " print(row[0], row[1], row[2])\n",
1084 | " pq+=1\n",
1085 | " if pq==10:\n",
1086 | " break"
1087 | ]
1088 | },
1089 | {
1090 | "cell_type": "code",
1091 | "execution_count": 35,
1092 | "metadata": {},
1093 | "outputs": [
1094 | {
1095 | "data": {
1096 | "text/plain": [
1097 | "0"
1098 | ]
1099 | },
1100 | "execution_count": 35,
1101 | "metadata": {},
1102 | "output_type": "execute_result"
1103 | }
1104 | ],
1105 | "source": [
1106 | "word_2_index_dict.get('的2', 0)"
1107 | ]
1108 | },
1109 | {
1110 | "cell_type": "code",
1111 | "execution_count": 36,
1112 | "metadata": {},
1113 | "outputs": [],
1114 | "source": [
1115 | "def vectorize_data(data, label_2_index_dict, word_2_index_dict, max_cut_query_lenth):\n",
1116 | " x_train = []\n",
1117 | " y_train = []\n",
1118 | " for index, row in data.iterrows():\n",
1119 | " query_sentence = row[2]\n",
1120 | " label = row[1]\n",
1121 | " # 字典找不到的情况下用 0 填充\n",
1122 | " x = [word_2_index_dict.get(w, 0) for w in query_sentence]\n",
1123 | " y = [label_2_index_dict[label]]\n",
1124 | " x_train.append(x)\n",
1125 | " y_train.append(y)\n",
1126 | " return (pad_sequences(x_train, maxlen=max_cut_query_lenth),\n",
1127 | " pad_sequences(y_train, maxlen=1))"
1128 | ]
1129 | },
1130 | {
1131 | "cell_type": "code",
1132 | "execution_count": 37,
1133 | "metadata": {},
1134 | "outputs": [],
1135 | "source": [
1136 | "x_train, y_train = vectorize_data(train_data_df, label_2_index_dict, word_2_index_dict, max_cut_query_lenth)"
1137 | ]
1138 | },
1139 | {
1140 | "cell_type": "code",
1141 | "execution_count": 38,
1142 | "metadata": {},
1143 | "outputs": [],
1144 | "source": [
1145 | "x_test, y_test = vectorize_data(test_data_df, label_2_index_dict, word_2_index_dict, test_max_cut_query_lenth)"
1146 | ]
1147 | },
1148 | {
1149 | "cell_type": "code",
1150 | "execution_count": 39,
1151 | "metadata": {},
1152 | "outputs": [
1153 | {
1154 | "name": "stdout",
1155 | "output_type": "stream",
1156 | "text": [
1157 | "[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
1158 | " 0 0 0 0 33 318 27 90] [7]\n"
1159 | ]
1160 | }
1161 | ],
1162 | "source": [
1163 | "print(x_train[0], y_train[0])"
1164 | ]
1165 | },
1166 | {
1167 | "cell_type": "code",
1168 | "execution_count": 40,
1169 | "metadata": {},
1170 | "outputs": [],
1171 | "source": [
1172 | "y_train = to_categorical(y_train, label_numbers)\n",
1173 | "y_test = to_categorical(y_test, label_numbers)"
1174 | ]
1175 | },
1176 | {
1177 | "cell_type": "code",
1178 | "execution_count": 41,
1179 | "metadata": {},
1180 | "outputs": [
1181 | {
1182 | "name": "stdout",
1183 | "output_type": "stream",
1184 | "text": [
1185 | "(2299, 26) (2299, 31)\n"
1186 | ]
1187 | }
1188 | ],
1189 | "source": [
1190 | "print(x_train.shape, y_train.shape)"
1191 | ]
1192 | },
1193 | {
1194 | "cell_type": "code",
1195 | "execution_count": 42,
1196 | "metadata": {},
1197 | "outputs": [
1198 | {
1199 | "name": "stdout",
1200 | "output_type": "stream",
1201 | "text": [
1202 | "(770, 26) (770, 31)\n"
1203 | ]
1204 | }
1205 | ],
1206 | "source": [
1207 | "print(x_test.shape, y_test.shape)"
1208 | ]
1209 | },
1210 | {
1211 | "cell_type": "markdown",
1212 | "metadata": {},
1213 | "source": [
1214 | "# 存储预处理过的数据"
1215 | ]
1216 | },
1217 | {
1218 | "cell_type": "code",
1219 | "execution_count": 43,
1220 | "metadata": {},
1221 | "outputs": [
1222 | {
1223 | "name": "stdout",
1224 | "output_type": "stream",
1225 | "text": [
1226 | "\n"
1227 | ]
1228 | }
1229 | ],
1230 | "source": [
1231 | "print(type(x_test))"
1232 | ]
1233 | },
1234 | {
1235 | "cell_type": "code",
1236 | "execution_count": 44,
1237 | "metadata": {},
1238 | "outputs": [],
1239 | "source": [
1240 | "np.savez(\"preprocessed_data\", x_train, y_train, x_test, y_test)"
1241 | ]
1242 | },
1243 | {
1244 | "cell_type": "markdown",
1245 | "metadata": {},
1246 | "source": [
1247 | "## 直接加载预处理的数据"
1248 | ]
1249 | },
1250 | {
1251 | "cell_type": "code",
1252 | "execution_count": 4,
1253 | "metadata": {},
1254 | "outputs": [],
1255 | "source": [
1256 | "# 使用已经经过预处理的数据,默认不使用\n",
1257 | "use_preprocessed_data = True\n",
1258 | "\n",
1259 | "if use_preprocessed_data == True:\n",
1260 | " preprocessed_data = np.load('preprocessed_data.npz')\n",
1261 | " x_train, y_train, x_test, y_test = preprocessed_data['arr_0'], preprocessed_data['arr_1'], preprocessed_data['arr_2'], preprocessed_data['arr_3'],"
1262 | ]
1263 | },
1264 | {
1265 | "cell_type": "code",
1266 | "execution_count": 5,
1267 | "metadata": {},
1268 | "outputs": [
1269 | {
1270 | "name": "stdout",
1271 | "output_type": "stream",
1272 | "text": [
1273 | "(2299, 26) (2299, 31)\n"
1274 | ]
1275 | }
1276 | ],
1277 | "source": [
1278 | "print(x_train.shape, y_train.shape)"
1279 | ]
1280 | },
1281 | {
1282 | "cell_type": "markdown",
1283 | "metadata": {},
1284 | "source": [
1285 | "# 设计模型"
1286 | ]
1287 | },
1288 | {
1289 | "cell_type": "code",
1290 | "execution_count": 6,
1291 | "metadata": {},
1292 | "outputs": [],
1293 | "source": [
1294 | "def create_SMP2018_lstm_model(max_features, max_cut_query_lenth, label_numbers):\n",
1295 | " model = Sequential()\n",
1296 | " model.add(Embedding(input_dim=max_features, output_dim=32, input_length=max_cut_query_lenth))\n",
1297 | " model.add(LSTM(units=64, dropout=0.2, recurrent_dropout=0.2))\n",
1298 | " model.add(Dense(label_numbers, activation='softmax'))\n",
1299 | " # try using different optimizers and different optimizer configs\n",
1300 | " model.compile(loss='categorical_crossentropy',\n",
1301 | " optimizer='adam',\n",
1302 | " metrics=[f1])\n",
1303 | "\n",
1304 | " plot_model(model, to_file='SMP2018_lstm_model.png', show_shapes=True)\n",
1305 | " \n",
1306 | " return model"
1307 | ]
1308 | },
1309 | {
1310 | "cell_type": "markdown",
1311 | "metadata": {},
1312 | "source": [
1313 | "# 训练模型"
1314 | ]
1315 | },
1316 | {
1317 | "cell_type": "code",
1318 | "execution_count": 7,
1319 | "metadata": {},
1320 | "outputs": [
1321 | {
1322 | "name": "stdout",
1323 | "output_type": "stream",
1324 | "text": [
1325 | "not find max_features variable, use default max_features values:\t2888\n",
1326 | "not find max_cut_query_lenth, use default max_features values:\t26\n",
1327 | "not find label_numbers, use default max_features values:\t31\n"
1328 | ]
1329 | }
1330 | ],
1331 | "source": [
1332 | "if 'max_features' not in dir():\n",
1333 | " max_features = 2888\n",
1334 | " print('not find max_features variable, use default max_features values:\\t{}'.format(max_features))\n",
1335 | "if 'max_cut_query_lenth' not in dir():\n",
1336 | " max_cut_query_lenth = 26\n",
1337 | " print('not find max_cut_query_lenth, use default max_features values:\\t{}'.format(max_cut_query_lenth))\n",
1338 | "if 'label_numbers' not in dir():\n",
1339 | " label_numbers = 31\n",
1340 | " print('not find label_numbers, use default max_features values:\\t{}'.format(label_numbers))"
1341 | ]
1342 | },
1343 | {
1344 | "cell_type": "code",
1345 | "execution_count": 8,
1346 | "metadata": {},
1347 | "outputs": [],
1348 | "source": [
1349 | "model = create_SMP2018_lstm_model(max_features, max_cut_query_lenth, label_numbers)"
1350 | ]
1351 | },
1352 | {
1353 | "cell_type": "code",
1354 | "execution_count": 9,
1355 | "metadata": {},
1356 | "outputs": [],
1357 | "source": [
1358 | "batch_size = 20\n",
1359 | "epochs = 300"
1360 | ]
1361 | },
1362 | {
1363 | "cell_type": "code",
1364 | "execution_count": null,
1365 | "metadata": {},
1366 | "outputs": [
1367 | {
1368 | "name": "stdout",
1369 | "output_type": "stream",
1370 | "text": [
1371 | "Train...\n",
1372 | "Train on 1839 samples, validate on 460 samples\n",
1373 | "Epoch 1/300\n",
1374 | "1839/1839 [==============================] - 6s 3ms/step - loss: 3.1404 - f1: 0.0000e+00 - val_loss: 2.9658 - val_f1: 0.0000e+00\n",
1375 | "Epoch 2/300\n",
1376 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.8634 - f1: 0.0404 - val_loss: 2.5949 - val_f1: 0.1618\n",
1377 | "Epoch 3/300\n",
1378 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.3841 - f1: 0.3354 - val_loss: 2.1469 - val_f1: 0.4080\n",
1379 | "Epoch 4/300\n",
1380 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.9530 - f1: 0.4240 - val_loss: 1.8311 - val_f1: 0.4429\n",
1381 | "Epoch 5/300\n",
1382 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.5153 - f1: 0.5092 - val_loss: 1.4660 - val_f1: 0.5133\n",
1383 | "Epoch 6/300\n",
1384 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.1055 - f1: 0.6257 - val_loss: 1.2311 - val_f1: 0.6446\n",
1385 | "Epoch 7/300\n",
1386 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.7985 - f1: 0.7558 - val_loss: 1.0519 - val_f1: 0.6857\n",
1387 | "Epoch 8/300\n",
1388 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.5887 - f1: 0.8245 - val_loss: 0.9113 - val_f1: 0.7443\n",
1389 | "Epoch 9/300\n",
1390 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.4365 - f1: 0.8729 - val_loss: 0.8589 - val_f1: 0.7589\n",
1391 | "Epoch 10/300\n",
1392 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.3196 - f1: 0.9178 - val_loss: 0.8198 - val_f1: 0.7948\n",
1393 | "Epoch 11/300\n",
1394 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.2584 - f1: 0.9379 - val_loss: 0.7777 - val_f1: 0.8046\n",
1395 | "Epoch 12/300\n",
1396 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.1941 - f1: 0.9593 - val_loss: 0.7518 - val_f1: 0.8343\n",
1397 | "Epoch 13/300\n",
1398 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.1540 - f1: 0.9726 - val_loss: 0.7506 - val_f1: 0.8322\n",
1399 | "Epoch 14/300\n",
1400 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.1220 - f1: 0.9808 - val_loss: 0.7529 - val_f1: 0.8195\n",
1401 | "Epoch 15/300\n",
1402 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.1044 - f1: 0.9837 - val_loss: 0.7723 - val_f1: 0.8226\n",
1403 | "Epoch 16/300\n",
1404 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0884 - f1: 0.9868 - val_loss: 0.7465 - val_f1: 0.8326\n",
1405 | "Epoch 17/300\n",
1406 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0722 - f1: 0.9900 - val_loss: 0.7687 - val_f1: 0.8240\n",
1407 | "Epoch 18/300\n",
1408 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0594 - f1: 0.9941 - val_loss: 0.7584 - val_f1: 0.8252\n",
1409 | "Epoch 19/300\n",
1410 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0497 - f1: 0.9947 - val_loss: 0.7572 - val_f1: 0.8302\n",
1411 | "Epoch 20/300\n",
1412 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0437 - f1: 0.9958 - val_loss: 0.7714 - val_f1: 0.8260\n",
1413 | "Epoch 21/300\n",
1414 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0398 - f1: 0.9972 - val_loss: 0.7631 - val_f1: 0.8246\n",
1415 | "Epoch 22/300\n",
1416 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0321 - f1: 0.9973 - val_loss: 0.7698 - val_f1: 0.8276\n",
1417 | "Epoch 23/300\n",
1418 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0313 - f1: 0.9967 - val_loss: 0.7809 - val_f1: 0.8288\n",
1419 | "Epoch 24/300\n",
1420 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0272 - f1: 0.9989 - val_loss: 0.7797 - val_f1: 0.8218\n",
1421 | "Epoch 25/300\n",
1422 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0240 - f1: 0.9986 - val_loss: 0.7531 - val_f1: 0.8345\n",
1423 | "Epoch 26/300\n",
1424 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0236 - f1: 0.9978 - val_loss: 0.7988 - val_f1: 0.8201\n",
1425 | "Epoch 27/300\n",
1426 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0207 - f1: 0.9986 - val_loss: 0.8156 - val_f1: 0.8259\n",
1427 | "Epoch 28/300\n",
1428 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0185 - f1: 0.9995 - val_loss: 0.7938 - val_f1: 0.8185\n",
1429 | "Epoch 29/300\n",
1430 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0189 - f1: 0.9981 - val_loss: 0.7839 - val_f1: 0.8169\n",
1431 | "Epoch 30/300\n",
1432 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0145 - f1: 1.0000 - val_loss: 0.8001 - val_f1: 0.8296\n",
1433 | "Epoch 31/300\n",
1434 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0163 - f1: 0.9984 - val_loss: 0.8265 - val_f1: 0.8116\n",
1435 | "Epoch 32/300\n",
1436 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0167 - f1: 0.9970 - val_loss: 0.8117 - val_f1: 0.8320\n",
1437 | "Epoch 33/300\n",
1438 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0131 - f1: 0.9997 - val_loss: 0.8121 - val_f1: 0.8224\n",
1439 | "Epoch 34/300\n",
1440 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0098 - f1: 1.0000 - val_loss: 0.8158 - val_f1: 0.8277\n",
1441 | "Epoch 35/300\n",
1442 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0133 - f1: 0.9986 - val_loss: 0.8314 - val_f1: 0.8242\n",
1443 | "Epoch 36/300\n",
1444 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0099 - f1: 1.0000 - val_loss: 0.8447 - val_f1: 0.8231\n",
1445 | "Epoch 37/300\n",
1446 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0081 - f1: 1.0000 - val_loss: 0.8237 - val_f1: 0.8312\n",
1447 | "Epoch 38/300\n",
1448 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0117 - f1: 0.9986 - val_loss: 0.8239 - val_f1: 0.8155\n",
1449 | "Epoch 39/300\n",
1450 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0082 - f1: 1.0000 - val_loss: 0.8470 - val_f1: 0.8204\n",
1451 | "Epoch 40/300\n",
1452 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0072 - f1: 1.0000 - val_loss: 0.8471 - val_f1: 0.8262\n",
1453 | "Epoch 41/300\n",
1454 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0098 - f1: 0.9992 - val_loss: 0.8262 - val_f1: 0.8323\n",
1455 | "Epoch 42/300\n",
1456 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0110 - f1: 0.9978 - val_loss: 0.8577 - val_f1: 0.8205\n",
1457 | "Epoch 43/300\n",
1458 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0069 - f1: 0.9997 - val_loss: 0.8587 - val_f1: 0.8226\n",
1459 | "Epoch 44/300\n",
1460 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0059 - f1: 0.9995 - val_loss: 0.8217 - val_f1: 0.8253\n",
1461 | "Epoch 45/300\n",
1462 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0054 - f1: 0.9995 - val_loss: 0.8342 - val_f1: 0.8269\n",
1463 | "Epoch 46/300\n",
1464 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0044 - f1: 1.0000 - val_loss: 0.8494 - val_f1: 0.8310\n",
1465 | "Epoch 47/300\n",
1466 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0040 - f1: 1.0000 - val_loss: 0.8496 - val_f1: 0.8341\n",
1467 | "Epoch 48/300\n",
1468 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0044 - f1: 1.0000 - val_loss: 0.8640 - val_f1: 0.8269\n",
1469 | "Epoch 49/300\n",
1470 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0049 - f1: 1.0000 - val_loss: 0.8453 - val_f1: 0.8321\n",
1471 | "Epoch 50/300\n",
1472 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0033 - f1: 1.0000 - val_loss: 0.9022 - val_f1: 0.8279\n",
1473 | "Epoch 51/300\n",
1474 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0074 - f1: 0.9989 - val_loss: 0.9145 - val_f1: 0.8171\n",
1475 | "Epoch 52/300\n",
1476 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0050 - f1: 0.9995 - val_loss: 0.9031 - val_f1: 0.8159\n",
1477 | "Epoch 53/300\n",
1478 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0083 - f1: 0.9981 - val_loss: 0.9456 - val_f1: 0.8084\n",
1479 | "Epoch 54/300\n",
1480 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0057 - f1: 0.9995 - val_loss: 0.8972 - val_f1: 0.8265\n",
1481 | "Epoch 55/300\n",
1482 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0027 - f1: 1.0000 - val_loss: 0.8921 - val_f1: 0.8284\n",
1483 | "Epoch 56/300\n",
1484 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0027 - f1: 1.0000 - val_loss: 0.8885 - val_f1: 0.8318\n",
1485 | "Epoch 57/300\n",
1486 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0024 - f1: 1.0000 - val_loss: 0.9059 - val_f1: 0.8283\n",
1487 | "Epoch 58/300\n",
1488 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0028 - f1: 1.0000 - val_loss: 0.9045 - val_f1: 0.8233\n",
1489 | "Epoch 59/300\n",
1490 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0022 - f1: 1.0000 - val_loss: 0.9238 - val_f1: 0.8302\n",
1491 | "Epoch 60/300\n",
1492 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0024 - f1: 0.9997 - val_loss: 0.9383 - val_f1: 0.8209\n",
1493 | "Epoch 61/300\n",
1494 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0030 - f1: 1.0000 - val_loss: 0.9409 - val_f1: 0.8157\n",
1495 | "Epoch 62/300\n",
1496 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0023 - f1: 1.0000 - val_loss: 0.9529 - val_f1: 0.8255\n",
1497 | "Epoch 63/300\n",
1498 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0046 - f1: 0.9997 - val_loss: 0.9899 - val_f1: 0.8158\n",
1499 | "Epoch 64/300\n",
1500 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0040 - f1: 0.9995 - val_loss: 0.9625 - val_f1: 0.8138\n",
1501 | "Epoch 65/300\n",
1502 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0039 - f1: 0.9997 - val_loss: 0.9493 - val_f1: 0.8135\n",
1503 | "Epoch 66/300\n",
1504 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0024 - f1: 1.0000 - val_loss: 0.9872 - val_f1: 0.8151\n",
1505 | "Epoch 67/300\n",
1506 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0058 - f1: 0.9989 - val_loss: 0.9106 - val_f1: 0.8146\n",
1507 | "Epoch 68/300\n",
1508 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0045 - f1: 0.9997 - val_loss: 0.9383 - val_f1: 0.8191\n",
1509 | "Epoch 69/300\n",
1510 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0018 - f1: 1.0000 - val_loss: 0.9366 - val_f1: 0.8184\n",
1511 | "Epoch 70/300\n",
1512 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0022 - f1: 1.0000 - val_loss: 1.0150 - val_f1: 0.8079\n",
1513 | "Epoch 71/300\n",
1514 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0018 - f1: 1.0000 - val_loss: 0.9735 - val_f1: 0.8136\n",
1515 | "Epoch 72/300\n",
1516 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0028 - f1: 1.0000 - val_loss: 0.9194 - val_f1: 0.8333\n",
1517 | "Epoch 73/300\n",
1518 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0016 - f1: 1.0000 - val_loss: 0.9224 - val_f1: 0.8321\n",
1519 | "Epoch 74/300\n",
1520 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 0.9445 - val_f1: 0.8249\n",
1521 | "Epoch 75/300\n",
1522 | "1839/1839 [==============================] - 6s 3ms/step - loss: 9.3293e-04 - f1: 1.0000 - val_loss: 0.9333 - val_f1: 0.8304\n",
1523 | "Epoch 76/300\n",
1524 | "1839/1839 [==============================] - 6s 3ms/step - loss: 9.6054e-04 - f1: 1.0000 - val_loss: 0.9360 - val_f1: 0.8264\n",
1525 | "Epoch 77/300\n",
1526 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0034 - f1: 0.9992 - val_loss: 0.9103 - val_f1: 0.8299\n",
1527 | "Epoch 78/300\n",
1528 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0018 - f1: 1.0000 - val_loss: 0.9219 - val_f1: 0.8332\n",
1529 | "Epoch 79/300\n",
1530 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 0.9166 - val_f1: 0.8372\n",
1531 | "Epoch 80/300\n",
1532 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.4605e-04 - f1: 1.0000 - val_loss: 0.9176 - val_f1: 0.8362\n",
1533 | "Epoch 81/300\n",
1534 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0015 - f1: 0.9995 - val_loss: 0.9786 - val_f1: 0.8179\n",
1535 | "Epoch 82/300\n",
1536 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 0.9525 - val_f1: 0.8336\n",
1537 | "Epoch 83/300\n",
1538 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 0.9509 - val_f1: 0.8278\n",
1539 | "Epoch 84/300\n",
1540 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.4968e-04 - f1: 1.0000 - val_loss: 0.9370 - val_f1: 0.8280\n",
1541 | "Epoch 85/300\n",
1542 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.7447e-04 - f1: 1.0000 - val_loss: 0.9853 - val_f1: 0.8197\n",
1543 | "Epoch 86/300\n",
1544 | "1839/1839 [==============================] - 5s 3ms/step - loss: 9.7422e-04 - f1: 1.0000 - val_loss: 0.9614 - val_f1: 0.8287\n",
1545 | "Epoch 87/300\n",
1546 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.7887e-04 - f1: 1.0000 - val_loss: 0.9664 - val_f1: 0.8255\n",
1547 | "Epoch 88/300\n",
1548 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.4960e-04 - f1: 1.0000 - val_loss: 0.9559 - val_f1: 0.8309\n",
1549 | "Epoch 89/300\n",
1550 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.1952e-04 - f1: 1.0000 - val_loss: 0.9667 - val_f1: 0.8296\n",
1551 | "Epoch 90/300\n",
1552 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.7391e-04 - f1: 1.0000 - val_loss: 0.9579 - val_f1: 0.8348\n",
1553 | "Epoch 91/300\n",
1554 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.7227e-04 - f1: 1.0000 - val_loss: 0.9605 - val_f1: 0.8326\n",
1555 | "Epoch 92/300\n",
1556 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.7311e-04 - f1: 1.0000 - val_loss: 0.9678 - val_f1: 0.8298\n",
1557 | "Epoch 93/300\n",
1558 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.2191e-04 - f1: 1.0000 - val_loss: 0.9664 - val_f1: 0.8352\n",
1559 | "Epoch 94/300\n",
1560 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.5526e-04 - f1: 1.0000 - val_loss: 0.9601 - val_f1: 0.8335\n",
1561 | "Epoch 95/300\n",
1562 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.0580e-04 - f1: 1.0000 - val_loss: 0.9861 - val_f1: 0.8351\n",
1563 | "Epoch 96/300\n",
1564 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 1.0249 - val_f1: 0.8131\n",
1565 | "Epoch 97/300\n",
1566 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.6099e-04 - f1: 1.0000 - val_loss: 1.0059 - val_f1: 0.8183\n",
1567 | "Epoch 98/300\n",
1568 | "1839/1839 [==============================] - 6s 3ms/step - loss: 8.8925e-04 - f1: 1.0000 - val_loss: 1.0204 - val_f1: 0.8181\n",
1569 | "Epoch 99/300\n",
1570 | "1839/1839 [==============================] - 6s 3ms/step - loss: 7.3515e-04 - f1: 1.0000 - val_loss: 1.0013 - val_f1: 0.8184\n",
1571 | "Epoch 100/300\n",
1572 | "1839/1839 [==============================] - 6s 3ms/step - loss: 8.3419e-04 - f1: 1.0000 - val_loss: 1.0341 - val_f1: 0.8164\n",
1573 | "Epoch 101/300\n",
1574 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0035 - f1: 0.9989 - val_loss: 0.9958 - val_f1: 0.8191\n",
1575 | "Epoch 102/300\n",
1576 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0053 - f1: 0.9989 - val_loss: 1.0122 - val_f1: 0.8265\n",
1577 | "Epoch 103/300\n",
1578 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0013 - f1: 1.0000 - val_loss: 1.0103 - val_f1: 0.8209\n",
1579 | "Epoch 104/300\n",
1580 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0053 - f1: 0.9989 - val_loss: 1.1036 - val_f1: 0.8036\n",
1581 | "Epoch 105/300\n",
1582 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0020 - f1: 1.0000 - val_loss: 1.0385 - val_f1: 0.8145\n",
1583 | "Epoch 106/300\n",
1584 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 1.0522 - val_f1: 0.8209\n",
1585 | "Epoch 107/300\n",
1586 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0010 - f1: 1.0000 - val_loss: 1.1168 - val_f1: 0.8123\n",
1587 | "Epoch 108/300\n",
1588 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0040 - f1: 0.9989 - val_loss: 1.1633 - val_f1: 0.8004\n",
1589 | "Epoch 109/300\n",
1590 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0034 - f1: 0.9995 - val_loss: 1.0392 - val_f1: 0.8165\n",
1591 | "Epoch 110/300\n",
1592 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0053 - f1: 0.9995 - val_loss: 0.9886 - val_f1: 0.8305\n",
1593 | "Epoch 111/300\n",
1594 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 1.0026 - val_f1: 0.8277\n",
1595 | "Epoch 112/300\n",
1596 | "1839/1839 [==============================] - 5s 3ms/step - loss: 9.9556e-04 - f1: 1.0000 - val_loss: 0.9939 - val_f1: 0.8292\n",
1597 | "Epoch 113/300\n",
1598 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0021 - f1: 0.9995 - val_loss: 1.0387 - val_f1: 0.8113\n",
1599 | "Epoch 114/300\n",
1600 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.0578e-04 - f1: 1.0000 - val_loss: 1.0518 - val_f1: 0.8205\n",
1601 | "Epoch 115/300\n",
1602 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.6815e-04 - f1: 1.0000 - val_loss: 1.0576 - val_f1: 0.8170\n",
1603 | "Epoch 116/300\n",
1604 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.7931e-04 - f1: 1.0000 - val_loss: 1.0544 - val_f1: 0.8152\n",
1605 | "Epoch 117/300\n",
1606 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.7737e-04 - f1: 1.0000 - val_loss: 1.0665 - val_f1: 0.8173\n",
1607 | "Epoch 118/300\n",
1608 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.3346e-04 - f1: 1.0000 - val_loss: 1.0667 - val_f1: 0.8168\n",
1609 | "Epoch 119/300\n",
1610 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.2447e-04 - f1: 1.0000 - val_loss: 1.0798 - val_f1: 0.8110\n",
1611 | "Epoch 120/300\n",
1612 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.6263e-04 - f1: 1.0000 - val_loss: 1.0843 - val_f1: 0.8114\n",
1613 | "Epoch 121/300\n",
1614 | "1839/1839 [==============================] - 5s 3ms/step - loss: 7.0956e-04 - f1: 1.0000 - val_loss: 1.0713 - val_f1: 0.8091\n",
1615 | "Epoch 122/300\n",
1616 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.3038e-04 - f1: 1.0000 - val_loss: 1.0521 - val_f1: 0.8218\n",
1617 | "Epoch 123/300\n",
1618 | "1839/1839 [==============================] - 5s 3ms/step - loss: 7.8382e-04 - f1: 1.0000 - val_loss: 1.0743 - val_f1: 0.8143\n",
1619 | "Epoch 124/300\n",
1620 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.2436e-04 - f1: 1.0000 - val_loss: 1.0851 - val_f1: 0.8174\n",
1621 | "Epoch 125/300\n",
1622 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.3625e-04 - f1: 1.0000 - val_loss: 1.1127 - val_f1: 0.8153\n",
1623 | "Epoch 126/300\n",
1624 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.4003e-04 - f1: 1.0000 - val_loss: 1.0980 - val_f1: 0.8134\n",
1625 | "Epoch 127/300\n",
1626 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.1112e-04 - f1: 1.0000 - val_loss: 1.0975 - val_f1: 0.8158\n",
1627 | "Epoch 128/300\n",
1628 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.7792e-04 - f1: 1.0000 - val_loss: 1.0916 - val_f1: 0.8216\n",
1629 | "Epoch 129/300\n",
1630 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.5881e-04 - f1: 1.0000 - val_loss: 1.0802 - val_f1: 0.8241\n",
1631 | "Epoch 130/300\n",
1632 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.0872e-04 - f1: 1.0000 - val_loss: 1.0879 - val_f1: 0.8220\n",
1633 | "Epoch 131/300\n",
1634 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.8239e-04 - f1: 1.0000 - val_loss: 1.0904 - val_f1: 0.8256\n",
1635 | "Epoch 132/300\n",
1636 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.1858e-04 - f1: 1.0000 - val_loss: 1.0941 - val_f1: 0.8200\n",
1637 | "Epoch 133/300\n",
1638 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.3054e-04 - f1: 1.0000 - val_loss: 1.0942 - val_f1: 0.8181\n",
1639 | "Epoch 134/300\n",
1640 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.2964e-04 - f1: 1.0000 - val_loss: 1.0947 - val_f1: 0.8209\n",
1641 | "Epoch 135/300\n",
1642 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.8267e-04 - f1: 1.0000 - val_loss: 1.0943 - val_f1: 0.8270\n",
1643 | "Epoch 136/300\n",
1644 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.2101e-04 - f1: 1.0000 - val_loss: 1.0975 - val_f1: 0.8256\n",
1645 | "Epoch 137/300\n",
1646 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.3244e-04 - f1: 1.0000 - val_loss: 1.0967 - val_f1: 0.8294\n",
1647 | "Epoch 138/300\n",
1648 | "1839/1839 [==============================] - 6s 3ms/step - loss: 9.5624e-05 - f1: 1.0000 - val_loss: 1.1052 - val_f1: 0.8288\n",
1649 | "Epoch 139/300\n",
1650 | "1839/1839 [==============================] - 6s 3ms/step - loss: 3.0982e-04 - f1: 1.0000 - val_loss: 1.1630 - val_f1: 0.8185\n",
1651 | "Epoch 140/300\n",
1652 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.4933e-04 - f1: 1.0000 - val_loss: 1.1253 - val_f1: 0.8189\n",
1653 | "Epoch 141/300\n",
1654 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.3635e-04 - f1: 1.0000 - val_loss: 1.1227 - val_f1: 0.8225\n",
1655 | "Epoch 142/300\n",
1656 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.1006e-04 - f1: 1.0000 - val_loss: 1.1184 - val_f1: 0.8251\n",
1657 | "Epoch 143/300\n",
1658 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.2128e-04 - f1: 1.0000 - val_loss: 1.1187 - val_f1: 0.8309\n",
1659 | "Epoch 144/300\n",
1660 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.1148e-04 - f1: 1.0000 - val_loss: 1.1111 - val_f1: 0.8275\n",
1661 | "Epoch 145/300\n",
1662 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.2230e-04 - f1: 1.0000 - val_loss: 1.1228 - val_f1: 0.8313\n",
1663 | "Epoch 146/300\n",
1664 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.0086e-04 - f1: 1.0000 - val_loss: 1.1273 - val_f1: 0.8293\n",
1665 | "Epoch 147/300\n",
1666 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.5006e-04 - f1: 1.0000 - val_loss: 1.1418 - val_f1: 0.8300\n",
1667 | "Epoch 148/300\n",
1668 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0024 - f1: 0.9995 - val_loss: 1.1586 - val_f1: 0.8162\n",
1669 | "Epoch 149/300\n",
1670 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.8215e-04 - f1: 1.0000 - val_loss: 1.0818 - val_f1: 0.8347\n",
1671 | "Epoch 150/300\n",
1672 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.4316e-04 - f1: 1.0000 - val_loss: 1.0817 - val_f1: 0.8312\n",
1673 | "Epoch 151/300\n",
1674 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.3920e-04 - f1: 1.0000 - val_loss: 1.0768 - val_f1: 0.8339\n",
1675 | "Epoch 152/300\n",
1676 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.1098e-04 - f1: 1.0000 - val_loss: 1.0783 - val_f1: 0.8360\n",
1677 | "Epoch 153/300\n",
1678 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.2897e-04 - f1: 1.0000 - val_loss: 1.0821 - val_f1: 0.8310\n",
1679 | "Epoch 154/300\n",
1680 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.6505e-04 - f1: 1.0000 - val_loss: 1.0777 - val_f1: 0.8308\n",
1681 | "Epoch 155/300\n",
1682 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0015 - f1: 0.9989 - val_loss: 1.1645 - val_f1: 0.8087\n",
1683 | "Epoch 156/300\n",
1684 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 1.0931 - val_f1: 0.8377\n",
1685 | "Epoch 157/300\n",
1686 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.3027e-04 - f1: 1.0000 - val_loss: 1.1130 - val_f1: 0.8289\n",
1687 | "Epoch 158/300\n",
1688 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.2346e-04 - f1: 1.0000 - val_loss: 1.1253 - val_f1: 0.8290\n",
1689 | "Epoch 159/300\n",
1690 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0015 - f1: 0.9995 - val_loss: 1.0921 - val_f1: 0.8284\n",
1691 | "Epoch 160/300\n",
1692 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0042 - f1: 0.9989 - val_loss: 1.1114 - val_f1: 0.8177\n",
1693 | "Epoch 161/300\n",
1694 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0021 - f1: 0.9995 - val_loss: 1.0611 - val_f1: 0.8253\n",
1695 | "Epoch 162/300\n",
1696 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.5093e-04 - f1: 1.0000 - val_loss: 1.0799 - val_f1: 0.8289\n",
1697 | "Epoch 163/300\n",
1698 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.7466e-04 - f1: 1.0000 - val_loss: 1.0400 - val_f1: 0.8241\n",
1699 | "Epoch 164/300\n",
1700 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.2356e-04 - f1: 1.0000 - val_loss: 1.0486 - val_f1: 0.8242\n",
1701 | "Epoch 165/300\n",
1702 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.2108e-04 - f1: 1.0000 - val_loss: 1.0388 - val_f1: 0.8244\n",
1703 | "Epoch 166/300\n",
1704 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.6586e-04 - f1: 1.0000 - val_loss: 1.0512 - val_f1: 0.8222\n",
1705 | "Epoch 167/300\n",
1706 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.1654e-04 - f1: 1.0000 - val_loss: 1.0569 - val_f1: 0.8237\n",
1707 | "Epoch 168/300\n",
1708 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.1511e-04 - f1: 1.0000 - val_loss: 1.0831 - val_f1: 0.8220\n",
1709 | "Epoch 169/300\n",
1710 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.7949e-04 - f1: 1.0000 - val_loss: 1.0916 - val_f1: 0.8185\n",
1711 | "Epoch 170/300\n",
1712 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.3409e-04 - f1: 1.0000 - val_loss: 1.0991 - val_f1: 0.8198\n",
1713 | "Epoch 171/300\n",
1714 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.0768e-04 - f1: 1.0000 - val_loss: 1.1065 - val_f1: 0.8176\n",
1715 | "Epoch 172/300\n",
1716 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.2211e-04 - f1: 1.0000 - val_loss: 1.1115 - val_f1: 0.8155\n",
1717 | "Epoch 173/300\n",
1718 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.1365e-04 - f1: 1.0000 - val_loss: 1.1072 - val_f1: 0.8128\n",
1719 | "Epoch 174/300\n",
1720 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.8997e-05 - f1: 1.0000 - val_loss: 1.1027 - val_f1: 0.8172\n",
1721 | "Epoch 175/300\n",
1722 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.0339e-04 - f1: 1.0000 - val_loss: 1.0983 - val_f1: 0.8200\n",
1723 | "Epoch 176/300\n",
1724 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.2617e-04 - f1: 1.0000 - val_loss: 1.1052 - val_f1: 0.8191\n",
1725 | "Epoch 177/300\n",
1726 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.2080e-05 - f1: 1.0000 - val_loss: 1.1063 - val_f1: 0.8178\n",
1727 | "Epoch 178/300\n",
1728 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.9384e-05 - f1: 1.0000 - val_loss: 1.1094 - val_f1: 0.8225\n",
1729 | "Epoch 179/300\n",
1730 | "1839/1839 [==============================] - 5s 3ms/step - loss: 7.3541e-05 - f1: 1.0000 - val_loss: 1.1141 - val_f1: 0.8212\n",
1731 | "Epoch 180/300\n",
1732 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.3658e-04 - f1: 1.0000 - val_loss: 1.1153 - val_f1: 0.8277\n",
1733 | "Epoch 181/300\n",
1734 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.2540e-05 - f1: 1.0000 - val_loss: 1.1166 - val_f1: 0.8268\n",
1735 | "Epoch 182/300\n",
1736 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.7952e-04 - f1: 1.0000 - val_loss: 1.1366 - val_f1: 0.8256\n",
1737 | "Epoch 183/300\n",
1738 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.6589e-05 - f1: 1.0000 - val_loss: 1.1376 - val_f1: 0.8260\n",
1739 | "Epoch 184/300\n",
1740 | "1839/1839 [==============================] - 5s 3ms/step - loss: 7.5782e-05 - f1: 1.0000 - val_loss: 1.1378 - val_f1: 0.8295\n",
1741 | "Epoch 185/300\n",
1742 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.7658e-05 - f1: 1.0000 - val_loss: 1.1392 - val_f1: 0.8247\n",
1743 | "Epoch 186/300\n",
1744 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.0105e-05 - f1: 1.0000 - val_loss: 1.1360 - val_f1: 0.8307\n",
1745 | "Epoch 187/300\n",
1746 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.7825e-05 - f1: 1.0000 - val_loss: 1.1352 - val_f1: 0.8317\n",
1747 | "Epoch 188/300\n",
1748 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.0959e-05 - f1: 1.0000 - val_loss: 1.1367 - val_f1: 0.8331\n",
1749 | "Epoch 189/300\n",
1750 | "1839/1839 [==============================] - 5s 3ms/step - loss: 7.8447e-05 - f1: 1.0000 - val_loss: 1.1430 - val_f1: 0.8286\n",
1751 | "Epoch 190/300\n",
1752 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.0137e-05 - f1: 1.0000 - val_loss: 1.1394 - val_f1: 0.8309\n",
1753 | "Epoch 191/300\n",
1754 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.1537e-05 - f1: 1.0000 - val_loss: 1.1423 - val_f1: 0.8339\n",
1755 | "Epoch 192/300\n",
1756 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.0265e-05 - f1: 1.0000 - val_loss: 1.1264 - val_f1: 0.8333\n",
1757 | "Epoch 193/300\n",
1758 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.7630e-05 - f1: 1.0000 - val_loss: 1.1351 - val_f1: 0.8315\n",
1759 | "Epoch 194/300\n",
1760 | "1839/1839 [==============================] - 5s 3ms/step - loss: 9.6157e-05 - f1: 1.0000 - val_loss: 1.1569 - val_f1: 0.8323\n",
1761 | "Epoch 195/300\n",
1762 | "1839/1839 [==============================] - 6s 3ms/step - loss: 4.0880e-05 - f1: 1.0000 - val_loss: 1.1556 - val_f1: 0.8376\n",
1763 | "Epoch 196/300\n",
1764 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.0660e-04 - f1: 1.0000 - val_loss: 1.1486 - val_f1: 0.8388\n",
1765 | "Epoch 197/300\n",
1766 | "1839/1839 [==============================] - 6s 3ms/step - loss: 4.7847e-05 - f1: 1.0000 - val_loss: 1.1383 - val_f1: 0.8360\n",
1767 | "Epoch 198/300\n",
1768 | "1839/1839 [==============================] - 6s 3ms/step - loss: 4.5485e-05 - f1: 1.0000 - val_loss: 1.1384 - val_f1: 0.8355\n",
1769 | "Epoch 199/300\n",
1770 | "1839/1839 [==============================] - 7s 4ms/step - loss: 6.6492e-05 - f1: 1.0000 - val_loss: 1.1453 - val_f1: 0.8361\n",
1771 | "Epoch 200/300\n",
1772 | "1839/1839 [==============================] - 7s 4ms/step - loss: 3.0428e-05 - f1: 1.0000 - val_loss: 1.1482 - val_f1: 0.8382\n",
1773 | "Epoch 201/300\n",
1774 | "1839/1839 [==============================] - 7s 4ms/step - loss: 3.2437e-05 - f1: 1.0000 - val_loss: 1.1498 - val_f1: 0.8383\n",
1775 | "Epoch 202/300\n",
1776 | "1839/1839 [==============================] - 7s 4ms/step - loss: 3.4816e-05 - f1: 1.0000 - val_loss: 1.1570 - val_f1: 0.8361\n",
1777 | "Epoch 203/300\n",
1778 | "1839/1839 [==============================] - 6s 4ms/step - loss: 3.9865e-05 - f1: 1.0000 - val_loss: 1.1833 - val_f1: 0.8354\n",
1779 | "Epoch 204/300\n",
1780 | "1839/1839 [==============================] - 6s 3ms/step - loss: 6.4216e-05 - f1: 1.0000 - val_loss: 1.1805 - val_f1: 0.8315\n",
1781 | "Epoch 205/300\n",
1782 | "1839/1839 [==============================] - 6s 3ms/step - loss: 4.5160e-05 - f1: 1.0000 - val_loss: 1.1617 - val_f1: 0.8412\n",
1783 | "Epoch 206/300\n",
1784 | "1839/1839 [==============================] - 6s 3ms/step - loss: 4.1668e-05 - f1: 1.0000 - val_loss: 1.1689 - val_f1: 0.8394\n",
1785 | "Epoch 207/300\n",
1786 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.4791e-05 - f1: 1.0000 - val_loss: 1.1736 - val_f1: 0.8394\n",
1787 | "Epoch 208/300\n",
1788 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.3336e-05 - f1: 1.0000 - val_loss: 1.1760 - val_f1: 0.8379\n",
1789 | "Epoch 209/300\n",
1790 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.2662e-05 - f1: 1.0000 - val_loss: 1.1767 - val_f1: 0.8362\n",
1791 | "Epoch 210/300\n",
1792 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.6002e-05 - f1: 1.0000 - val_loss: 1.1850 - val_f1: 0.8359\n",
1793 | "Epoch 211/300\n",
1794 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.9812e-05 - f1: 1.0000 - val_loss: 1.1848 - val_f1: 0.8351\n",
1795 | "Epoch 212/300\n",
1796 | "1839/1839 [==============================] - 6s 3ms/step - loss: 3.1796e-05 - f1: 1.0000 - val_loss: 1.1862 - val_f1: 0.8332\n",
1797 | "Epoch 213/300\n",
1798 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.5655e-05 - f1: 1.0000 - val_loss: 1.1919 - val_f1: 0.8299\n",
1799 | "Epoch 214/300\n",
1800 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.0894e-05 - f1: 1.0000 - val_loss: 1.1916 - val_f1: 0.8326\n",
1801 | "Epoch 215/300\n",
1802 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.0926e-05 - f1: 1.0000 - val_loss: 1.1944 - val_f1: 0.8318\n",
1803 | "Epoch 216/300\n",
1804 | " 60/1839 [..............................] - ETA: 5s - loss: 9.7486e-06 - f1: 1.0000"
1805 | ]
1806 | }
1807 | ],
1808 | "source": [
1809 | "print('Train...')\n",
1810 | "model.fit(x_train, y_train,\n",
1811 | " batch_size=batch_size,\n",
1812 | " epochs=epochs,\n",
1813 | " callbacks=[TensorBoard(log_dir='../logs/{}'.format(\"SMP2018_lstm_{}\".format(get_customization_time())))],\n",
1814 | " validation_split=0.2\n",
1815 | " )"
1816 | ]
1817 | },
1818 | {
1819 | "cell_type": "markdown",
1820 | "metadata": {},
1821 | "source": [
1822 | "# 评估模型"
1823 | ]
1824 | },
1825 | {
1826 | "cell_type": "code",
1827 | "execution_count": 53,
1828 | "metadata": {},
1829 | "outputs": [
1830 | {
1831 | "name": "stdout",
1832 | "output_type": "stream",
1833 | "text": [
1834 | "770/770 [==============================] - 0s 240us/step\n",
1835 | "Test score: 0.7415416103291821\n",
1836 | "Test f1: 0.8223602949798882\n"
1837 | ]
1838 | }
1839 | ],
1840 | "source": [
1841 | "score = model.evaluate(x_test, y_test,\n",
1842 | " batch_size=batch_size, verbose=1)\n",
1843 | "\n",
1844 | "print('Test score:', score[0])\n",
1845 | "print('Test f1:', score[1])"
1846 | ]
1847 | },
1848 | {
1849 | "cell_type": "code",
1850 | "execution_count": 54,
1851 | "metadata": {},
1852 | "outputs": [],
1853 | "source": [
1854 | "y_hat_test = model.predict(x_test)"
1855 | ]
1856 | },
1857 | {
1858 | "cell_type": "code",
1859 | "execution_count": 55,
1860 | "metadata": {},
1861 | "outputs": [
1862 | {
1863 | "name": "stdout",
1864 | "output_type": "stream",
1865 | "text": [
1866 | "(770, 31)\n"
1867 | ]
1868 | }
1869 | ],
1870 | "source": [
1871 | "print(y_hat_test.shape)"
1872 | ]
1873 | },
1874 | {
1875 | "cell_type": "markdown",
1876 | "metadata": {},
1877 | "source": [
1878 | "## 将 one-hot 张量转换成对应的整数"
1879 | ]
1880 | },
1881 | {
1882 | "cell_type": "code",
1883 | "execution_count": 56,
1884 | "metadata": {},
1885 | "outputs": [],
1886 | "source": [
1887 | "y_pred = np.argmax(y_hat_test, axis=1).tolist()"
1888 | ]
1889 | },
1890 | {
1891 | "cell_type": "code",
1892 | "execution_count": 57,
1893 | "metadata": {},
1894 | "outputs": [],
1895 | "source": [
1896 | "y_true = np.argmax(y_test, axis=1).tolist()"
1897 | ]
1898 | },
1899 | {
1900 | "cell_type": "markdown",
1901 | "metadata": {},
1902 | "source": [
1903 | "## 查看多分类的 准确率、召回率、F1 值"
1904 | ]
1905 | },
1906 | {
1907 | "cell_type": "code",
1908 | "execution_count": 58,
1909 | "metadata": {},
1910 | "outputs": [
1911 | {
1912 | "name": "stdout",
1913 | "output_type": "stream",
1914 | "text": [
1915 | " precision recall f1-score support\n",
1916 | "\n",
1917 | " 0 1.00 0.90 0.95 21\n",
1918 | " 1 0.86 0.75 0.80 8\n",
1919 | " 2 1.00 0.95 0.98 21\n",
1920 | " 3 0.52 0.57 0.54 23\n",
1921 | " 4 0.91 0.91 0.91 11\n",
1922 | " 5 0.82 0.97 0.89 34\n",
1923 | " 6 0.25 0.17 0.20 6\n",
1924 | " 7 0.86 0.86 0.86 22\n",
1925 | " 8 1.00 0.88 0.93 8\n",
1926 | " 9 0.89 1.00 0.94 8\n",
1927 | " 10 0.95 0.95 0.95 21\n",
1928 | " 11 1.00 0.62 0.77 8\n",
1929 | " 12 0.62 0.70 0.66 60\n",
1930 | " 13 0.86 0.90 0.88 20\n",
1931 | " 14 0.55 0.58 0.56 19\n",
1932 | " 15 0.76 0.78 0.77 36\n",
1933 | " 16 0.87 0.90 0.89 154\n",
1934 | " 17 0.57 0.50 0.53 8\n",
1935 | " 18 0.86 0.75 0.80 8\n",
1936 | " 19 0.74 0.95 0.83 21\n",
1937 | " 20 0.87 0.83 0.85 24\n",
1938 | " 21 1.00 0.75 0.86 8\n",
1939 | " 22 0.67 0.67 0.67 9\n",
1940 | " 23 1.00 1.00 1.00 8\n",
1941 | " 24 0.62 0.56 0.59 18\n",
1942 | " 25 0.92 1.00 0.96 24\n",
1943 | " 26 0.75 0.30 0.43 10\n",
1944 | " 27 0.71 0.55 0.62 22\n",
1945 | " 28 0.71 0.65 0.68 23\n",
1946 | " 29 0.73 0.61 0.67 18\n",
1947 | " 30 0.94 0.94 0.94 89\n",
1948 | "\n",
1949 | " micro avg 0.82 0.82 0.82 770\n",
1950 | " macro avg 0.80 0.76 0.77 770\n",
1951 | "weighted avg 0.82 0.82 0.81 770\n",
1952 | "\n"
1953 | ]
1954 | }
1955 | ],
1956 | "source": [
1957 | "print(classification_report(y_true, y_pred))"
1958 | ]
1959 | }
1960 | ],
1961 | "metadata": {
1962 | "kernelspec": {
1963 | "display_name": "Python 3",
1964 | "language": "python",
1965 | "name": "python3"
1966 | },
1967 | "language_info": {
1968 | "codemirror_mode": {
1969 | "name": "ipython",
1970 | "version": 3
1971 | },
1972 | "file_extension": ".py",
1973 | "mimetype": "text/x-python",
1974 | "name": "python",
1975 | "nbconvert_exporter": "python",
1976 | "pygments_lexer": "ipython3",
1977 | "version": "3.6.5"
1978 | }
1979 | },
1980 | "nbformat": 4,
1981 | "nbformat_minor": 2
1982 | }
1983 |
--------------------------------------------------------------------------------