├── preprocessed_data.npz ├── 用户意图分类APP ├── index_2_label_dict.pkl ├── word_2_index_dict.pkl ├── model_structure_json.pkl ├── SMP2018_GlobalAveragePooling1D_model(F1_86).h5 ├── app.py ├── APP说明和使用APP.ipynb └── SMP应用详细代码.ipynb ├── README.md ├── SMP2018_EDA_and_Baseline_Model(Keras).ipynb └── SMP2018_EDA_and_Baseline_Model.ipynb /preprocessed_data.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/SMP2018/HEAD/preprocessed_data.npz -------------------------------------------------------------------------------- /用户意图分类APP/index_2_label_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/SMP2018/HEAD/用户意图分类APP/index_2_label_dict.pkl -------------------------------------------------------------------------------- /用户意图分类APP/word_2_index_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/SMP2018/HEAD/用户意图分类APP/word_2_index_dict.pkl -------------------------------------------------------------------------------- /用户意图分类APP/model_structure_json.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/SMP2018/HEAD/用户意图分类APP/model_structure_json.pkl -------------------------------------------------------------------------------- /用户意图分类APP/SMP2018_GlobalAveragePooling1D_model(F1_86).h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanxiaosc/SMP2018/HEAD/用户意图分类APP/SMP2018_GlobalAveragePooling1D_model(F1_86).h5 -------------------------------------------------------------------------------- /用户意图分类APP/app.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from keras.models import model_from_json 4 | from keras.preprocessing.sequence import pad_sequences 5 | import jieba 6 | import pickle 7 | 8 | # 加载 pickle 对象的函数 9 | def load_obj(name ): 10 | with open(name + '.pkl', 'rb') as f: 11 | return pickle.load(f) 12 | 13 | # 输入模型的最终单句长度 14 | max_cut_query_lenth = 26 15 | 16 | # 加载查询词汇和对应 ID 的字典 17 | word_2_index_dict = load_obj('word_2_index_dict') 18 | # 加载模型输出 ID 和对应标签(种类)的字典 19 | index_2_label_dict = load_obj('index_2_label_dict') 20 | # 加载模型结构 21 | model_structure_json = load_obj('model_structure_json') 22 | model = model_from_json(model_structure_json) 23 | # 加载模型权重 24 | model.load_weights('SMP2018_GlobalAveragePooling1D_model(F1_86).h5') 25 | 26 | def query_2_label(query_sentence): 27 | ''' 28 | input query: "从中山到西安的汽车。" 29 | return label: "bus" 30 | ''' 31 | x_input = [] 32 | # 分词 ['从', '中山', '到', '西安', '的', '汽车', '。'] 33 | query_sentence_list = list(jieba.cut(query_sentence)) 34 | # 序列化 [54, 717, 0, 8, 0, 0, 1, 0, 183, 2] 35 | x = [word_2_index_dict.get(w, 0) for w in query_sentence] 36 | # 填充 array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 37 | # 0, 0, 0, 54, 717, 0, 8, 0, 0, 1, 0, 183, 2]], dtype=int32) 38 | x_input.append(x) 39 | x_input = pad_sequences(x_input, maxlen=max_cut_query_lenth) 40 | # 预测 41 | y_hat = model.predict(x_input) 42 | # 取最大值所在的序号 11 43 | pred_y_index = np.argmax(y_hat) 44 | # 查找序号所对应标签(类别) 45 | label = index_2_label_dict[pred_y_index] 46 | return label 47 | 48 | if __name__=="__main__": 49 | query_sentence = '狐臭怎么治?' 50 | print(query_2_label(query_sentence)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SMP2018 2 | > 通过SMP2018,展现处理中文文本分类的一般方法。特别是 [Keras 和中文分词工具 jieba 的联合使用](SMP2018_EDA_and_Baseline_Model(Keras).ipynb) 3 | 4 | SMP2018中文人机对话技术评测由中国中文信息学会社会媒体处理专委会主办,哈尔滨工业大学、科大讯飞股份有限公司承办,讯飞公司提供数据,华为公司提供奖金。旨在促进中文人机对话系统相关研究的发展,为人机对话技术相关的学术研究人员和产业界从业人员提供一个良好的沟通平台。在此,评测会务组诚邀各个单位参加本次人机对话技术评测活动! 5 | 6 | # 用户意图领域分类   7 | 8 | 在人机对话系统的应用过程中,用户可能会有多种意图,相应地会触发人机对话系统中的多个领域(domain) ,其中包括任务型垂直领域(如查询机票、酒店、公交车等)、知识型问答以及闲聊等。因而,人机对话系统的一个关键任务就是正确地将用户的输入分类到相应的领域(domain)中,从而才能返回正确的回复结果。 9 | 10 | **例如** 11 | 12 | 1) 你好啊,很高兴见到你! — 闲聊类 13 | 14 | 2) 我想订一张去北京的机票。 — 任务型垂类(订机票) 15 | 16 | 3) 我想找一家五道口附近便宜干净的快捷酒店 — 任务型垂类(订酒店) 17 | 18 | ## 相关资源 19 | 20 | |标题|说明| 21 | |-|-| 22 | |[CodaLab评测主页](https://worksheets.codalab.org/worksheets/0x27203f932f8341b79841d50ce0fd684f/)|[数据下载](https://worksheets.codalab.org/worksheets/0x27203f932f8341b79841d50ce0fd684f/#)| 23 | |[CodaLab 评测教程](https://worksheets.codalab.org/worksheets/0x1a7d7d33243c476984ff3d151c4977d4/)||20181010| 24 | |[评测排行榜](https://smp2018ecdt.github.io/Leader-board/)|| 25 | |[SMP2018-ECDT评测主页](http://smp2018.cips-smp.org/ecdt_index.html)|| 26 | |[SMP2018-ECDT评测成绩公告链接](https://mp.weixin.qq.com/s/_VHEuXzR7oXRTo5loqJp8A)|| 27 | 28 | 29 | # [SMP2018中文人机对话技术评测(ECDT)](http://smp2018.cips-smp.org/ecdt_index.html) 30 | 31 | 1. 本资源是一个完整的针对 [SMP2018中文人机对话技术评测(ECDT)](http://smp2018.cips-smp.org/ecdt_index.html) 的实验,由该实验训练的基线模型能达到评测排行榜的前三的水平。 32 | 2. 通过本实验,可以掌握处理自然语言文本数据的一般方法。 33 | 3. 推荐自己修改此文件,达到更好的实验效果,比如改变以下几个超参数 34 | 35 | ```python 36 | # 词嵌入的维度 37 | embedding_word_dims = 32 38 | # 批次大小 39 | batch_size = 30 40 | # 周期 41 | epochs = 20 42 | ``` 43 | 44 | ## 本实验还可以改进的地方举例 45 | 46 | 1. 预处理阶段使用其它的分词工具 47 | 2. 采用字符向量和词向量结合的方式 48 | 3. 使用预先训练好的词向量 49 | 4. 改变模型结构 50 | 5. 改变模型超参数 51 | 52 | ## 资源说明 53 | 54 | + [SMP2018_EDA_and_Baseline_Model.ipynb](SMP2018_EDA_and_Baseline_Model.ipynb) 是完整的数据分析和模型构建过程的代码 55 | + [app.py](用户意图分类APP/app.py) 是根据训练的模型构建的 用户意图分类应用 56 | + 其它资源见名知意 57 | -------------------------------------------------------------------------------- /用户意图分类APP/APP说明和使用APP.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 用户意图领域分类\n", 8 | "  在人机对话系统的应用过程中,用户可能会有多种意图,相应地会触发人机对话系统中的多个领域(domain) ,其中包括任务型垂直领域(如查询机票、酒店、公交车等)、知识型问答以及闲聊等。因而,人机对话系统的一个关键任务就是正确地将用户的输入分类到相应的领域(domain)中,从而才能返回正确的回复结果。" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "## 分类的类别说明" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "+ 包含闲聊和垂类两大类,其中垂类又细分为30个垂直领域。\n", 23 | "+ 本次评测任务1中,仅考虑针对单轮对话用户意图的领域分类,多轮对话整体意图的领域分类不在此次评测范围之内。" 24 | ] 25 | }, 26 | { 27 | "cell_type": "raw", 28 | "metadata": {}, 29 | "source": [ 30 | "类别 = ['website', 'tvchannel', 'lottery', 'chat', 'match',\n", 31 | " 'datetime', 'weather', 'bus', 'novel', 'video', 'riddle',\n", 32 | " 'calc', 'telephone', 'health', 'contacts', 'epg', 'app', 'music',\n", 33 | " 'cookbook', 'stock', 'map', 'message', 'poetry', 'cinemas', 'news',\n", 34 | " 'flight', 'translation', 'train', 'schedule', 'radio', 'email']" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "# 开始使用" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 1, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stderr", 51 | "output_type": "stream", 52 | "text": [ 53 | "Using TensorFlow backend.\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "from app import query_2_label" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 2, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stderr", 68 | "output_type": "stream", 69 | "text": [ 70 | "Building prefix dict from the default dictionary ...\n", 71 | "Loading model from cache /tmp/jieba.cache\n", 72 | "Loading model cost 0.945 seconds.\n", 73 | "Prefix dict has been built succesfully.\n" 74 | ] 75 | }, 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "'chat'" 80 | ] 81 | }, 82 | "execution_count": 2, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "query_2_label('我喜欢你')" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "# 运行下面代码进行查询,输入 0 结束查询" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdin", 105 | "output_type": "stream", 106 | "text": [ 107 | " 今天东莞天气如何\n" 108 | ] 109 | }, 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "----------\n", 115 | "predict label:\t datetime\n", 116 | "----------\n" 117 | ] 118 | }, 119 | { 120 | "name": "stdin", 121 | "output_type": "stream", 122 | "text": [ 123 | " 怎么治疗感冒?\n" 124 | ] 125 | }, 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | "----------\n", 131 | "predict label:\t health\n", 132 | "----------\n" 133 | ] 134 | }, 135 | { 136 | "name": "stdin", 137 | "output_type": "stream", 138 | "text": [ 139 | " 你好?\n" 140 | ] 141 | }, 142 | { 143 | "name": "stdout", 144 | "output_type": "stream", 145 | "text": [ 146 | "----------\n", 147 | "predict label:\t chat\n", 148 | "----------\n" 149 | ] 150 | } 151 | ], 152 | "source": [ 153 | "while True:\n", 154 | " your_query_sentence = input()\n", 155 | " print('-'*10)\n", 156 | " label = query_2_label(your_query_sentence)\n", 157 | " print('predict label:\\t', label)\n", 158 | " print('-'*10)\n", 159 | " if your_query_sentence=='0':\n", 160 | " break" 161 | ] 162 | } 163 | ], 164 | "metadata": { 165 | "kernelspec": { 166 | "display_name": "Python 3", 167 | "language": "python", 168 | "name": "python3" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.6.5" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 2 185 | } 186 | -------------------------------------------------------------------------------- /用户意图分类APP/SMP应用详细代码.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 导入相关库" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 37, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import pandas as pd\n", 18 | "from keras.models import model_from_json\n", 19 | "from keras.preprocessing.sequence import pad_sequences\n", 20 | "import jieba\n", 21 | "import pickle" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# 加载模型 SMP2018_model(F1_86)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 38, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# 加载 pickle 对象的函数\n", 38 | "def load_obj(name ):\n", 39 | " with open(name + '.pkl', 'rb') as f:\n", 40 | " return pickle.load(f)\n", 41 | " \n", 42 | "# 输入模型的最终单句长度\n", 43 | "max_cut_query_lenth = 26\n", 44 | "\n", 45 | "# 加载查询词汇和对应 ID 的字典\n", 46 | "word_2_index_dict = load_obj('word_2_index_dict')\n", 47 | "# 加载模型输出 ID 和对应标签(种类)的字典\n", 48 | "index_2_label_dict = load_obj('index_2_label_dict')\n", 49 | "# 加载模型结构\n", 50 | "model_structure_json = load_obj('model_structure_json')\n", 51 | "model = model_from_json(model_structure_json)\n", 52 | "# 加载模型权重\n", 53 | "model.load_weights('SMP2018_GlobalAveragePooling1D_model(F1_86).h5')" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "# 使用模型的函数" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 39, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "def query_2_label(query_sentence):\n", 70 | " '''\n", 71 | " input query: \"从中山到西安的汽车。\"\n", 72 | " return label: \"bus\"\n", 73 | " '''\n", 74 | " x_input = []\n", 75 | " # 分词 ['从', '中山', '到', '西安', '的', '汽车', '。']\n", 76 | " query_sentence_list = list(jieba.cut(query_sentence))\n", 77 | " # 序列化 [54, 717, 0, 8, 0, 0, 1, 0, 183, 2]\n", 78 | " x = [word_2_index_dict.get(w, 0) for w in query_sentence]\n", 79 | " # 填充 array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 80 | " # 0, 0, 0, 54, 717, 0, 8, 0, 0, 1, 0, 183, 2]], dtype=int32)\n", 81 | " x_input.append(x)\n", 82 | " x_input = pad_sequences(x_input, maxlen=max_cut_query_lenth)\n", 83 | " # 预测\n", 84 | " y_hat = model.predict(x_input)\n", 85 | " # 取最大值所在的序号 11\n", 86 | " pred_y_index = np.argmax(y_hat)\n", 87 | " # 查找序号所对应标签(类别)\n", 88 | " label = index_2_label_dict[pred_y_index]\n", 89 | " return label" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "# 使用例子" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 49, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "query_sentence = '狐臭怎么治?'\n", 106 | "\n", 107 | "print(query_2_label(query_sentence))" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "# 对 2299 条数据进行预测演示" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "## 获取数据" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 51, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "def get_json_data(path):\n", 131 | " # read data\n", 132 | " data_df = pd.read_json(path)\n", 133 | " # change row and colunm\n", 134 | " data_df = data_df.transpose()\n", 135 | " # change colunm order\n", 136 | " data_df = data_df[['query', 'label']]\n", 137 | " return data_df" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 52, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "data_df = get_json_data(path=\"../data/train.json\")" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 53, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "data": { 156 | "text/html": [ 157 | "
\n", 158 | "\n", 171 | "\n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | "
querylabel
count22992299
unique229931
top还是想知道你能做些什么chat
freq1455
\n", 202 | "
" 203 | ], 204 | "text/plain": [ 205 | " query label\n", 206 | "count 2299 2299\n", 207 | "unique 2299 31\n", 208 | "top 还是想知道你能做些什么 chat\n", 209 | "freq 1 455" 210 | ] 211 | }, 212 | "execution_count": 53, 213 | "metadata": {}, 214 | "output_type": "execute_result" 215 | } 216 | ], 217 | "source": [ 218 | "data_df.describe()" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "## 查看前 10 条数据 " 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 54, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "data": { 235 | "text/html": [ 236 | "
\n", 237 | "\n", 250 | "\n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | "
querylabel
0今天东莞天气如何weather
1从观音桥到重庆市图书馆怎么走map
2鸭蛋怎么腌?cookbook
3怎么治疗牛皮癣health
4唠什么chat
5阳澄湖大闸蟹的做法。cookbook
6昆山大润发在哪里map
7红烧肉怎么做?嗯?cookbook
8南京到厦门的火车票train
96的平方calc
\n", 311 | "
" 312 | ], 313 | "text/plain": [ 314 | " query label\n", 315 | "0 今天东莞天气如何 weather\n", 316 | "1 从观音桥到重庆市图书馆怎么走 map\n", 317 | "2 鸭蛋怎么腌? cookbook\n", 318 | "3 怎么治疗牛皮癣 health\n", 319 | "4 唠什么 chat\n", 320 | "5 阳澄湖大闸蟹的做法。 cookbook\n", 321 | "6 昆山大润发在哪里 map\n", 322 | "7 红烧肉怎么做?嗯? cookbook\n", 323 | "8 南京到厦门的火车票 train\n", 324 | "9 6的平方 calc" 325 | ] 326 | }, 327 | "execution_count": 54, 328 | "metadata": {}, 329 | "output_type": "execute_result" 330 | } 331 | ], 332 | "source": [ 333 | "data_df.head(10)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "metadata": {}, 339 | "source": [ 340 | "## 模型预测,并查看前 10 条数据" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 55, 346 | "metadata": {}, 347 | "outputs": [ 348 | { 349 | "data": { 350 | "text/html": [ 351 | "
\n", 352 | "\n", 365 | "\n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | "
querylabelmodel_prediction_label
0今天东莞天气如何weatherdatetime
1从观音桥到重庆市图书馆怎么走mapmap
2鸭蛋怎么腌?cookbookcookbook
3怎么治疗牛皮癣healthchat
4唠什么chatchat
5阳澄湖大闸蟹的做法。cookbookcookbook
6昆山大润发在哪里mapchat
7红烧肉怎么做?嗯?cookbookcookbook
8南京到厦门的火车票trainbus
96的平方calccalc
\n", 437 | "
" 438 | ], 439 | "text/plain": [ 440 | " query label model_prediction_label\n", 441 | "0 今天东莞天气如何 weather datetime\n", 442 | "1 从观音桥到重庆市图书馆怎么走 map map\n", 443 | "2 鸭蛋怎么腌? cookbook cookbook\n", 444 | "3 怎么治疗牛皮癣 health chat\n", 445 | "4 唠什么 chat chat\n", 446 | "5 阳澄湖大闸蟹的做法。 cookbook cookbook\n", 447 | "6 昆山大润发在哪里 map chat\n", 448 | "7 红烧肉怎么做?嗯? cookbook cookbook\n", 449 | "8 南京到厦门的火车票 train bus\n", 450 | "9 6的平方 calc calc" 451 | ] 452 | }, 453 | "execution_count": 55, 454 | "metadata": {}, 455 | "output_type": "execute_result" 456 | } 457 | ], 458 | "source": [ 459 | "data_df['model_prediction_label'] = data_df['query'].apply(query_2_label)\n", 460 | "\n", 461 | "data_df.head(10)" 462 | ] 463 | } 464 | ], 465 | "metadata": { 466 | "kernelspec": { 467 | "display_name": "Python 3", 468 | "language": "python", 469 | "name": "python3" 470 | }, 471 | "language_info": { 472 | "codemirror_mode": { 473 | "name": "ipython", 474 | "version": 3 475 | }, 476 | "file_extension": ".py", 477 | "mimetype": "text/x-python", 478 | "name": "python", 479 | "nbconvert_exporter": "python", 480 | "pygments_lexer": "ipython3", 481 | "version": "3.6.5" 482 | } 483 | }, 484 | "nbformat": 4, 485 | "nbformat_minor": 2 486 | } 487 | -------------------------------------------------------------------------------- /SMP2018_EDA_and_Baseline_Model(Keras).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# $$SMP2018中文人机对话技术评测(ECDT)$$" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "1. 下面是一个完整的针对 [SMP2018中文人机对话技术评测(ECDT)](http://smp2018.cips-smp.org/ecdt_index.html) 的实验,由该实验训练的基线模型能达到评测排行榜的前三的水平。\n", 15 | "2. 通过本实验,可以掌握处理自然语言文本数据的一般方法。\n", 16 | "3. 推荐自己修改此文件,达到更好的实验效果,比如改变以下几个超参数 " 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "```python\n", 24 | "# 词嵌入的维度\n", 25 | "embedding_word_dims = 32\n", 26 | "# 批次大小\n", 27 | "batch_size = 30\n", 28 | "# 周期\n", 29 | "epochs = 20\n", 30 | "```" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "# 本实验还可以改进的地方举例 " 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "1. 预处理阶段使用其它的分词工具\n", 45 | "2. 采用字符向量和词向量结合的方式\n", 46 | "3. 使用预先训练好的词向量\n", 47 | "4. 改变模型结构\n", 48 | "5. 改变模型超参数" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "# 导入依赖库" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 1, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "Using TensorFlow backend.\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "import numpy as np\n", 73 | "import pandas as pd\n", 74 | "import collections\n", 75 | "import jieba\n", 76 | "from keras.preprocessing.text import Tokenizer\n", 77 | "from keras.preprocessing.sequence import pad_sequences\n", 78 | "from keras.models import Sequential\n", 79 | "from keras.layers import Embedding, LSTM, Dense\n", 80 | "from keras.utils import to_categorical,plot_model\n", 81 | "from keras.callbacks import TensorBoard, Callback\n", 82 | "\n", 83 | "from sklearn.metrics import classification_report\n", 84 | "\n", 85 | "import requests \n", 86 | "\n", 87 | "import time\n", 88 | "\n", 89 | "import os" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "# 辅助函数" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 2, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "from keras import backend as K\n", 106 | "\n", 107 | "# 计算 F1 值的函数\n", 108 | "def f1(y_true, y_pred):\n", 109 | " def recall(y_true, y_pred):\n", 110 | " \"\"\"Recall metric.\n", 111 | "\n", 112 | " Only computes a batch-wise average of recall.\n", 113 | "\n", 114 | " Computes the recall, a metric for multi-label classification of\n", 115 | " how many relevant items are selected.\n", 116 | " \"\"\"\n", 117 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n", 118 | " possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))\n", 119 | " recall = true_positives / (possible_positives + K.epsilon())\n", 120 | " return recall\n", 121 | "\n", 122 | " def precision(y_true, y_pred):\n", 123 | " \"\"\"Precision metric.\n", 124 | "\n", 125 | " Only computes a batch-wise average of precision.\n", 126 | "\n", 127 | " Computes the precision, a metric for multi-label classification of\n", 128 | " how many selected items are relevant.\n", 129 | " \"\"\"\n", 130 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n", 131 | " predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))\n", 132 | " precision = true_positives / (predicted_positives + K.epsilon())\n", 133 | " return precision\n", 134 | " precision = precision(y_true, y_pred)\n", 135 | " recall = recall(y_true, y_pred)\n", 136 | " return 2*((precision*recall)/(precision+recall+K.epsilon()))" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 3, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "# 获取自定义时间格式的字符串\n", 146 | "def get_customization_time():\n", 147 | " # return '2018_10_10_18_11_45' 年月日时分秒\n", 148 | " time_tuple = time.localtime(time.time())\n", 149 | " customization_time = \"{}_{}_{}_{}_{}_{}\".format(time_tuple[0], time_tuple[1], time_tuple[2], time_tuple[3], time_tuple[4], time_tuple[5])\n", 150 | " return customization_time" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "# 准备数据" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "## [下载SMP2018官方数据](https://worksheets.codalab.org/worksheets/0x27203f932f8341b79841d50ce0fd684f/)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 4, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "raw_train_data_url = \"https://worksheets.codalab.org/rest/bundles/0x0161fd2fb40d4dd48541c2643d04b0b8/contents/blob/\"\n", 174 | "raw_test_data_url = \"https://worksheets.codalab.org/rest/bundles/0x1f96bc12222641209ad057e762910252/contents/blob/\"\n", 175 | "\n", 176 | "# 如果不存在 SMP2018 数据,则下载\n", 177 | "if (not os.path.exists('./data/train.json')) or (not os.path.exists('./data/dev.json')):\n", 178 | " raw_train = requests.get(raw_train_data_url) \n", 179 | " raw_test = requests.get(raw_test_data_url) \n", 180 | " if not os.path.exists('./data'):\n", 181 | " os.makedirs('./data')\n", 182 | " with open(\"./data/train.json\", \"wb\") as code:\n", 183 | " code.write(raw_train.content)\n", 184 | " with open(\"./data/dev.json\", \"wb\") as code:\n", 185 | " code.write(raw_test.content)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 5, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "def get_json_data(path):\n", 195 | " # read data\n", 196 | " data_df = pd.read_json(path)\n", 197 | " # change row and colunm\n", 198 | " data_df = data_df.transpose()\n", 199 | " # change colunm order\n", 200 | " data_df = data_df[['query', 'label']]\n", 201 | " return data_df" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 6, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "train_data_df = get_json_data(path=\"data/train.json\")\n", 211 | "\n", 212 | "test_data_df = get_json_data(path=\"data/dev.json\")" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 7, 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "data": { 222 | "text/html": [ 223 | "
\n", 224 | "\n", 237 | "\n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | "
querylabel
0今天东莞天气如何weather
1从观音桥到重庆市图书馆怎么走map
2鸭蛋怎么腌?cookbook
3怎么治疗牛皮癣health
4唠什么chat
\n", 273 | "
" 274 | ], 275 | "text/plain": [ 276 | " query label\n", 277 | "0 今天东莞天气如何 weather\n", 278 | "1 从观音桥到重庆市图书馆怎么走 map\n", 279 | "2 鸭蛋怎么腌? cookbook\n", 280 | "3 怎么治疗牛皮癣 health\n", 281 | "4 唠什么 chat" 282 | ] 283 | }, 284 | "execution_count": 7, 285 | "metadata": {}, 286 | "output_type": "execute_result" 287 | } 288 | ], 289 | "source": [ 290 | "train_data_df.head()" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "metadata": {}, 296 | "source": [ 297 | "---" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "## [结巴分词](https://github.com/fxsjy/jieba)示例,下面将使用结巴分词对原数据进行处理" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 8, 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "name": "stderr", 314 | "output_type": "stream", 315 | "text": [ 316 | "Building prefix dict from the default dictionary ...\n", 317 | "Loading model from cache /tmp/jieba.cache\n", 318 | "Loading model cost 1.022 seconds.\n", 319 | "Prefix dict has been built succesfully.\n" 320 | ] 321 | }, 322 | { 323 | "name": "stdout", 324 | "output_type": "stream", 325 | "text": [ 326 | "['他', '来到', '了', '网易', '杭研', '大厦']\n" 327 | ] 328 | } 329 | ], 330 | "source": [ 331 | "seg_list = jieba.cut(\"他来到了网易杭研大厦\") # 默认是精确模式\n", 332 | "print(list(seg_list))" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "---" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "# 序列化" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 9, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "def use_jieba_cut(a_sentence):\n", 356 | " return list(jieba.cut(a_sentence))\n", 357 | "\n", 358 | "train_data_df['cut_query'] = train_data_df['query'].apply(use_jieba_cut)\n", 359 | "test_data_df['cut_query'] = test_data_df['query'].apply(use_jieba_cut)" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 10, 365 | "metadata": {}, 366 | "outputs": [ 367 | { 368 | "data": { 369 | "text/html": [ 370 | "
\n", 371 | "\n", 384 | "\n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | "
querylabelcut_query
0今天东莞天气如何weather[今天, 东莞, 天气, 如何]
1从观音桥到重庆市图书馆怎么走map[从, 观音桥, 到, 重庆市, 图书馆, 怎么, 走]
2鸭蛋怎么腌?cookbook[鸭蛋, 怎么, 腌, ?]
3怎么治疗牛皮癣health[怎么, 治疗, 牛皮癣]
4唠什么chat[唠, 什么]
5阳澄湖大闸蟹的做法。cookbook[阳澄湖, 大闸蟹, 的, 做法, 。]
6昆山大润发在哪里map[昆山, 大润发, 在, 哪里]
7红烧肉怎么做?嗯?cookbook[红烧肉, 怎么, 做, ?, 嗯, ?]
8南京到厦门的火车票train[南京, 到, 厦门, 的, 火车票]
96的平方calc[6, 的, 平方]
\n", 456 | "
" 457 | ], 458 | "text/plain": [ 459 | " query label cut_query\n", 460 | "0 今天东莞天气如何 weather [今天, 东莞, 天气, 如何]\n", 461 | "1 从观音桥到重庆市图书馆怎么走 map [从, 观音桥, 到, 重庆市, 图书馆, 怎么, 走]\n", 462 | "2 鸭蛋怎么腌? cookbook [鸭蛋, 怎么, 腌, ?]\n", 463 | "3 怎么治疗牛皮癣 health [怎么, 治疗, 牛皮癣]\n", 464 | "4 唠什么 chat [唠, 什么]\n", 465 | "5 阳澄湖大闸蟹的做法。 cookbook [阳澄湖, 大闸蟹, 的, 做法, 。]\n", 466 | "6 昆山大润发在哪里 map [昆山, 大润发, 在, 哪里]\n", 467 | "7 红烧肉怎么做?嗯? cookbook [红烧肉, 怎么, 做, ?, 嗯, ?]\n", 468 | "8 南京到厦门的火车票 train [南京, 到, 厦门, 的, 火车票]\n", 469 | "9 6的平方 calc [6, 的, 平方]" 470 | ] 471 | }, 472 | "execution_count": 10, 473 | "metadata": {}, 474 | "output_type": "execute_result" 475 | } 476 | ], 477 | "source": [ 478 | "train_data_df.head(10)" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "## 处理特征" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 11, 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [ 494 | "tokenizer = Tokenizer()" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 12, 500 | "metadata": {}, 501 | "outputs": [], 502 | "source": [ 503 | "tokenizer.fit_on_texts(train_data_df['cut_query'])" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 13, 509 | "metadata": {}, 510 | "outputs": [ 511 | { 512 | "data": { 513 | "text/plain": [ 514 | "2883" 515 | ] 516 | }, 517 | "execution_count": 13, 518 | "metadata": {}, 519 | "output_type": "execute_result" 520 | } 521 | ], 522 | "source": [ 523 | "max_features = len(tokenizer.index_word)\n", 524 | "\n", 525 | "len(tokenizer.index_word)" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 14, 531 | "metadata": {}, 532 | "outputs": [], 533 | "source": [ 534 | "x_train = tokenizer.texts_to_sequences(train_data_df['cut_query'])\n", 535 | "\n", 536 | "x_test = tokenizer.texts_to_sequences(test_data_df['cut_query'])" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": 15, 542 | "metadata": {}, 543 | "outputs": [], 544 | "source": [ 545 | "max_cut_query_lenth = 26" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 16, 551 | "metadata": {}, 552 | "outputs": [], 553 | "source": [ 554 | "x_train = pad_sequences(x_train, max_cut_query_lenth)\n", 555 | "\n", 556 | "x_test = pad_sequences(x_test, max_cut_query_lenth)" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": 17, 562 | "metadata": {}, 563 | "outputs": [ 564 | { 565 | "data": { 566 | "text/plain": [ 567 | "(2299, 26)" 568 | ] 569 | }, 570 | "execution_count": 17, 571 | "metadata": {}, 572 | "output_type": "execute_result" 573 | } 574 | ], 575 | "source": [ 576 | "x_train.shape" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 18, 582 | "metadata": {}, 583 | "outputs": [ 584 | { 585 | "data": { 586 | "text/plain": [ 587 | "(770, 26)" 588 | ] 589 | }, 590 | "execution_count": 18, 591 | "metadata": {}, 592 | "output_type": "execute_result" 593 | } 594 | ], 595 | "source": [ 596 | "x_test.shape" 597 | ] 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "metadata": {}, 602 | "source": [ 603 | "## 处理标签" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": 19, 609 | "metadata": {}, 610 | "outputs": [], 611 | "source": [ 612 | "label_tokenizer = Tokenizer()" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": 20, 618 | "metadata": {}, 619 | "outputs": [], 620 | "source": [ 621 | "label_tokenizer.fit_on_texts(train_data_df['label'])" 622 | ] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "execution_count": 21, 627 | "metadata": {}, 628 | "outputs": [], 629 | "source": [ 630 | "label_numbers = len(label_tokenizer.word_counts)" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": 22, 636 | "metadata": {}, 637 | "outputs": [], 638 | "source": [ 639 | "NUM_CLASSES = len(label_tokenizer.word_counts)" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": 23, 645 | "metadata": {}, 646 | "outputs": [ 647 | { 648 | "data": { 649 | "text/plain": [ 650 | "OrderedDict([('weather', 66),\n", 651 | " ('map', 68),\n", 652 | " ('cookbook', 269),\n", 653 | " ('health', 55),\n", 654 | " ('chat', 455),\n", 655 | " ('train', 70),\n", 656 | " ('calc', 24),\n", 657 | " ('translation', 61),\n", 658 | " ('music', 66),\n", 659 | " ('tvchannel', 71),\n", 660 | " ('poetry', 102),\n", 661 | " ('telephone', 63),\n", 662 | " ('stock', 71),\n", 663 | " ('radio', 24),\n", 664 | " ('contacts', 30),\n", 665 | " ('lottery', 24),\n", 666 | " ('website', 54),\n", 667 | " ('video', 182),\n", 668 | " ('news', 58),\n", 669 | " ('bus', 24),\n", 670 | " ('app', 53),\n", 671 | " ('flight', 62),\n", 672 | " ('epg', 107),\n", 673 | " ('message', 63),\n", 674 | " ('match', 24),\n", 675 | " ('schedule', 29),\n", 676 | " ('novel', 24),\n", 677 | " ('riddle', 34),\n", 678 | " ('email', 24),\n", 679 | " ('datetime', 18),\n", 680 | " ('cinemas', 24)])" 681 | ] 682 | }, 683 | "execution_count": 23, 684 | "metadata": {}, 685 | "output_type": "execute_result" 686 | } 687 | ], 688 | "source": [ 689 | "label_tokenizer.word_counts" 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 24, 695 | "metadata": {}, 696 | "outputs": [], 697 | "source": [ 698 | "y_train = label_tokenizer.texts_to_sequences(train_data_df['label'])" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 25, 704 | "metadata": {}, 705 | "outputs": [ 706 | { 707 | "data": { 708 | "text/plain": [ 709 | "[[10], [9], [2], [17], [1], [2], [9], [2], [8], [23]]" 710 | ] 711 | }, 712 | "execution_count": 25, 713 | "metadata": {}, 714 | "output_type": "execute_result" 715 | } 716 | ], 717 | "source": [ 718 | "y_train[:10]" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": 26, 724 | "metadata": {}, 725 | "outputs": [], 726 | "source": [ 727 | "y_train = [[y[0]-1] for y in y_train]" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": 27, 733 | "metadata": {}, 734 | "outputs": [ 735 | { 736 | "data": { 737 | "text/plain": [ 738 | "[[9], [8], [1], [16], [0], [1], [8], [1], [7], [22]]" 739 | ] 740 | }, 741 | "execution_count": 27, 742 | "metadata": {}, 743 | "output_type": "execute_result" 744 | } 745 | ], 746 | "source": [ 747 | "y_train[:10]" 748 | ] 749 | }, 750 | { 751 | "cell_type": "code", 752 | "execution_count": 28, 753 | "metadata": {}, 754 | "outputs": [ 755 | { 756 | "data": { 757 | "text/plain": [ 758 | "(2299, 31)" 759 | ] 760 | }, 761 | "execution_count": 28, 762 | "metadata": {}, 763 | "output_type": "execute_result" 764 | } 765 | ], 766 | "source": [ 767 | "y_train = to_categorical(y_train, label_numbers)\n", 768 | "y_train.shape" 769 | ] 770 | }, 771 | { 772 | "cell_type": "code", 773 | "execution_count": 29, 774 | "metadata": {}, 775 | "outputs": [ 776 | { 777 | "data": { 778 | "text/plain": [ 779 | "(770, 31)" 780 | ] 781 | }, 782 | "execution_count": 29, 783 | "metadata": {}, 784 | "output_type": "execute_result" 785 | } 786 | ], 787 | "source": [ 788 | "y_test = label_tokenizer.texts_to_sequences(test_data_df['label'])\n", 789 | "y_test = [y[0]-1 for y in y_test]\n", 790 | "y_test = to_categorical(y_test, label_numbers)\n", 791 | "y_test.shape" 792 | ] 793 | }, 794 | { 795 | "cell_type": "code", 796 | "execution_count": 30, 797 | "metadata": {}, 798 | "outputs": [ 799 | { 800 | "data": { 801 | "text/plain": [ 802 | "array([0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 803 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 804 | " dtype=float32)" 805 | ] 806 | }, 807 | "execution_count": 30, 808 | "metadata": {}, 809 | "output_type": "execute_result" 810 | } 811 | ], 812 | "source": [ 813 | "y_test[0]" 814 | ] 815 | }, 816 | { 817 | "cell_type": "markdown", 818 | "metadata": {}, 819 | "source": [ 820 | "# 设计模型" 821 | ] 822 | }, 823 | { 824 | "cell_type": "code", 825 | "execution_count": 45, 826 | "metadata": {}, 827 | "outputs": [], 828 | "source": [ 829 | "def create_SMP2018_lstm_model(max_features, max_cut_query_lenth, label_numbers):\n", 830 | " model = Sequential()\n", 831 | " model.add(Embedding(input_dim=max_features+1, output_dim=32, input_length=max_cut_query_lenth))\n", 832 | " model.add(LSTM(units=64, dropout=0.2, recurrent_dropout=0.2))\n", 833 | " model.add(Dense(label_numbers, activation='softmax'))\n", 834 | " # try using different optimizers and different optimizer configs\n", 835 | " model.compile(loss='categorical_crossentropy',\n", 836 | " optimizer='adam',\n", 837 | " metrics=[f1])\n", 838 | "\n", 839 | " plot_model(model, to_file='SMP2018_lstm_model.png', show_shapes=True)\n", 840 | " \n", 841 | " return model" 842 | ] 843 | }, 844 | { 845 | "cell_type": "markdown", 846 | "metadata": {}, 847 | "source": [ 848 | "# 训练模型" 849 | ] 850 | }, 851 | { 852 | "cell_type": "code", 853 | "execution_count": 46, 854 | "metadata": {}, 855 | "outputs": [], 856 | "source": [ 857 | "if 'max_features' not in dir():\n", 858 | " max_features = 2888\n", 859 | " print('not find max_features variable, use default max_features values:\\t{}'.format(max_features))\n", 860 | "if 'max_cut_query_lenth' not in dir():\n", 861 | " max_cut_query_lenth = 26\n", 862 | " print('not find max_cut_query_lenth, use default max_features values:\\t{}'.format(max_cut_query_lenth))\n", 863 | "if 'label_numbers' not in dir():\n", 864 | " label_numbers = 31\n", 865 | " print('not find label_numbers, use default max_features values:\\t{}'.format(label_numbers))" 866 | ] 867 | }, 868 | { 869 | "cell_type": "code", 870 | "execution_count": 47, 871 | "metadata": {}, 872 | "outputs": [], 873 | "source": [ 874 | "model = create_SMP2018_lstm_model(max_features, max_cut_query_lenth, label_numbers)" 875 | ] 876 | }, 877 | { 878 | "cell_type": "code", 879 | "execution_count": 48, 880 | "metadata": {}, 881 | "outputs": [], 882 | "source": [ 883 | "batch_size = 20\n", 884 | "epochs = 30" 885 | ] 886 | }, 887 | { 888 | "cell_type": "code", 889 | "execution_count": 49, 890 | "metadata": {}, 891 | "outputs": [ 892 | { 893 | "name": "stdout", 894 | "output_type": "stream", 895 | "text": [ 896 | "(2299, 26) (2299, 31)\n" 897 | ] 898 | } 899 | ], 900 | "source": [ 901 | "print(x_train.shape, y_train.shape)" 902 | ] 903 | }, 904 | { 905 | "cell_type": "code", 906 | "execution_count": 50, 907 | "metadata": {}, 908 | "outputs": [ 909 | { 910 | "name": "stdout", 911 | "output_type": "stream", 912 | "text": [ 913 | "(770, 26) (770, 31)\n" 914 | ] 915 | } 916 | ], 917 | "source": [ 918 | "print(x_test.shape, y_test.shape)" 919 | ] 920 | }, 921 | { 922 | "cell_type": "code", 923 | "execution_count": 51, 924 | "metadata": {}, 925 | "outputs": [ 926 | { 927 | "name": "stdout", 928 | "output_type": "stream", 929 | "text": [ 930 | "Train...\n", 931 | "Epoch 1/30\n", 932 | "2299/2299 [==============================] - 16s 7ms/step - loss: 3.0916 - f1: 0.0000e+00\n", 933 | "Epoch 2/30\n", 934 | "2299/2299 [==============================] - 14s 6ms/step - loss: 2.6594 - f1: 0.1409\n", 935 | "Epoch 3/30\n", 936 | "2299/2299 [==============================] - 13s 6ms/step - loss: 2.0817 - f1: 0.4055\n", 937 | "Epoch 4/30\n", 938 | "2299/2299 [==============================] - 14s 6ms/step - loss: 1.6032 - f1: 0.4689\n", 939 | "Epoch 5/30\n", 940 | "2299/2299 [==============================] - 14s 6ms/step - loss: 1.1318 - f1: 0.6176\n", 941 | "Epoch 6/30\n", 942 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.8090 - f1: 0.7399\n", 943 | "Epoch 7/30\n", 944 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.5704 - f1: 0.8298\n", 945 | "Epoch 8/30\n", 946 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.4051 - f1: 0.8879\n", 947 | "Epoch 9/30\n", 948 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.3002 - f1: 0.9280\n", 949 | "Epoch 10/30\n", 950 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.2317 - f1: 0.9467\n", 951 | "Epoch 11/30\n", 952 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.1755 - f1: 0.9678\n", 953 | "Epoch 12/30\n", 954 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.1391 - f1: 0.9758\n", 955 | "Epoch 13/30\n", 956 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.1131 - f1: 0.9800\n", 957 | "Epoch 14/30\n", 958 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0883 - f1: 0.9861\n", 959 | "Epoch 15/30\n", 960 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0725 - f1: 0.9894\n", 961 | "Epoch 16/30\n", 962 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0615 - f1: 0.9929\n", 963 | "Epoch 17/30\n", 964 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0507 - f1: 0.9945\n", 965 | "Epoch 18/30\n", 966 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0455 - f1: 0.9963\n", 967 | "Epoch 19/30\n", 968 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0398 - f1: 0.9960\n", 969 | "Epoch 20/30\n", 970 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0313 - f1: 0.9978\n", 971 | "Epoch 21/30\n", 972 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0266 - f1: 0.9984\n", 973 | "Epoch 22/30\n", 974 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0279 - f1: 0.9965\n", 975 | "Epoch 23/30\n", 976 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0250 - f1: 0.9976\n", 977 | "Epoch 24/30\n", 978 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0219 - f1: 0.9982\n", 979 | "Epoch 25/30\n", 980 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0195 - f1: 0.9982\n", 981 | "Epoch 26/30\n", 982 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0179 - f1: 0.9989\n", 983 | "Epoch 27/30\n", 984 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0177 - f1: 0.9974\n", 985 | "Epoch 28/30\n", 986 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0139 - f1: 0.9987\n", 987 | "Epoch 29/30\n", 988 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0139 - f1: 0.9989\n", 989 | "Epoch 30/30\n", 990 | "2299/2299 [==============================] - 14s 6ms/step - loss: 0.0129 - f1: 0.9987\n" 991 | ] 992 | }, 993 | { 994 | "data": { 995 | "text/plain": [ 996 | "" 997 | ] 998 | }, 999 | "execution_count": 51, 1000 | "metadata": {}, 1001 | "output_type": "execute_result" 1002 | } 1003 | ], 1004 | "source": [ 1005 | "print('Train...')\n", 1006 | "model.fit(x_train, y_train,\n", 1007 | " batch_size=batch_size,\n", 1008 | " epochs=epochs)" 1009 | ] 1010 | }, 1011 | { 1012 | "cell_type": "markdown", 1013 | "metadata": {}, 1014 | "source": [ 1015 | "# 评估模型" 1016 | ] 1017 | }, 1018 | { 1019 | "cell_type": "code", 1020 | "execution_count": 52, 1021 | "metadata": {}, 1022 | "outputs": [ 1023 | { 1024 | "name": "stdout", 1025 | "output_type": "stream", 1026 | "text": [ 1027 | "770/770 [==============================] - 1s 1ms/step\n", 1028 | "Test score: 0.6803552009068526\n", 1029 | "Test f1: 0.8464262740952628\n" 1030 | ] 1031 | } 1032 | ], 1033 | "source": [ 1034 | "score = model.evaluate(x_test, y_test,\n", 1035 | " batch_size=batch_size, verbose=1)\n", 1036 | "\n", 1037 | "print('Test score:', score[0])\n", 1038 | "print('Test f1:', score[1])" 1039 | ] 1040 | }, 1041 | { 1042 | "cell_type": "code", 1043 | "execution_count": 53, 1044 | "metadata": {}, 1045 | "outputs": [], 1046 | "source": [ 1047 | "y_hat_test = model.predict(x_test)" 1048 | ] 1049 | }, 1050 | { 1051 | "cell_type": "code", 1052 | "execution_count": 55, 1053 | "metadata": {}, 1054 | "outputs": [ 1055 | { 1056 | "name": "stdout", 1057 | "output_type": "stream", 1058 | "text": [ 1059 | "(770, 31)\n" 1060 | ] 1061 | } 1062 | ], 1063 | "source": [ 1064 | "print(y_hat_test.shape)" 1065 | ] 1066 | }, 1067 | { 1068 | "cell_type": "markdown", 1069 | "metadata": {}, 1070 | "source": [ 1071 | "## 将 one-hot 张量转换成对应的整数" 1072 | ] 1073 | }, 1074 | { 1075 | "cell_type": "code", 1076 | "execution_count": 54, 1077 | "metadata": {}, 1078 | "outputs": [], 1079 | "source": [ 1080 | "y_pred = np.argmax(y_hat_test, axis=1).tolist()" 1081 | ] 1082 | }, 1083 | { 1084 | "cell_type": "code", 1085 | "execution_count": 55, 1086 | "metadata": {}, 1087 | "outputs": [], 1088 | "source": [ 1089 | "y_true = np.argmax(y_test, axis=1).tolist()" 1090 | ] 1091 | }, 1092 | { 1093 | "cell_type": "markdown", 1094 | "metadata": {}, 1095 | "source": [ 1096 | "## 查看多分类的 准确率、召回率、F1 值" 1097 | ] 1098 | }, 1099 | { 1100 | "cell_type": "code", 1101 | "execution_count": 56, 1102 | "metadata": {}, 1103 | "outputs": [ 1104 | { 1105 | "name": "stdout", 1106 | "output_type": "stream", 1107 | "text": [ 1108 | " precision recall f1-score support\n", 1109 | "\n", 1110 | " 0 0.78 0.93 0.85 154\n", 1111 | " 1 0.92 0.97 0.95 89\n", 1112 | " 2 0.67 0.62 0.64 60\n", 1113 | " 3 0.83 0.83 0.83 36\n", 1114 | " 4 0.79 1.00 0.88 34\n", 1115 | " 5 0.83 0.65 0.73 23\n", 1116 | " 6 1.00 0.83 0.91 24\n", 1117 | " 7 1.00 1.00 1.00 24\n", 1118 | " 8 0.68 0.65 0.67 23\n", 1119 | " 9 0.90 0.86 0.88 22\n", 1120 | " 10 0.85 0.50 0.63 22\n", 1121 | " 11 0.88 1.00 0.93 21\n", 1122 | " 12 1.00 0.90 0.95 21\n", 1123 | " 13 0.91 0.95 0.93 21\n", 1124 | " 14 1.00 0.95 0.98 21\n", 1125 | " 15 0.79 0.95 0.86 20\n", 1126 | " 16 0.90 0.47 0.62 19\n", 1127 | " 17 0.79 0.61 0.69 18\n", 1128 | " 18 0.63 0.67 0.65 18\n", 1129 | " 19 0.90 0.82 0.86 11\n", 1130 | " 20 1.00 0.70 0.82 10\n", 1131 | " 21 1.00 0.67 0.80 9\n", 1132 | " 22 1.00 0.88 0.93 8\n", 1133 | " 23 1.00 0.62 0.77 8\n", 1134 | " 24 1.00 1.00 1.00 8\n", 1135 | " 25 1.00 0.88 0.93 8\n", 1136 | " 26 0.88 0.88 0.88 8\n", 1137 | " 27 0.86 0.75 0.80 8\n", 1138 | " 28 1.00 1.00 1.00 8\n", 1139 | " 29 0.75 0.75 0.75 8\n", 1140 | " 30 0.75 1.00 0.86 6\n", 1141 | "\n", 1142 | " micro avg 0.84 0.84 0.84 770\n", 1143 | " macro avg 0.88 0.82 0.84 770\n", 1144 | "weighted avg 0.85 0.84 0.84 770\n", 1145 | "\n" 1146 | ] 1147 | } 1148 | ], 1149 | "source": [ 1150 | "print(classification_report(y_true, y_pred))" 1151 | ] 1152 | } 1153 | ], 1154 | "metadata": { 1155 | "kernelspec": { 1156 | "display_name": "Python 3", 1157 | "language": "python", 1158 | "name": "python3" 1159 | }, 1160 | "language_info": { 1161 | "codemirror_mode": { 1162 | "name": "ipython", 1163 | "version": 3 1164 | }, 1165 | "file_extension": ".py", 1166 | "mimetype": "text/x-python", 1167 | "name": "python", 1168 | "nbconvert_exporter": "python", 1169 | "pygments_lexer": "ipython3", 1170 | "version": "3.6.5" 1171 | } 1172 | }, 1173 | "nbformat": 4, 1174 | "nbformat_minor": 2 1175 | } 1176 | -------------------------------------------------------------------------------- /SMP2018_EDA_and_Baseline_Model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# $$SMP2018中文人机对话技术评测(ECDT)$$" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# **不要直接修改此文件,可把该文件拷贝至自己的文件夹下再进行操作**" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "1. 下面是一个完整的针对 [SMP2018中文人机对话技术评测(ECDT)](http://smp2018.cips-smp.org/ecdt_index.html) 的实验,由该实验训练的基线模型能达到评测排行榜的前三的水平。\n", 22 | "2. 通过本实验,可以掌握处理自然语言文本数据的一般方法。\n", 23 | "3. 推荐自己修改此文件,达到更好的实验效果,比如改变以下几个超参数 " 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "```python\n", 31 | "# 词嵌入的维度\n", 32 | "embedding_word_dims = 32\n", 33 | "# 批次大小\n", 34 | "batch_size = 30\n", 35 | "# 周期\n", 36 | "epochs = 20\n", 37 | "```" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "# 本实验还可以改进的地方举例 " 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "1. 预处理阶段使用其它的分词工具\n", 52 | "2. 采用字符向量和词向量结合的方式\n", 53 | "3. 使用预先训练好的词向量\n", 54 | "4. 改变模型结构\n", 55 | "5. 改变模型超参数" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# 导入依赖库" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 1, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stderr", 72 | "output_type": "stream", 73 | "text": [ 74 | "Using TensorFlow backend.\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "import numpy as np\n", 80 | "import pandas as pd\n", 81 | "import collections\n", 82 | "import jieba\n", 83 | "from keras.preprocessing.sequence import pad_sequences\n", 84 | "from keras.models import Sequential\n", 85 | "from keras.layers import Embedding, LSTM, Dense\n", 86 | "from keras.utils import to_categorical,plot_model\n", 87 | "from keras.callbacks import TensorBoard, Callback\n", 88 | "\n", 89 | "from sklearn.metrics import classification_report\n", 90 | "\n", 91 | "import requests \n", 92 | "\n", 93 | "import time\n", 94 | "\n", 95 | "import os" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "# 辅助函数" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 2, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "from keras import backend as K\n", 112 | "\n", 113 | "# 计算 F1 值的函数\n", 114 | "def f1(y_true, y_pred):\n", 115 | " def recall(y_true, y_pred):\n", 116 | " \"\"\"Recall metric.\n", 117 | "\n", 118 | " Only computes a batch-wise average of recall.\n", 119 | "\n", 120 | " Computes the recall, a metric for multi-label classification of\n", 121 | " how many relevant items are selected.\n", 122 | " \"\"\"\n", 123 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n", 124 | " possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))\n", 125 | " recall = true_positives / (possible_positives + K.epsilon())\n", 126 | " return recall\n", 127 | "\n", 128 | " def precision(y_true, y_pred):\n", 129 | " \"\"\"Precision metric.\n", 130 | "\n", 131 | " Only computes a batch-wise average of precision.\n", 132 | "\n", 133 | " Computes the precision, a metric for multi-label classification of\n", 134 | " how many selected items are relevant.\n", 135 | " \"\"\"\n", 136 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n", 137 | " predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))\n", 138 | " precision = true_positives / (predicted_positives + K.epsilon())\n", 139 | " return precision\n", 140 | " precision = precision(y_true, y_pred)\n", 141 | " recall = recall(y_true, y_pred)\n", 142 | " return 2*((precision*recall)/(precision+recall+K.epsilon()))" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 3, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "# 获取自定义时间格式的字符串\n", 152 | "def get_customization_time():\n", 153 | " # return '2018_10_10_18_11_45' 年月日时分秒\n", 154 | " time_tuple = time.localtime(time.time())\n", 155 | " customization_time = \"{}_{}_{}_{}_{}_{}\".format(time_tuple[0], time_tuple[1], time_tuple[2], time_tuple[3], time_tuple[4], time_tuple[5])\n", 156 | " return customization_time" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "# 准备数据" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "## [下载SMP2018官方数据](https://worksheets.codalab.org/worksheets/0x27203f932f8341b79841d50ce0fd684f/)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 3, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "raw_train_data_url = \"https://worksheets.codalab.org/rest/bundles/0x0161fd2fb40d4dd48541c2643d04b0b8/contents/blob/\"\n", 180 | "raw_test_data_url = \"https://worksheets.codalab.org/rest/bundles/0x1f96bc12222641209ad057e762910252/contents/blob/\"\n", 181 | "\n", 182 | "# 如果不存在 SMP2018 数据,则下载\n", 183 | "if (not os.path.exists('./data/train.json')) or (not os.path.exists('./data/dev.json')):\n", 184 | " raw_train = requests.get(raw_train_data_url) \n", 185 | " raw_test = requests.get(raw_test_data_url) \n", 186 | " if not os.path.exists('./data'):\n", 187 | " os.makedirs('./data')\n", 188 | " with open(\"./data/train.json\", \"wb\") as code:\n", 189 | " code.write(raw_train.content)\n", 190 | " with open(\"./data/dev.json\", \"wb\") as code:\n", 191 | " code.write(raw_test.content)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 4, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "def get_json_data(path):\n", 201 | " # read data\n", 202 | " data_df = pd.read_json(path)\n", 203 | " # change row and colunm\n", 204 | " data_df = data_df.transpose()\n", 205 | " # change colunm order\n", 206 | " data_df = data_df[['query', 'label']]\n", 207 | " return data_df" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 5, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "train_data_df = get_json_data(path=\"data/train.json\")\n", 217 | "\n", 218 | "test_data_df = get_json_data(path=\"data/dev.json\")" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 6, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "data": { 228 | "text/html": [ 229 | "
\n", 230 | "\n", 243 | "\n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | "
querylabel
0今天东莞天气如何weather
1从观音桥到重庆市图书馆怎么走map
2鸭蛋怎么腌?cookbook
3怎么治疗牛皮癣health
4唠什么chat
\n", 279 | "
" 280 | ], 281 | "text/plain": [ 282 | " query label\n", 283 | "0 今天东莞天气如何 weather\n", 284 | "1 从观音桥到重庆市图书馆怎么走 map\n", 285 | "2 鸭蛋怎么腌? cookbook\n", 286 | "3 怎么治疗牛皮癣 health\n", 287 | "4 唠什么 chat" 288 | ] 289 | }, 290 | "execution_count": 6, 291 | "metadata": {}, 292 | "output_type": "execute_result" 293 | } 294 | ], 295 | "source": [ 296 | "train_data_df.head()" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 7, 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "data": { 306 | "text/html": [ 307 | "
\n", 308 | "\n", 321 | "\n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | "
querylabel
0毛泽东的诗哦。poetry
1有房有车吗微笑chat
22013年亚洲冠军联赛恒广州恒大比赛时间。match
3若相惜不弃下一句是什么?poetry
4苹果翻译成英语translation
\n", 357 | "
" 358 | ], 359 | "text/plain": [ 360 | " query label\n", 361 | "0 毛泽东的诗哦。 poetry\n", 362 | "1 有房有车吗微笑 chat\n", 363 | "2 2013年亚洲冠军联赛恒广州恒大比赛时间。 match\n", 364 | "3 若相惜不弃下一句是什么? poetry\n", 365 | "4 苹果翻译成英语 translation" 366 | ] 367 | }, 368 | "execution_count": 7, 369 | "metadata": {}, 370 | "output_type": "execute_result" 371 | } 372 | ], 373 | "source": [ 374 | "test_data_df.head()" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 8, 380 | "metadata": {}, 381 | "outputs": [ 382 | { 383 | "data": { 384 | "text/html": [ 385 | "
\n", 386 | "\n", 399 | "\n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | "
querylabel
count22992299
unique229931
top中国新闻网网站chat
freq1455
\n", 430 | "
" 431 | ], 432 | "text/plain": [ 433 | " query label\n", 434 | "count 2299 2299\n", 435 | "unique 2299 31\n", 436 | "top 中国新闻网网站 chat\n", 437 | "freq 1 455" 438 | ] 439 | }, 440 | "execution_count": 8, 441 | "metadata": {}, 442 | "output_type": "execute_result" 443 | } 444 | ], 445 | "source": [ 446 | "train_data_df.describe()" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 9, 452 | "metadata": {}, 453 | "outputs": [ 454 | { 455 | "data": { 456 | "text/html": [ 457 | "
\n", 458 | "\n", 471 | "\n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | "
querylabel
count770770
unique77031
top查下安徽电视台今天节目单chat
freq1154
\n", 502 | "
" 503 | ], 504 | "text/plain": [ 505 | " query label\n", 506 | "count 770 770\n", 507 | "unique 770 31\n", 508 | "top 查下安徽电视台今天节目单 chat\n", 509 | "freq 1 154" 510 | ] 511 | }, 512 | "execution_count": 9, 513 | "metadata": {}, 514 | "output_type": "execute_result" 515 | } 516 | ], 517 | "source": [ 518 | "test_data_df.describe()" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 10, 524 | "metadata": {}, 525 | "outputs": [], 526 | "source": [ 527 | "# 获取所以标签,也就是分类的类别\n", 528 | "labels = list(set(train_data_df['label'].tolist()))" 529 | ] 530 | }, 531 | { 532 | "cell_type": "raw", 533 | "metadata": {}, 534 | "source": [ 535 | "# All labels\n", 536 | "labels = ['website', 'tvchannel', 'lottery', 'chat', 'match',\n", 537 | " 'datetime', 'weather', 'bus', 'novel', 'video', 'riddle',\n", 538 | " 'calc', 'telephone', 'health', 'contacts', 'epg', 'app', 'music',\n", 539 | " 'cookbook', 'stock', 'map', 'message', 'poetry', 'cinemas', 'news',\n", 540 | " 'flight', 'translation', 'train', 'schedule', 'radio', 'email']" 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": 11, 546 | "metadata": {}, 547 | "outputs": [ 548 | { 549 | "name": "stdout", 550 | "output_type": "stream", 551 | "text": [ 552 | "label_numbers:\t 31\n" 553 | ] 554 | } 555 | ], 556 | "source": [ 557 | "label_numbers = len(labels)\n", 558 | "print('label_numbers:\\t', label_numbers)" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": {}, 564 | "source": [ 565 | "## 标签和对应ID的映射字典" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 12, 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "label_2_index_dict = dict([(label, index) for index, label in enumerate(labels)])\n", 575 | "index_2_label_dict = dict([(index, label) for index, label in enumerate(labels)])" 576 | ] 577 | }, 578 | { 579 | "cell_type": "markdown", 580 | "metadata": {}, 581 | "source": [ 582 | "---" 583 | ] 584 | }, 585 | { 586 | "cell_type": "markdown", 587 | "metadata": {}, 588 | "source": [ 589 | "## [结巴分词](https://github.com/fxsjy/jieba)示例,下面将使用结巴分词对原数据进行处理" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": 13, 595 | "metadata": {}, 596 | "outputs": [ 597 | { 598 | "name": "stderr", 599 | "output_type": "stream", 600 | "text": [ 601 | "Building prefix dict from the default dictionary ...\n", 602 | "Loading model from cache /tmp/jieba.cache\n", 603 | "Loading model cost 0.903 seconds.\n", 604 | "Prefix dict has been built succesfully.\n" 605 | ] 606 | }, 607 | { 608 | "name": "stdout", 609 | "output_type": "stream", 610 | "text": [ 611 | "['他', '来到', '了', '网易', '杭研', '大厦']\n" 612 | ] 613 | } 614 | ], 615 | "source": [ 616 | "seg_list = jieba.cut(\"他来到了网易杭研大厦\") # 默认是精确模式\n", 617 | "print(list(seg_list))" 618 | ] 619 | }, 620 | { 621 | "cell_type": "markdown", 622 | "metadata": {}, 623 | "source": [ 624 | "---" 625 | ] 626 | }, 627 | { 628 | "cell_type": "markdown", 629 | "metadata": {}, 630 | "source": [ 631 | "# 序列化" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 14, 637 | "metadata": {}, 638 | "outputs": [], 639 | "source": [ 640 | "def use_jieba_cut(a_sentence):\n", 641 | " return list(jieba.cut(a_sentence))\n", 642 | "\n", 643 | "train_data_df['cut_query'] = train_data_df['query'].apply(use_jieba_cut)\n", 644 | "test_data_df['cut_query'] = test_data_df['query'].apply(use_jieba_cut)" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": 15, 650 | "metadata": {}, 651 | "outputs": [ 652 | { 653 | "data": { 654 | "text/html": [ 655 | "
\n", 656 | "\n", 669 | "\n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | "
querylabelcut_query
0今天东莞天气如何weather[今天, 东莞, 天气, 如何]
1从观音桥到重庆市图书馆怎么走map[从, 观音桥, 到, 重庆市, 图书馆, 怎么, 走]
2鸭蛋怎么腌?cookbook[鸭蛋, 怎么, 腌, ?]
3怎么治疗牛皮癣health[怎么, 治疗, 牛皮癣]
4唠什么chat[唠, 什么]
5阳澄湖大闸蟹的做法。cookbook[阳澄湖, 大闸蟹, 的, 做法, 。]
6昆山大润发在哪里map[昆山, 大润发, 在, 哪里]
7红烧肉怎么做?嗯?cookbook[红烧肉, 怎么, 做, ?, 嗯, ?]
8南京到厦门的火车票train[南京, 到, 厦门, 的, 火车票]
96的平方calc[6, 的, 平方]
\n", 741 | "
" 742 | ], 743 | "text/plain": [ 744 | " query label cut_query\n", 745 | "0 今天东莞天气如何 weather [今天, 东莞, 天气, 如何]\n", 746 | "1 从观音桥到重庆市图书馆怎么走 map [从, 观音桥, 到, 重庆市, 图书馆, 怎么, 走]\n", 747 | "2 鸭蛋怎么腌? cookbook [鸭蛋, 怎么, 腌, ?]\n", 748 | "3 怎么治疗牛皮癣 health [怎么, 治疗, 牛皮癣]\n", 749 | "4 唠什么 chat [唠, 什么]\n", 750 | "5 阳澄湖大闸蟹的做法。 cookbook [阳澄湖, 大闸蟹, 的, 做法, 。]\n", 751 | "6 昆山大润发在哪里 map [昆山, 大润发, 在, 哪里]\n", 752 | "7 红烧肉怎么做?嗯? cookbook [红烧肉, 怎么, 做, ?, 嗯, ?]\n", 753 | "8 南京到厦门的火车票 train [南京, 到, 厦门, 的, 火车票]\n", 754 | "9 6的平方 calc [6, 的, 平方]" 755 | ] 756 | }, 757 | "execution_count": 15, 758 | "metadata": {}, 759 | "output_type": "execute_result" 760 | } 761 | ], 762 | "source": [ 763 | "train_data_df.head(10)" 764 | ] 765 | }, 766 | { 767 | "cell_type": "code", 768 | "execution_count": 16, 769 | "metadata": {}, 770 | "outputs": [], 771 | "source": [ 772 | "# 获取数据的所有词汇\n", 773 | "def get_all_vocab_from_data(data, colunm_name):\n", 774 | " train_vocab_list = []\n", 775 | " max_cut_query_lenth = 0\n", 776 | " for cut_query in data[colunm_name]:\n", 777 | " if len(cut_query) > max_cut_query_lenth:\n", 778 | " max_cut_query_lenth = len(cut_query)\n", 779 | " train_vocab_list += cut_query\n", 780 | " return train_vocab_list, max_cut_query_lenth " 781 | ] 782 | }, 783 | { 784 | "cell_type": "code", 785 | "execution_count": 17, 786 | "metadata": {}, 787 | "outputs": [], 788 | "source": [ 789 | "train_vocab_list, max_cut_query_lenth = get_all_vocab_from_data(train_data_df, 'cut_query')" 790 | ] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "execution_count": 18, 795 | "metadata": {}, 796 | "outputs": [ 797 | { 798 | "name": "stdout", 799 | "output_type": "stream", 800 | "text": [ 801 | "Number of words:\t 11498\n" 802 | ] 803 | } 804 | ], 805 | "source": [ 806 | "print('Number of words:\\t', len(train_vocab_list))" 807 | ] 808 | }, 809 | { 810 | "cell_type": "code", 811 | "execution_count": 19, 812 | "metadata": {}, 813 | "outputs": [ 814 | { 815 | "name": "stdout", 816 | "output_type": "stream", 817 | "text": [ 818 | "max_cut_query_lenth:\t 26\n" 819 | ] 820 | } 821 | ], 822 | "source": [ 823 | "print('max_cut_query_lenth:\\t', max_cut_query_lenth)" 824 | ] 825 | }, 826 | { 827 | "cell_type": "code", 828 | "execution_count": 20, 829 | "metadata": {}, 830 | "outputs": [], 831 | "source": [ 832 | "test_vocab_list, test_max_cut_query_lenth = get_all_vocab_from_data(train_data_df, 'cut_query')" 833 | ] 834 | }, 835 | { 836 | "cell_type": "code", 837 | "execution_count": 21, 838 | "metadata": {}, 839 | "outputs": [ 840 | { 841 | "name": "stdout", 842 | "output_type": "stream", 843 | "text": [ 844 | "test_max_cut_query_lenth:\t 26\n" 845 | ] 846 | } 847 | ], 848 | "source": [ 849 | "print('test_max_cut_query_lenth:\\t', test_max_cut_query_lenth)" 850 | ] 851 | }, 852 | { 853 | "cell_type": "code", 854 | "execution_count": 22, 855 | "metadata": {}, 856 | "outputs": [ 857 | { 858 | "data": { 859 | "text/plain": [ 860 | "['今天', '东莞', '天气', '如何', '从', '观音桥', '到', '重庆市', '图书馆', '怎么']" 861 | ] 862 | }, 863 | "execution_count": 22, 864 | "metadata": {}, 865 | "output_type": "execute_result" 866 | } 867 | ], 868 | "source": [ 869 | "train_vocab_list[:10]" 870 | ] 871 | }, 872 | { 873 | "cell_type": "code", 874 | "execution_count": 23, 875 | "metadata": {}, 876 | "outputs": [], 877 | "source": [ 878 | "train_vocab_counter = collections.Counter(train_vocab_list)" 879 | ] 880 | }, 881 | { 882 | "cell_type": "code", 883 | "execution_count": 24, 884 | "metadata": {}, 885 | "outputs": [ 886 | { 887 | "name": "stdout", 888 | "output_type": "stream", 889 | "text": [ 890 | "Number of different words:\t 2887\n" 891 | ] 892 | } 893 | ], 894 | "source": [ 895 | "print('Number of different words:\\t', len(train_vocab_counter.keys()))" 896 | ] 897 | }, 898 | { 899 | "cell_type": "markdown", 900 | "metadata": {}, 901 | "source": [ 902 | "## 不同种类的词汇个数,预留一个位置给不存在的词汇(不存在的词汇标记为0) " 903 | ] 904 | }, 905 | { 906 | "cell_type": "code", 907 | "execution_count": 26, 908 | "metadata": {}, 909 | "outputs": [], 910 | "source": [ 911 | "max_features = len(train_vocab_counter.keys()) + 1" 912 | ] 913 | }, 914 | { 915 | "cell_type": "code", 916 | "execution_count": 27, 917 | "metadata": {}, 918 | "outputs": [ 919 | { 920 | "name": "stdout", 921 | "output_type": "stream", 922 | "text": [ 923 | "2888\n" 924 | ] 925 | } 926 | ], 927 | "source": [ 928 | "print(max_features)" 929 | ] 930 | }, 931 | { 932 | "cell_type": "code", 933 | "execution_count": 28, 934 | "metadata": {}, 935 | "outputs": [ 936 | { 937 | "data": { 938 | "text/plain": [ 939 | "[('的', 605),\n", 940 | " ('。', 341),\n", 941 | " ('我', 320),\n", 942 | " ('你', 297),\n", 943 | " ('怎么', 273),\n", 944 | " ('?', 251),\n", 945 | " ('什么', 210),\n", 946 | " ('到', 165),\n", 947 | " ('给', 154),\n", 948 | " ('做', 148)]" 949 | ] 950 | }, 951 | "execution_count": 28, 952 | "metadata": {}, 953 | "output_type": "execute_result" 954 | } 955 | ], 956 | "source": [ 957 | "# 10 words with the highest frequency\n", 958 | "train_vocab_counter.most_common(10)" 959 | ] 960 | }, 961 | { 962 | "cell_type": "markdown", 963 | "metadata": {}, 964 | "source": [ 965 | "## 统计低频词语" 966 | ] 967 | }, 968 | { 969 | "cell_type": "code", 970 | "execution_count": 29, 971 | "metadata": {}, 972 | "outputs": [ 973 | { 974 | "name": "stdout", 975 | "output_type": "stream", 976 | "text": [ 977 | "word_times_zero:\t 1978\n", 978 | "word_times_zero/all:\t 0.685140284031867\n" 979 | ] 980 | } 981 | ], 982 | "source": [ 983 | "word_times_zero = 0\n", 984 | "for word, word_times in train_vocab_counter.items():\n", 985 | " if word_times <=1:\n", 986 | " word_times_zero+=1\n", 987 | "print('word_times_zero:\\t', word_times_zero)\n", 988 | "print('word_times_zero/all:\\t', word_times_zero/len(train_vocab_counter))" 989 | ] 990 | }, 991 | { 992 | "cell_type": "markdown", 993 | "metadata": {}, 994 | "source": [ 995 | "## 制作词汇字典" 996 | ] 997 | }, 998 | { 999 | "cell_type": "code", 1000 | "execution_count": 30, 1001 | "metadata": {}, 1002 | "outputs": [], 1003 | "source": [ 1004 | "def create_train_vocab_dict(train_vocab_counter):\n", 1005 | " word_2_index, index_2_word = {}, {}\n", 1006 | " # Reserve 0 for masking via pad_sequences\n", 1007 | " index_number = 1\n", 1008 | " for word, word_times in train_vocab_counter.most_common():\n", 1009 | " word_2_index[word] = index_number\n", 1010 | " index_2_word[index_number] = word\n", 1011 | " index_number += 1\n", 1012 | " return word_2_index, index_2_word " 1013 | ] 1014 | }, 1015 | { 1016 | "cell_type": "code", 1017 | "execution_count": 31, 1018 | "metadata": {}, 1019 | "outputs": [], 1020 | "source": [ 1021 | "word_2_index_dict, index_2_word_dict = create_train_vocab_dict(train_vocab_counter)" 1022 | ] 1023 | }, 1024 | { 1025 | "cell_type": "code", 1026 | "execution_count": 32, 1027 | "metadata": {}, 1028 | "outputs": [ 1029 | { 1030 | "name": "stdout", 1031 | "output_type": "stream", 1032 | "text": [ 1033 | "1 2\n" 1034 | ] 1035 | } 1036 | ], 1037 | "source": [ 1038 | "print(word_2_index_dict['的'], word_2_index_dict['。'])" 1039 | ] 1040 | }, 1041 | { 1042 | "cell_type": "code", 1043 | "execution_count": 33, 1044 | "metadata": {}, 1045 | "outputs": [ 1046 | { 1047 | "name": "stdout", 1048 | "output_type": "stream", 1049 | "text": [ 1050 | "的 。\n" 1051 | ] 1052 | } 1053 | ], 1054 | "source": [ 1055 | "print(index_2_word_dict[1], index_2_word_dict[2])" 1056 | ] 1057 | }, 1058 | { 1059 | "cell_type": "code", 1060 | "execution_count": 34, 1061 | "metadata": {}, 1062 | "outputs": [ 1063 | { 1064 | "name": "stdout", 1065 | "output_type": "stream", 1066 | "text": [ 1067 | "今天东莞天气如何 weather ['今天', '东莞', '天气', '如何']\n", 1068 | "从观音桥到重庆市图书馆怎么走 map ['从', '观音桥', '到', '重庆市', '图书馆', '怎么', '走']\n", 1069 | "鸭蛋怎么腌? cookbook ['鸭蛋', '怎么', '腌', '?']\n", 1070 | "怎么治疗牛皮癣 health ['怎么', '治疗', '牛皮癣']\n", 1071 | "唠什么 chat ['唠', '什么']\n", 1072 | "阳澄湖大闸蟹的做法。 cookbook ['阳澄湖', '大闸蟹', '的', '做法', '。']\n", 1073 | "昆山大润发在哪里 map ['昆山', '大润发', '在', '哪里']\n", 1074 | "红烧肉怎么做?嗯? cookbook ['红烧肉', '怎么', '做', '?', '嗯', '?']\n", 1075 | "南京到厦门的火车票 train ['南京', '到', '厦门', '的', '火车票']\n", 1076 | "6的平方 calc ['6', '的', '平方']\n" 1077 | ] 1078 | } 1079 | ], 1080 | "source": [ 1081 | "pq= 0\n", 1082 | "for index, row in train_data_df.iterrows():\n", 1083 | " print(row[0], row[1], row[2])\n", 1084 | " pq+=1\n", 1085 | " if pq==10:\n", 1086 | " break" 1087 | ] 1088 | }, 1089 | { 1090 | "cell_type": "code", 1091 | "execution_count": 35, 1092 | "metadata": {}, 1093 | "outputs": [ 1094 | { 1095 | "data": { 1096 | "text/plain": [ 1097 | "0" 1098 | ] 1099 | }, 1100 | "execution_count": 35, 1101 | "metadata": {}, 1102 | "output_type": "execute_result" 1103 | } 1104 | ], 1105 | "source": [ 1106 | "word_2_index_dict.get('的2', 0)" 1107 | ] 1108 | }, 1109 | { 1110 | "cell_type": "code", 1111 | "execution_count": 36, 1112 | "metadata": {}, 1113 | "outputs": [], 1114 | "source": [ 1115 | "def vectorize_data(data, label_2_index_dict, word_2_index_dict, max_cut_query_lenth):\n", 1116 | " x_train = []\n", 1117 | " y_train = []\n", 1118 | " for index, row in data.iterrows():\n", 1119 | " query_sentence = row[2]\n", 1120 | " label = row[1]\n", 1121 | " # 字典找不到的情况下用 0 填充\n", 1122 | " x = [word_2_index_dict.get(w, 0) for w in query_sentence]\n", 1123 | " y = [label_2_index_dict[label]]\n", 1124 | " x_train.append(x)\n", 1125 | " y_train.append(y)\n", 1126 | " return (pad_sequences(x_train, maxlen=max_cut_query_lenth),\n", 1127 | " pad_sequences(y_train, maxlen=1))" 1128 | ] 1129 | }, 1130 | { 1131 | "cell_type": "code", 1132 | "execution_count": 37, 1133 | "metadata": {}, 1134 | "outputs": [], 1135 | "source": [ 1136 | "x_train, y_train = vectorize_data(train_data_df, label_2_index_dict, word_2_index_dict, max_cut_query_lenth)" 1137 | ] 1138 | }, 1139 | { 1140 | "cell_type": "code", 1141 | "execution_count": 38, 1142 | "metadata": {}, 1143 | "outputs": [], 1144 | "source": [ 1145 | "x_test, y_test = vectorize_data(test_data_df, label_2_index_dict, word_2_index_dict, test_max_cut_query_lenth)" 1146 | ] 1147 | }, 1148 | { 1149 | "cell_type": "code", 1150 | "execution_count": 39, 1151 | "metadata": {}, 1152 | "outputs": [ 1153 | { 1154 | "name": "stdout", 1155 | "output_type": "stream", 1156 | "text": [ 1157 | "[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 1158 | " 0 0 0 0 33 318 27 90] [7]\n" 1159 | ] 1160 | } 1161 | ], 1162 | "source": [ 1163 | "print(x_train[0], y_train[0])" 1164 | ] 1165 | }, 1166 | { 1167 | "cell_type": "code", 1168 | "execution_count": 40, 1169 | "metadata": {}, 1170 | "outputs": [], 1171 | "source": [ 1172 | "y_train = to_categorical(y_train, label_numbers)\n", 1173 | "y_test = to_categorical(y_test, label_numbers)" 1174 | ] 1175 | }, 1176 | { 1177 | "cell_type": "code", 1178 | "execution_count": 41, 1179 | "metadata": {}, 1180 | "outputs": [ 1181 | { 1182 | "name": "stdout", 1183 | "output_type": "stream", 1184 | "text": [ 1185 | "(2299, 26) (2299, 31)\n" 1186 | ] 1187 | } 1188 | ], 1189 | "source": [ 1190 | "print(x_train.shape, y_train.shape)" 1191 | ] 1192 | }, 1193 | { 1194 | "cell_type": "code", 1195 | "execution_count": 42, 1196 | "metadata": {}, 1197 | "outputs": [ 1198 | { 1199 | "name": "stdout", 1200 | "output_type": "stream", 1201 | "text": [ 1202 | "(770, 26) (770, 31)\n" 1203 | ] 1204 | } 1205 | ], 1206 | "source": [ 1207 | "print(x_test.shape, y_test.shape)" 1208 | ] 1209 | }, 1210 | { 1211 | "cell_type": "markdown", 1212 | "metadata": {}, 1213 | "source": [ 1214 | "# 存储预处理过的数据" 1215 | ] 1216 | }, 1217 | { 1218 | "cell_type": "code", 1219 | "execution_count": 43, 1220 | "metadata": {}, 1221 | "outputs": [ 1222 | { 1223 | "name": "stdout", 1224 | "output_type": "stream", 1225 | "text": [ 1226 | "\n" 1227 | ] 1228 | } 1229 | ], 1230 | "source": [ 1231 | "print(type(x_test))" 1232 | ] 1233 | }, 1234 | { 1235 | "cell_type": "code", 1236 | "execution_count": 44, 1237 | "metadata": {}, 1238 | "outputs": [], 1239 | "source": [ 1240 | "np.savez(\"preprocessed_data\", x_train, y_train, x_test, y_test)" 1241 | ] 1242 | }, 1243 | { 1244 | "cell_type": "markdown", 1245 | "metadata": {}, 1246 | "source": [ 1247 | "## 直接加载预处理的数据" 1248 | ] 1249 | }, 1250 | { 1251 | "cell_type": "code", 1252 | "execution_count": 4, 1253 | "metadata": {}, 1254 | "outputs": [], 1255 | "source": [ 1256 | "# 使用已经经过预处理的数据,默认不使用\n", 1257 | "use_preprocessed_data = True\n", 1258 | "\n", 1259 | "if use_preprocessed_data == True:\n", 1260 | " preprocessed_data = np.load('preprocessed_data.npz')\n", 1261 | " x_train, y_train, x_test, y_test = preprocessed_data['arr_0'], preprocessed_data['arr_1'], preprocessed_data['arr_2'], preprocessed_data['arr_3']," 1262 | ] 1263 | }, 1264 | { 1265 | "cell_type": "code", 1266 | "execution_count": 5, 1267 | "metadata": {}, 1268 | "outputs": [ 1269 | { 1270 | "name": "stdout", 1271 | "output_type": "stream", 1272 | "text": [ 1273 | "(2299, 26) (2299, 31)\n" 1274 | ] 1275 | } 1276 | ], 1277 | "source": [ 1278 | "print(x_train.shape, y_train.shape)" 1279 | ] 1280 | }, 1281 | { 1282 | "cell_type": "markdown", 1283 | "metadata": {}, 1284 | "source": [ 1285 | "# 设计模型" 1286 | ] 1287 | }, 1288 | { 1289 | "cell_type": "code", 1290 | "execution_count": 6, 1291 | "metadata": {}, 1292 | "outputs": [], 1293 | "source": [ 1294 | "def create_SMP2018_lstm_model(max_features, max_cut_query_lenth, label_numbers):\n", 1295 | " model = Sequential()\n", 1296 | " model.add(Embedding(input_dim=max_features, output_dim=32, input_length=max_cut_query_lenth))\n", 1297 | " model.add(LSTM(units=64, dropout=0.2, recurrent_dropout=0.2))\n", 1298 | " model.add(Dense(label_numbers, activation='softmax'))\n", 1299 | " # try using different optimizers and different optimizer configs\n", 1300 | " model.compile(loss='categorical_crossentropy',\n", 1301 | " optimizer='adam',\n", 1302 | " metrics=[f1])\n", 1303 | "\n", 1304 | " plot_model(model, to_file='SMP2018_lstm_model.png', show_shapes=True)\n", 1305 | " \n", 1306 | " return model" 1307 | ] 1308 | }, 1309 | { 1310 | "cell_type": "markdown", 1311 | "metadata": {}, 1312 | "source": [ 1313 | "# 训练模型" 1314 | ] 1315 | }, 1316 | { 1317 | "cell_type": "code", 1318 | "execution_count": 7, 1319 | "metadata": {}, 1320 | "outputs": [ 1321 | { 1322 | "name": "stdout", 1323 | "output_type": "stream", 1324 | "text": [ 1325 | "not find max_features variable, use default max_features values:\t2888\n", 1326 | "not find max_cut_query_lenth, use default max_features values:\t26\n", 1327 | "not find label_numbers, use default max_features values:\t31\n" 1328 | ] 1329 | } 1330 | ], 1331 | "source": [ 1332 | "if 'max_features' not in dir():\n", 1333 | " max_features = 2888\n", 1334 | " print('not find max_features variable, use default max_features values:\\t{}'.format(max_features))\n", 1335 | "if 'max_cut_query_lenth' not in dir():\n", 1336 | " max_cut_query_lenth = 26\n", 1337 | " print('not find max_cut_query_lenth, use default max_features values:\\t{}'.format(max_cut_query_lenth))\n", 1338 | "if 'label_numbers' not in dir():\n", 1339 | " label_numbers = 31\n", 1340 | " print('not find label_numbers, use default max_features values:\\t{}'.format(label_numbers))" 1341 | ] 1342 | }, 1343 | { 1344 | "cell_type": "code", 1345 | "execution_count": 8, 1346 | "metadata": {}, 1347 | "outputs": [], 1348 | "source": [ 1349 | "model = create_SMP2018_lstm_model(max_features, max_cut_query_lenth, label_numbers)" 1350 | ] 1351 | }, 1352 | { 1353 | "cell_type": "code", 1354 | "execution_count": 9, 1355 | "metadata": {}, 1356 | "outputs": [], 1357 | "source": [ 1358 | "batch_size = 20\n", 1359 | "epochs = 300" 1360 | ] 1361 | }, 1362 | { 1363 | "cell_type": "code", 1364 | "execution_count": null, 1365 | "metadata": {}, 1366 | "outputs": [ 1367 | { 1368 | "name": "stdout", 1369 | "output_type": "stream", 1370 | "text": [ 1371 | "Train...\n", 1372 | "Train on 1839 samples, validate on 460 samples\n", 1373 | "Epoch 1/300\n", 1374 | "1839/1839 [==============================] - 6s 3ms/step - loss: 3.1404 - f1: 0.0000e+00 - val_loss: 2.9658 - val_f1: 0.0000e+00\n", 1375 | "Epoch 2/300\n", 1376 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.8634 - f1: 0.0404 - val_loss: 2.5949 - val_f1: 0.1618\n", 1377 | "Epoch 3/300\n", 1378 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.3841 - f1: 0.3354 - val_loss: 2.1469 - val_f1: 0.4080\n", 1379 | "Epoch 4/300\n", 1380 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.9530 - f1: 0.4240 - val_loss: 1.8311 - val_f1: 0.4429\n", 1381 | "Epoch 5/300\n", 1382 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.5153 - f1: 0.5092 - val_loss: 1.4660 - val_f1: 0.5133\n", 1383 | "Epoch 6/300\n", 1384 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.1055 - f1: 0.6257 - val_loss: 1.2311 - val_f1: 0.6446\n", 1385 | "Epoch 7/300\n", 1386 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.7985 - f1: 0.7558 - val_loss: 1.0519 - val_f1: 0.6857\n", 1387 | "Epoch 8/300\n", 1388 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.5887 - f1: 0.8245 - val_loss: 0.9113 - val_f1: 0.7443\n", 1389 | "Epoch 9/300\n", 1390 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.4365 - f1: 0.8729 - val_loss: 0.8589 - val_f1: 0.7589\n", 1391 | "Epoch 10/300\n", 1392 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.3196 - f1: 0.9178 - val_loss: 0.8198 - val_f1: 0.7948\n", 1393 | "Epoch 11/300\n", 1394 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.2584 - f1: 0.9379 - val_loss: 0.7777 - val_f1: 0.8046\n", 1395 | "Epoch 12/300\n", 1396 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.1941 - f1: 0.9593 - val_loss: 0.7518 - val_f1: 0.8343\n", 1397 | "Epoch 13/300\n", 1398 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.1540 - f1: 0.9726 - val_loss: 0.7506 - val_f1: 0.8322\n", 1399 | "Epoch 14/300\n", 1400 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.1220 - f1: 0.9808 - val_loss: 0.7529 - val_f1: 0.8195\n", 1401 | "Epoch 15/300\n", 1402 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.1044 - f1: 0.9837 - val_loss: 0.7723 - val_f1: 0.8226\n", 1403 | "Epoch 16/300\n", 1404 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0884 - f1: 0.9868 - val_loss: 0.7465 - val_f1: 0.8326\n", 1405 | "Epoch 17/300\n", 1406 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0722 - f1: 0.9900 - val_loss: 0.7687 - val_f1: 0.8240\n", 1407 | "Epoch 18/300\n", 1408 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0594 - f1: 0.9941 - val_loss: 0.7584 - val_f1: 0.8252\n", 1409 | "Epoch 19/300\n", 1410 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0497 - f1: 0.9947 - val_loss: 0.7572 - val_f1: 0.8302\n", 1411 | "Epoch 20/300\n", 1412 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0437 - f1: 0.9958 - val_loss: 0.7714 - val_f1: 0.8260\n", 1413 | "Epoch 21/300\n", 1414 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0398 - f1: 0.9972 - val_loss: 0.7631 - val_f1: 0.8246\n", 1415 | "Epoch 22/300\n", 1416 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0321 - f1: 0.9973 - val_loss: 0.7698 - val_f1: 0.8276\n", 1417 | "Epoch 23/300\n", 1418 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0313 - f1: 0.9967 - val_loss: 0.7809 - val_f1: 0.8288\n", 1419 | "Epoch 24/300\n", 1420 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0272 - f1: 0.9989 - val_loss: 0.7797 - val_f1: 0.8218\n", 1421 | "Epoch 25/300\n", 1422 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0240 - f1: 0.9986 - val_loss: 0.7531 - val_f1: 0.8345\n", 1423 | "Epoch 26/300\n", 1424 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0236 - f1: 0.9978 - val_loss: 0.7988 - val_f1: 0.8201\n", 1425 | "Epoch 27/300\n", 1426 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0207 - f1: 0.9986 - val_loss: 0.8156 - val_f1: 0.8259\n", 1427 | "Epoch 28/300\n", 1428 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0185 - f1: 0.9995 - val_loss: 0.7938 - val_f1: 0.8185\n", 1429 | "Epoch 29/300\n", 1430 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0189 - f1: 0.9981 - val_loss: 0.7839 - val_f1: 0.8169\n", 1431 | "Epoch 30/300\n", 1432 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0145 - f1: 1.0000 - val_loss: 0.8001 - val_f1: 0.8296\n", 1433 | "Epoch 31/300\n", 1434 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0163 - f1: 0.9984 - val_loss: 0.8265 - val_f1: 0.8116\n", 1435 | "Epoch 32/300\n", 1436 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0167 - f1: 0.9970 - val_loss: 0.8117 - val_f1: 0.8320\n", 1437 | "Epoch 33/300\n", 1438 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0131 - f1: 0.9997 - val_loss: 0.8121 - val_f1: 0.8224\n", 1439 | "Epoch 34/300\n", 1440 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0098 - f1: 1.0000 - val_loss: 0.8158 - val_f1: 0.8277\n", 1441 | "Epoch 35/300\n", 1442 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0133 - f1: 0.9986 - val_loss: 0.8314 - val_f1: 0.8242\n", 1443 | "Epoch 36/300\n", 1444 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0099 - f1: 1.0000 - val_loss: 0.8447 - val_f1: 0.8231\n", 1445 | "Epoch 37/300\n", 1446 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0081 - f1: 1.0000 - val_loss: 0.8237 - val_f1: 0.8312\n", 1447 | "Epoch 38/300\n", 1448 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0117 - f1: 0.9986 - val_loss: 0.8239 - val_f1: 0.8155\n", 1449 | "Epoch 39/300\n", 1450 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0082 - f1: 1.0000 - val_loss: 0.8470 - val_f1: 0.8204\n", 1451 | "Epoch 40/300\n", 1452 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0072 - f1: 1.0000 - val_loss: 0.8471 - val_f1: 0.8262\n", 1453 | "Epoch 41/300\n", 1454 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0098 - f1: 0.9992 - val_loss: 0.8262 - val_f1: 0.8323\n", 1455 | "Epoch 42/300\n", 1456 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0110 - f1: 0.9978 - val_loss: 0.8577 - val_f1: 0.8205\n", 1457 | "Epoch 43/300\n", 1458 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0069 - f1: 0.9997 - val_loss: 0.8587 - val_f1: 0.8226\n", 1459 | "Epoch 44/300\n", 1460 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0059 - f1: 0.9995 - val_loss: 0.8217 - val_f1: 0.8253\n", 1461 | "Epoch 45/300\n", 1462 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0054 - f1: 0.9995 - val_loss: 0.8342 - val_f1: 0.8269\n", 1463 | "Epoch 46/300\n", 1464 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0044 - f1: 1.0000 - val_loss: 0.8494 - val_f1: 0.8310\n", 1465 | "Epoch 47/300\n", 1466 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0040 - f1: 1.0000 - val_loss: 0.8496 - val_f1: 0.8341\n", 1467 | "Epoch 48/300\n", 1468 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0044 - f1: 1.0000 - val_loss: 0.8640 - val_f1: 0.8269\n", 1469 | "Epoch 49/300\n", 1470 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0049 - f1: 1.0000 - val_loss: 0.8453 - val_f1: 0.8321\n", 1471 | "Epoch 50/300\n", 1472 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0033 - f1: 1.0000 - val_loss: 0.9022 - val_f1: 0.8279\n", 1473 | "Epoch 51/300\n", 1474 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0074 - f1: 0.9989 - val_loss: 0.9145 - val_f1: 0.8171\n", 1475 | "Epoch 52/300\n", 1476 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0050 - f1: 0.9995 - val_loss: 0.9031 - val_f1: 0.8159\n", 1477 | "Epoch 53/300\n", 1478 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0083 - f1: 0.9981 - val_loss: 0.9456 - val_f1: 0.8084\n", 1479 | "Epoch 54/300\n", 1480 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0057 - f1: 0.9995 - val_loss: 0.8972 - val_f1: 0.8265\n", 1481 | "Epoch 55/300\n", 1482 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0027 - f1: 1.0000 - val_loss: 0.8921 - val_f1: 0.8284\n", 1483 | "Epoch 56/300\n", 1484 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0027 - f1: 1.0000 - val_loss: 0.8885 - val_f1: 0.8318\n", 1485 | "Epoch 57/300\n", 1486 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0024 - f1: 1.0000 - val_loss: 0.9059 - val_f1: 0.8283\n", 1487 | "Epoch 58/300\n", 1488 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0028 - f1: 1.0000 - val_loss: 0.9045 - val_f1: 0.8233\n", 1489 | "Epoch 59/300\n", 1490 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0022 - f1: 1.0000 - val_loss: 0.9238 - val_f1: 0.8302\n", 1491 | "Epoch 60/300\n", 1492 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0024 - f1: 0.9997 - val_loss: 0.9383 - val_f1: 0.8209\n", 1493 | "Epoch 61/300\n", 1494 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0030 - f1: 1.0000 - val_loss: 0.9409 - val_f1: 0.8157\n", 1495 | "Epoch 62/300\n", 1496 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0023 - f1: 1.0000 - val_loss: 0.9529 - val_f1: 0.8255\n", 1497 | "Epoch 63/300\n", 1498 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0046 - f1: 0.9997 - val_loss: 0.9899 - val_f1: 0.8158\n", 1499 | "Epoch 64/300\n", 1500 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0040 - f1: 0.9995 - val_loss: 0.9625 - val_f1: 0.8138\n", 1501 | "Epoch 65/300\n", 1502 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0039 - f1: 0.9997 - val_loss: 0.9493 - val_f1: 0.8135\n", 1503 | "Epoch 66/300\n", 1504 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0024 - f1: 1.0000 - val_loss: 0.9872 - val_f1: 0.8151\n", 1505 | "Epoch 67/300\n", 1506 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0058 - f1: 0.9989 - val_loss: 0.9106 - val_f1: 0.8146\n", 1507 | "Epoch 68/300\n", 1508 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0045 - f1: 0.9997 - val_loss: 0.9383 - val_f1: 0.8191\n", 1509 | "Epoch 69/300\n", 1510 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0018 - f1: 1.0000 - val_loss: 0.9366 - val_f1: 0.8184\n", 1511 | "Epoch 70/300\n", 1512 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0022 - f1: 1.0000 - val_loss: 1.0150 - val_f1: 0.8079\n", 1513 | "Epoch 71/300\n", 1514 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0018 - f1: 1.0000 - val_loss: 0.9735 - val_f1: 0.8136\n", 1515 | "Epoch 72/300\n", 1516 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0028 - f1: 1.0000 - val_loss: 0.9194 - val_f1: 0.8333\n", 1517 | "Epoch 73/300\n", 1518 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0016 - f1: 1.0000 - val_loss: 0.9224 - val_f1: 0.8321\n", 1519 | "Epoch 74/300\n", 1520 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 0.9445 - val_f1: 0.8249\n", 1521 | "Epoch 75/300\n", 1522 | "1839/1839 [==============================] - 6s 3ms/step - loss: 9.3293e-04 - f1: 1.0000 - val_loss: 0.9333 - val_f1: 0.8304\n", 1523 | "Epoch 76/300\n", 1524 | "1839/1839 [==============================] - 6s 3ms/step - loss: 9.6054e-04 - f1: 1.0000 - val_loss: 0.9360 - val_f1: 0.8264\n", 1525 | "Epoch 77/300\n", 1526 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0034 - f1: 0.9992 - val_loss: 0.9103 - val_f1: 0.8299\n", 1527 | "Epoch 78/300\n", 1528 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0018 - f1: 1.0000 - val_loss: 0.9219 - val_f1: 0.8332\n", 1529 | "Epoch 79/300\n", 1530 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 0.9166 - val_f1: 0.8372\n", 1531 | "Epoch 80/300\n", 1532 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.4605e-04 - f1: 1.0000 - val_loss: 0.9176 - val_f1: 0.8362\n", 1533 | "Epoch 81/300\n", 1534 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0015 - f1: 0.9995 - val_loss: 0.9786 - val_f1: 0.8179\n", 1535 | "Epoch 82/300\n", 1536 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 0.9525 - val_f1: 0.8336\n", 1537 | "Epoch 83/300\n", 1538 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 0.9509 - val_f1: 0.8278\n", 1539 | "Epoch 84/300\n", 1540 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.4968e-04 - f1: 1.0000 - val_loss: 0.9370 - val_f1: 0.8280\n", 1541 | "Epoch 85/300\n", 1542 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.7447e-04 - f1: 1.0000 - val_loss: 0.9853 - val_f1: 0.8197\n", 1543 | "Epoch 86/300\n", 1544 | "1839/1839 [==============================] - 5s 3ms/step - loss: 9.7422e-04 - f1: 1.0000 - val_loss: 0.9614 - val_f1: 0.8287\n", 1545 | "Epoch 87/300\n", 1546 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.7887e-04 - f1: 1.0000 - val_loss: 0.9664 - val_f1: 0.8255\n", 1547 | "Epoch 88/300\n", 1548 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.4960e-04 - f1: 1.0000 - val_loss: 0.9559 - val_f1: 0.8309\n", 1549 | "Epoch 89/300\n", 1550 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.1952e-04 - f1: 1.0000 - val_loss: 0.9667 - val_f1: 0.8296\n", 1551 | "Epoch 90/300\n", 1552 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.7391e-04 - f1: 1.0000 - val_loss: 0.9579 - val_f1: 0.8348\n", 1553 | "Epoch 91/300\n", 1554 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.7227e-04 - f1: 1.0000 - val_loss: 0.9605 - val_f1: 0.8326\n", 1555 | "Epoch 92/300\n", 1556 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.7311e-04 - f1: 1.0000 - val_loss: 0.9678 - val_f1: 0.8298\n", 1557 | "Epoch 93/300\n", 1558 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.2191e-04 - f1: 1.0000 - val_loss: 0.9664 - val_f1: 0.8352\n", 1559 | "Epoch 94/300\n", 1560 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.5526e-04 - f1: 1.0000 - val_loss: 0.9601 - val_f1: 0.8335\n", 1561 | "Epoch 95/300\n", 1562 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.0580e-04 - f1: 1.0000 - val_loss: 0.9861 - val_f1: 0.8351\n", 1563 | "Epoch 96/300\n", 1564 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 1.0249 - val_f1: 0.8131\n", 1565 | "Epoch 97/300\n", 1566 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.6099e-04 - f1: 1.0000 - val_loss: 1.0059 - val_f1: 0.8183\n", 1567 | "Epoch 98/300\n", 1568 | "1839/1839 [==============================] - 6s 3ms/step - loss: 8.8925e-04 - f1: 1.0000 - val_loss: 1.0204 - val_f1: 0.8181\n", 1569 | "Epoch 99/300\n", 1570 | "1839/1839 [==============================] - 6s 3ms/step - loss: 7.3515e-04 - f1: 1.0000 - val_loss: 1.0013 - val_f1: 0.8184\n", 1571 | "Epoch 100/300\n", 1572 | "1839/1839 [==============================] - 6s 3ms/step - loss: 8.3419e-04 - f1: 1.0000 - val_loss: 1.0341 - val_f1: 0.8164\n", 1573 | "Epoch 101/300\n", 1574 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0035 - f1: 0.9989 - val_loss: 0.9958 - val_f1: 0.8191\n", 1575 | "Epoch 102/300\n", 1576 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0053 - f1: 0.9989 - val_loss: 1.0122 - val_f1: 0.8265\n", 1577 | "Epoch 103/300\n", 1578 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0013 - f1: 1.0000 - val_loss: 1.0103 - val_f1: 0.8209\n", 1579 | "Epoch 104/300\n", 1580 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0053 - f1: 0.9989 - val_loss: 1.1036 - val_f1: 0.8036\n", 1581 | "Epoch 105/300\n", 1582 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0020 - f1: 1.0000 - val_loss: 1.0385 - val_f1: 0.8145\n", 1583 | "Epoch 106/300\n", 1584 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 1.0522 - val_f1: 0.8209\n", 1585 | "Epoch 107/300\n", 1586 | "1839/1839 [==============================] - 6s 3ms/step - loss: 0.0010 - f1: 1.0000 - val_loss: 1.1168 - val_f1: 0.8123\n", 1587 | "Epoch 108/300\n", 1588 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0040 - f1: 0.9989 - val_loss: 1.1633 - val_f1: 0.8004\n", 1589 | "Epoch 109/300\n", 1590 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0034 - f1: 0.9995 - val_loss: 1.0392 - val_f1: 0.8165\n", 1591 | "Epoch 110/300\n", 1592 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0053 - f1: 0.9995 - val_loss: 0.9886 - val_f1: 0.8305\n", 1593 | "Epoch 111/300\n", 1594 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 1.0026 - val_f1: 0.8277\n", 1595 | "Epoch 112/300\n", 1596 | "1839/1839 [==============================] - 5s 3ms/step - loss: 9.9556e-04 - f1: 1.0000 - val_loss: 0.9939 - val_f1: 0.8292\n", 1597 | "Epoch 113/300\n", 1598 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0021 - f1: 0.9995 - val_loss: 1.0387 - val_f1: 0.8113\n", 1599 | "Epoch 114/300\n", 1600 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.0578e-04 - f1: 1.0000 - val_loss: 1.0518 - val_f1: 0.8205\n", 1601 | "Epoch 115/300\n", 1602 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.6815e-04 - f1: 1.0000 - val_loss: 1.0576 - val_f1: 0.8170\n", 1603 | "Epoch 116/300\n", 1604 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.7931e-04 - f1: 1.0000 - val_loss: 1.0544 - val_f1: 0.8152\n", 1605 | "Epoch 117/300\n", 1606 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.7737e-04 - f1: 1.0000 - val_loss: 1.0665 - val_f1: 0.8173\n", 1607 | "Epoch 118/300\n", 1608 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.3346e-04 - f1: 1.0000 - val_loss: 1.0667 - val_f1: 0.8168\n", 1609 | "Epoch 119/300\n", 1610 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.2447e-04 - f1: 1.0000 - val_loss: 1.0798 - val_f1: 0.8110\n", 1611 | "Epoch 120/300\n", 1612 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.6263e-04 - f1: 1.0000 - val_loss: 1.0843 - val_f1: 0.8114\n", 1613 | "Epoch 121/300\n", 1614 | "1839/1839 [==============================] - 5s 3ms/step - loss: 7.0956e-04 - f1: 1.0000 - val_loss: 1.0713 - val_f1: 0.8091\n", 1615 | "Epoch 122/300\n", 1616 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.3038e-04 - f1: 1.0000 - val_loss: 1.0521 - val_f1: 0.8218\n", 1617 | "Epoch 123/300\n", 1618 | "1839/1839 [==============================] - 5s 3ms/step - loss: 7.8382e-04 - f1: 1.0000 - val_loss: 1.0743 - val_f1: 0.8143\n", 1619 | "Epoch 124/300\n", 1620 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.2436e-04 - f1: 1.0000 - val_loss: 1.0851 - val_f1: 0.8174\n", 1621 | "Epoch 125/300\n", 1622 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.3625e-04 - f1: 1.0000 - val_loss: 1.1127 - val_f1: 0.8153\n", 1623 | "Epoch 126/300\n", 1624 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.4003e-04 - f1: 1.0000 - val_loss: 1.0980 - val_f1: 0.8134\n", 1625 | "Epoch 127/300\n", 1626 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.1112e-04 - f1: 1.0000 - val_loss: 1.0975 - val_f1: 0.8158\n", 1627 | "Epoch 128/300\n", 1628 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.7792e-04 - f1: 1.0000 - val_loss: 1.0916 - val_f1: 0.8216\n", 1629 | "Epoch 129/300\n", 1630 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.5881e-04 - f1: 1.0000 - val_loss: 1.0802 - val_f1: 0.8241\n", 1631 | "Epoch 130/300\n", 1632 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.0872e-04 - f1: 1.0000 - val_loss: 1.0879 - val_f1: 0.8220\n", 1633 | "Epoch 131/300\n", 1634 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.8239e-04 - f1: 1.0000 - val_loss: 1.0904 - val_f1: 0.8256\n", 1635 | "Epoch 132/300\n", 1636 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.1858e-04 - f1: 1.0000 - val_loss: 1.0941 - val_f1: 0.8200\n", 1637 | "Epoch 133/300\n", 1638 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.3054e-04 - f1: 1.0000 - val_loss: 1.0942 - val_f1: 0.8181\n", 1639 | "Epoch 134/300\n", 1640 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.2964e-04 - f1: 1.0000 - val_loss: 1.0947 - val_f1: 0.8209\n", 1641 | "Epoch 135/300\n", 1642 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.8267e-04 - f1: 1.0000 - val_loss: 1.0943 - val_f1: 0.8270\n", 1643 | "Epoch 136/300\n", 1644 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.2101e-04 - f1: 1.0000 - val_loss: 1.0975 - val_f1: 0.8256\n", 1645 | "Epoch 137/300\n", 1646 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.3244e-04 - f1: 1.0000 - val_loss: 1.0967 - val_f1: 0.8294\n", 1647 | "Epoch 138/300\n", 1648 | "1839/1839 [==============================] - 6s 3ms/step - loss: 9.5624e-05 - f1: 1.0000 - val_loss: 1.1052 - val_f1: 0.8288\n", 1649 | "Epoch 139/300\n", 1650 | "1839/1839 [==============================] - 6s 3ms/step - loss: 3.0982e-04 - f1: 1.0000 - val_loss: 1.1630 - val_f1: 0.8185\n", 1651 | "Epoch 140/300\n", 1652 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.4933e-04 - f1: 1.0000 - val_loss: 1.1253 - val_f1: 0.8189\n", 1653 | "Epoch 141/300\n", 1654 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.3635e-04 - f1: 1.0000 - val_loss: 1.1227 - val_f1: 0.8225\n", 1655 | "Epoch 142/300\n", 1656 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.1006e-04 - f1: 1.0000 - val_loss: 1.1184 - val_f1: 0.8251\n", 1657 | "Epoch 143/300\n", 1658 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.2128e-04 - f1: 1.0000 - val_loss: 1.1187 - val_f1: 0.8309\n", 1659 | "Epoch 144/300\n", 1660 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.1148e-04 - f1: 1.0000 - val_loss: 1.1111 - val_f1: 0.8275\n", 1661 | "Epoch 145/300\n", 1662 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.2230e-04 - f1: 1.0000 - val_loss: 1.1228 - val_f1: 0.8313\n", 1663 | "Epoch 146/300\n", 1664 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.0086e-04 - f1: 1.0000 - val_loss: 1.1273 - val_f1: 0.8293\n", 1665 | "Epoch 147/300\n", 1666 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.5006e-04 - f1: 1.0000 - val_loss: 1.1418 - val_f1: 0.8300\n", 1667 | "Epoch 148/300\n", 1668 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0024 - f1: 0.9995 - val_loss: 1.1586 - val_f1: 0.8162\n", 1669 | "Epoch 149/300\n", 1670 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.8215e-04 - f1: 1.0000 - val_loss: 1.0818 - val_f1: 0.8347\n", 1671 | "Epoch 150/300\n", 1672 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.4316e-04 - f1: 1.0000 - val_loss: 1.0817 - val_f1: 0.8312\n", 1673 | "Epoch 151/300\n", 1674 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.3920e-04 - f1: 1.0000 - val_loss: 1.0768 - val_f1: 0.8339\n", 1675 | "Epoch 152/300\n", 1676 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.1098e-04 - f1: 1.0000 - val_loss: 1.0783 - val_f1: 0.8360\n", 1677 | "Epoch 153/300\n", 1678 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.2897e-04 - f1: 1.0000 - val_loss: 1.0821 - val_f1: 0.8310\n", 1679 | "Epoch 154/300\n", 1680 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.6505e-04 - f1: 1.0000 - val_loss: 1.0777 - val_f1: 0.8308\n", 1681 | "Epoch 155/300\n", 1682 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0015 - f1: 0.9989 - val_loss: 1.1645 - val_f1: 0.8087\n", 1683 | "Epoch 156/300\n", 1684 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0011 - f1: 1.0000 - val_loss: 1.0931 - val_f1: 0.8377\n", 1685 | "Epoch 157/300\n", 1686 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.3027e-04 - f1: 1.0000 - val_loss: 1.1130 - val_f1: 0.8289\n", 1687 | "Epoch 158/300\n", 1688 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.2346e-04 - f1: 1.0000 - val_loss: 1.1253 - val_f1: 0.8290\n", 1689 | "Epoch 159/300\n", 1690 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0015 - f1: 0.9995 - val_loss: 1.0921 - val_f1: 0.8284\n", 1691 | "Epoch 160/300\n", 1692 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0042 - f1: 0.9989 - val_loss: 1.1114 - val_f1: 0.8177\n", 1693 | "Epoch 161/300\n", 1694 | "1839/1839 [==============================] - 5s 3ms/step - loss: 0.0021 - f1: 0.9995 - val_loss: 1.0611 - val_f1: 0.8253\n", 1695 | "Epoch 162/300\n", 1696 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.5093e-04 - f1: 1.0000 - val_loss: 1.0799 - val_f1: 0.8289\n", 1697 | "Epoch 163/300\n", 1698 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.7466e-04 - f1: 1.0000 - val_loss: 1.0400 - val_f1: 0.8241\n", 1699 | "Epoch 164/300\n", 1700 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.2356e-04 - f1: 1.0000 - val_loss: 1.0486 - val_f1: 0.8242\n", 1701 | "Epoch 165/300\n", 1702 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.2108e-04 - f1: 1.0000 - val_loss: 1.0388 - val_f1: 0.8244\n", 1703 | "Epoch 166/300\n", 1704 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.6586e-04 - f1: 1.0000 - val_loss: 1.0512 - val_f1: 0.8222\n", 1705 | "Epoch 167/300\n", 1706 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.1654e-04 - f1: 1.0000 - val_loss: 1.0569 - val_f1: 0.8237\n", 1707 | "Epoch 168/300\n", 1708 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.1511e-04 - f1: 1.0000 - val_loss: 1.0831 - val_f1: 0.8220\n", 1709 | "Epoch 169/300\n", 1710 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.7949e-04 - f1: 1.0000 - val_loss: 1.0916 - val_f1: 0.8185\n", 1711 | "Epoch 170/300\n", 1712 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.3409e-04 - f1: 1.0000 - val_loss: 1.0991 - val_f1: 0.8198\n", 1713 | "Epoch 171/300\n", 1714 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.0768e-04 - f1: 1.0000 - val_loss: 1.1065 - val_f1: 0.8176\n", 1715 | "Epoch 172/300\n", 1716 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.2211e-04 - f1: 1.0000 - val_loss: 1.1115 - val_f1: 0.8155\n", 1717 | "Epoch 173/300\n", 1718 | "1839/1839 [==============================] - 5s 3ms/step - loss: 2.1365e-04 - f1: 1.0000 - val_loss: 1.1072 - val_f1: 0.8128\n", 1719 | "Epoch 174/300\n", 1720 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.8997e-05 - f1: 1.0000 - val_loss: 1.1027 - val_f1: 0.8172\n", 1721 | "Epoch 175/300\n", 1722 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.0339e-04 - f1: 1.0000 - val_loss: 1.0983 - val_f1: 0.8200\n", 1723 | "Epoch 176/300\n", 1724 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.2617e-04 - f1: 1.0000 - val_loss: 1.1052 - val_f1: 0.8191\n", 1725 | "Epoch 177/300\n", 1726 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.2080e-05 - f1: 1.0000 - val_loss: 1.1063 - val_f1: 0.8178\n", 1727 | "Epoch 178/300\n", 1728 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.9384e-05 - f1: 1.0000 - val_loss: 1.1094 - val_f1: 0.8225\n", 1729 | "Epoch 179/300\n", 1730 | "1839/1839 [==============================] - 5s 3ms/step - loss: 7.3541e-05 - f1: 1.0000 - val_loss: 1.1141 - val_f1: 0.8212\n", 1731 | "Epoch 180/300\n", 1732 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.3658e-04 - f1: 1.0000 - val_loss: 1.1153 - val_f1: 0.8277\n", 1733 | "Epoch 181/300\n", 1734 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.2540e-05 - f1: 1.0000 - val_loss: 1.1166 - val_f1: 0.8268\n", 1735 | "Epoch 182/300\n", 1736 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.7952e-04 - f1: 1.0000 - val_loss: 1.1366 - val_f1: 0.8256\n", 1737 | "Epoch 183/300\n", 1738 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.6589e-05 - f1: 1.0000 - val_loss: 1.1376 - val_f1: 0.8260\n", 1739 | "Epoch 184/300\n", 1740 | "1839/1839 [==============================] - 5s 3ms/step - loss: 7.5782e-05 - f1: 1.0000 - val_loss: 1.1378 - val_f1: 0.8295\n", 1741 | "Epoch 185/300\n", 1742 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.7658e-05 - f1: 1.0000 - val_loss: 1.1392 - val_f1: 0.8247\n", 1743 | "Epoch 186/300\n", 1744 | "1839/1839 [==============================] - 5s 3ms/step - loss: 6.0105e-05 - f1: 1.0000 - val_loss: 1.1360 - val_f1: 0.8307\n", 1745 | "Epoch 187/300\n", 1746 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.7825e-05 - f1: 1.0000 - val_loss: 1.1352 - val_f1: 0.8317\n", 1747 | "Epoch 188/300\n", 1748 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.0959e-05 - f1: 1.0000 - val_loss: 1.1367 - val_f1: 0.8331\n", 1749 | "Epoch 189/300\n", 1750 | "1839/1839 [==============================] - 5s 3ms/step - loss: 7.8447e-05 - f1: 1.0000 - val_loss: 1.1430 - val_f1: 0.8286\n", 1751 | "Epoch 190/300\n", 1752 | "1839/1839 [==============================] - 5s 3ms/step - loss: 5.0137e-05 - f1: 1.0000 - val_loss: 1.1394 - val_f1: 0.8309\n", 1753 | "Epoch 191/300\n", 1754 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.1537e-05 - f1: 1.0000 - val_loss: 1.1423 - val_f1: 0.8339\n", 1755 | "Epoch 192/300\n", 1756 | "1839/1839 [==============================] - 5s 3ms/step - loss: 8.0265e-05 - f1: 1.0000 - val_loss: 1.1264 - val_f1: 0.8333\n", 1757 | "Epoch 193/300\n", 1758 | "1839/1839 [==============================] - 5s 3ms/step - loss: 4.7630e-05 - f1: 1.0000 - val_loss: 1.1351 - val_f1: 0.8315\n", 1759 | "Epoch 194/300\n", 1760 | "1839/1839 [==============================] - 5s 3ms/step - loss: 9.6157e-05 - f1: 1.0000 - val_loss: 1.1569 - val_f1: 0.8323\n", 1761 | "Epoch 195/300\n", 1762 | "1839/1839 [==============================] - 6s 3ms/step - loss: 4.0880e-05 - f1: 1.0000 - val_loss: 1.1556 - val_f1: 0.8376\n", 1763 | "Epoch 196/300\n", 1764 | "1839/1839 [==============================] - 5s 3ms/step - loss: 1.0660e-04 - f1: 1.0000 - val_loss: 1.1486 - val_f1: 0.8388\n", 1765 | "Epoch 197/300\n", 1766 | "1839/1839 [==============================] - 6s 3ms/step - loss: 4.7847e-05 - f1: 1.0000 - val_loss: 1.1383 - val_f1: 0.8360\n", 1767 | "Epoch 198/300\n", 1768 | "1839/1839 [==============================] - 6s 3ms/step - loss: 4.5485e-05 - f1: 1.0000 - val_loss: 1.1384 - val_f1: 0.8355\n", 1769 | "Epoch 199/300\n", 1770 | "1839/1839 [==============================] - 7s 4ms/step - loss: 6.6492e-05 - f1: 1.0000 - val_loss: 1.1453 - val_f1: 0.8361\n", 1771 | "Epoch 200/300\n", 1772 | "1839/1839 [==============================] - 7s 4ms/step - loss: 3.0428e-05 - f1: 1.0000 - val_loss: 1.1482 - val_f1: 0.8382\n", 1773 | "Epoch 201/300\n", 1774 | "1839/1839 [==============================] - 7s 4ms/step - loss: 3.2437e-05 - f1: 1.0000 - val_loss: 1.1498 - val_f1: 0.8383\n", 1775 | "Epoch 202/300\n", 1776 | "1839/1839 [==============================] - 7s 4ms/step - loss: 3.4816e-05 - f1: 1.0000 - val_loss: 1.1570 - val_f1: 0.8361\n", 1777 | "Epoch 203/300\n", 1778 | "1839/1839 [==============================] - 6s 4ms/step - loss: 3.9865e-05 - f1: 1.0000 - val_loss: 1.1833 - val_f1: 0.8354\n", 1779 | "Epoch 204/300\n", 1780 | "1839/1839 [==============================] - 6s 3ms/step - loss: 6.4216e-05 - f1: 1.0000 - val_loss: 1.1805 - val_f1: 0.8315\n", 1781 | "Epoch 205/300\n", 1782 | "1839/1839 [==============================] - 6s 3ms/step - loss: 4.5160e-05 - f1: 1.0000 - val_loss: 1.1617 - val_f1: 0.8412\n", 1783 | "Epoch 206/300\n", 1784 | "1839/1839 [==============================] - 6s 3ms/step - loss: 4.1668e-05 - f1: 1.0000 - val_loss: 1.1689 - val_f1: 0.8394\n", 1785 | "Epoch 207/300\n", 1786 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.4791e-05 - f1: 1.0000 - val_loss: 1.1736 - val_f1: 0.8394\n", 1787 | "Epoch 208/300\n", 1788 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.3336e-05 - f1: 1.0000 - val_loss: 1.1760 - val_f1: 0.8379\n", 1789 | "Epoch 209/300\n", 1790 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.2662e-05 - f1: 1.0000 - val_loss: 1.1767 - val_f1: 0.8362\n", 1791 | "Epoch 210/300\n", 1792 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.6002e-05 - f1: 1.0000 - val_loss: 1.1850 - val_f1: 0.8359\n", 1793 | "Epoch 211/300\n", 1794 | "1839/1839 [==============================] - 6s 3ms/step - loss: 1.9812e-05 - f1: 1.0000 - val_loss: 1.1848 - val_f1: 0.8351\n", 1795 | "Epoch 212/300\n", 1796 | "1839/1839 [==============================] - 6s 3ms/step - loss: 3.1796e-05 - f1: 1.0000 - val_loss: 1.1862 - val_f1: 0.8332\n", 1797 | "Epoch 213/300\n", 1798 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.5655e-05 - f1: 1.0000 - val_loss: 1.1919 - val_f1: 0.8299\n", 1799 | "Epoch 214/300\n", 1800 | "1839/1839 [==============================] - 5s 3ms/step - loss: 3.0894e-05 - f1: 1.0000 - val_loss: 1.1916 - val_f1: 0.8326\n", 1801 | "Epoch 215/300\n", 1802 | "1839/1839 [==============================] - 6s 3ms/step - loss: 2.0926e-05 - f1: 1.0000 - val_loss: 1.1944 - val_f1: 0.8318\n", 1803 | "Epoch 216/300\n", 1804 | " 60/1839 [..............................] - ETA: 5s - loss: 9.7486e-06 - f1: 1.0000" 1805 | ] 1806 | } 1807 | ], 1808 | "source": [ 1809 | "print('Train...')\n", 1810 | "model.fit(x_train, y_train,\n", 1811 | " batch_size=batch_size,\n", 1812 | " epochs=epochs,\n", 1813 | " callbacks=[TensorBoard(log_dir='../logs/{}'.format(\"SMP2018_lstm_{}\".format(get_customization_time())))],\n", 1814 | " validation_split=0.2\n", 1815 | " )" 1816 | ] 1817 | }, 1818 | { 1819 | "cell_type": "markdown", 1820 | "metadata": {}, 1821 | "source": [ 1822 | "# 评估模型" 1823 | ] 1824 | }, 1825 | { 1826 | "cell_type": "code", 1827 | "execution_count": 53, 1828 | "metadata": {}, 1829 | "outputs": [ 1830 | { 1831 | "name": "stdout", 1832 | "output_type": "stream", 1833 | "text": [ 1834 | "770/770 [==============================] - 0s 240us/step\n", 1835 | "Test score: 0.7415416103291821\n", 1836 | "Test f1: 0.8223602949798882\n" 1837 | ] 1838 | } 1839 | ], 1840 | "source": [ 1841 | "score = model.evaluate(x_test, y_test,\n", 1842 | " batch_size=batch_size, verbose=1)\n", 1843 | "\n", 1844 | "print('Test score:', score[0])\n", 1845 | "print('Test f1:', score[1])" 1846 | ] 1847 | }, 1848 | { 1849 | "cell_type": "code", 1850 | "execution_count": 54, 1851 | "metadata": {}, 1852 | "outputs": [], 1853 | "source": [ 1854 | "y_hat_test = model.predict(x_test)" 1855 | ] 1856 | }, 1857 | { 1858 | "cell_type": "code", 1859 | "execution_count": 55, 1860 | "metadata": {}, 1861 | "outputs": [ 1862 | { 1863 | "name": "stdout", 1864 | "output_type": "stream", 1865 | "text": [ 1866 | "(770, 31)\n" 1867 | ] 1868 | } 1869 | ], 1870 | "source": [ 1871 | "print(y_hat_test.shape)" 1872 | ] 1873 | }, 1874 | { 1875 | "cell_type": "markdown", 1876 | "metadata": {}, 1877 | "source": [ 1878 | "## 将 one-hot 张量转换成对应的整数" 1879 | ] 1880 | }, 1881 | { 1882 | "cell_type": "code", 1883 | "execution_count": 56, 1884 | "metadata": {}, 1885 | "outputs": [], 1886 | "source": [ 1887 | "y_pred = np.argmax(y_hat_test, axis=1).tolist()" 1888 | ] 1889 | }, 1890 | { 1891 | "cell_type": "code", 1892 | "execution_count": 57, 1893 | "metadata": {}, 1894 | "outputs": [], 1895 | "source": [ 1896 | "y_true = np.argmax(y_test, axis=1).tolist()" 1897 | ] 1898 | }, 1899 | { 1900 | "cell_type": "markdown", 1901 | "metadata": {}, 1902 | "source": [ 1903 | "## 查看多分类的 准确率、召回率、F1 值" 1904 | ] 1905 | }, 1906 | { 1907 | "cell_type": "code", 1908 | "execution_count": 58, 1909 | "metadata": {}, 1910 | "outputs": [ 1911 | { 1912 | "name": "stdout", 1913 | "output_type": "stream", 1914 | "text": [ 1915 | " precision recall f1-score support\n", 1916 | "\n", 1917 | " 0 1.00 0.90 0.95 21\n", 1918 | " 1 0.86 0.75 0.80 8\n", 1919 | " 2 1.00 0.95 0.98 21\n", 1920 | " 3 0.52 0.57 0.54 23\n", 1921 | " 4 0.91 0.91 0.91 11\n", 1922 | " 5 0.82 0.97 0.89 34\n", 1923 | " 6 0.25 0.17 0.20 6\n", 1924 | " 7 0.86 0.86 0.86 22\n", 1925 | " 8 1.00 0.88 0.93 8\n", 1926 | " 9 0.89 1.00 0.94 8\n", 1927 | " 10 0.95 0.95 0.95 21\n", 1928 | " 11 1.00 0.62 0.77 8\n", 1929 | " 12 0.62 0.70 0.66 60\n", 1930 | " 13 0.86 0.90 0.88 20\n", 1931 | " 14 0.55 0.58 0.56 19\n", 1932 | " 15 0.76 0.78 0.77 36\n", 1933 | " 16 0.87 0.90 0.89 154\n", 1934 | " 17 0.57 0.50 0.53 8\n", 1935 | " 18 0.86 0.75 0.80 8\n", 1936 | " 19 0.74 0.95 0.83 21\n", 1937 | " 20 0.87 0.83 0.85 24\n", 1938 | " 21 1.00 0.75 0.86 8\n", 1939 | " 22 0.67 0.67 0.67 9\n", 1940 | " 23 1.00 1.00 1.00 8\n", 1941 | " 24 0.62 0.56 0.59 18\n", 1942 | " 25 0.92 1.00 0.96 24\n", 1943 | " 26 0.75 0.30 0.43 10\n", 1944 | " 27 0.71 0.55 0.62 22\n", 1945 | " 28 0.71 0.65 0.68 23\n", 1946 | " 29 0.73 0.61 0.67 18\n", 1947 | " 30 0.94 0.94 0.94 89\n", 1948 | "\n", 1949 | " micro avg 0.82 0.82 0.82 770\n", 1950 | " macro avg 0.80 0.76 0.77 770\n", 1951 | "weighted avg 0.82 0.82 0.81 770\n", 1952 | "\n" 1953 | ] 1954 | } 1955 | ], 1956 | "source": [ 1957 | "print(classification_report(y_true, y_pred))" 1958 | ] 1959 | } 1960 | ], 1961 | "metadata": { 1962 | "kernelspec": { 1963 | "display_name": "Python 3", 1964 | "language": "python", 1965 | "name": "python3" 1966 | }, 1967 | "language_info": { 1968 | "codemirror_mode": { 1969 | "name": "ipython", 1970 | "version": 3 1971 | }, 1972 | "file_extension": ".py", 1973 | "mimetype": "text/x-python", 1974 | "name": "python", 1975 | "nbconvert_exporter": "python", 1976 | "pygments_lexer": "ipython3", 1977 | "version": "3.6.5" 1978 | } 1979 | }, 1980 | "nbformat": 4, 1981 | "nbformat_minor": 2 1982 | } 1983 | --------------------------------------------------------------------------------