├── README.md ├── RECENT_chatobot.ipynb ├── chatbot_keras.py ├── conversation_spliter.py ├── decoder_inputs.txt ├── encoder_inputs.txt ├── padded_decoder_sequences.txt ├── padded_encoder_sequences.txt ├── prepare_data.py ├── seq2seq_model.py └── training_chatbot.py /README.md: -------------------------------------------------------------------------------- 1 | # Automatic Encoder-Decoder Seq2Seq: English Chatbot 2 | ## エンコーダ・デコーダLSTMによるSeq2Seqによる英語チャットボット 3 | ![](https://cdn-images-1.medium.com/max/2560/1*1I2tTjCkMHlQ-r73eRn4ZQ.png) 4 | 5 | ## Introduction 6 | 7 | Seq2seq is Sequence to Sequence model, input and output of the model are time series data, and it converts time series data into another time series data. The idea is simple: prepare two RNNs, the input language side (encoder) and the output language side (decoder), and connect them with intermediate nodes. 8 | Pass the data which you want to convert as an input to Encoder, process it with Encoder, pass the processing result to Decoder, and Decoder outputs the conversion result of the input data. Encoder and Decoder use RNN and process given time series data respectively. 9 | 10 | ## Technical Preferences 11 | 12 | | Title | Detail | 13 | |:-----------:|:------------------------------------------------| 14 | | Environment | MacOS Mojave 10.14.3 | 15 | | Language | Python | 16 | | Library | Kras, scikit-learn, Numpy, matplotlib, Pandas, Seaborn | 17 | | Dataset | [Tab-delimited Bilingual Sentence Pairs](http://www.manythings.org/anki/) | 18 | | Algorithm | Encoder-Decoder LSTM | 19 | 20 | ## Refference 21 | 22 | - [Machine Translation using Sequence-to-Sequence Learning](https://nextjournal.com/gkoehler/machine-translation-seq2seq-cpu) 23 | - [Chatbots with Seq2Seq Learn to build a chatbot using TensorFlow](http://complx.me/2016-06-28-easy-seq2seq/) 24 | - [Generative Model Chatbots](https://medium.com/botsupply/generative-model-chatbots-e422ab08461e) 25 | - [How I Used Deep Learning To Train A Chatbot To Talk Like Me (Sorta)](https://adeshpande3.github.io/How-I-Used-Deep-Learning-to-Train-a-Chatbot-to-Talk-Like-Me) 26 | - [今更聞けないLSTMの基本](https://www.hellocybernetics.tech/entry/2017/05/06/182757) 27 | -------------------------------------------------------------------------------- /RECENT_chatobot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "A_pjoByYRMTv" 8 | }, 9 | "source": [ 10 | "# Seq2Seq: Encoder-Decoder Chatbot " 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "colab_type": "text", 17 | "id": "4n0pEEarRMTx" 18 | }, 19 | "source": [ 20 | "![](https://cdn-images-1.medium.com/max/2560/1*1I2tTjCkMHlQ-r73eRn4ZQ.png)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 0, 26 | "metadata": { 27 | "colab": {}, 28 | "colab_type": "code", 29 | "id": "XgAIoL02RMTy" 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import numpy as np\n", 34 | "import pandas as pd\n", 35 | "import string\n", 36 | "import pickle\n", 37 | "import operator\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": { 45 | "colab_type": "text", 46 | "id": "dHaiKoheRMT5" 47 | }, 48 | "source": [ 49 | "## Step 1. Import Data" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 0, 55 | "metadata": { 56 | "colab": {}, 57 | "colab_type": "code", 58 | "id": "jh3QLAlCRMT6" 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "# .txtから会話データを取得する\n", 63 | "import codecs\n", 64 | "\n", 65 | "with codecs.open(\"movie_lines.txt\", \"rb\", encoding=\"utf-8\", errors=\"ignore\") as f:\n", 66 | " lines = f.read().split(\"\\n\")\n", 67 | " conversations = []\n", 68 | " for line in lines:\n", 69 | " data = line.split(\" +++$+++ \")\n", 70 | " conversations.append(data)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 7, 76 | "metadata": { 77 | "colab": { 78 | "base_uri": "https://localhost:8080/", 79 | "height": 119 80 | }, 81 | "colab_type": "code", 82 | "id": "6M3eZnuPRMT9", 83 | "outputId": "c7420e45-8e12-4feb-bf58-7c208d14e842" 84 | }, 85 | "outputs": [ 86 | { 87 | "data": { 88 | "text/plain": [ 89 | "[['L1045', 'u0', 'm0', 'BIANCA', 'They do not!'],\n", 90 | " ['L1044', 'u2', 'm0', 'CAMERON', 'They do to!'],\n", 91 | " ['L985', 'u0', 'm0', 'BIANCA', 'I hope so.'],\n", 92 | " ['L984', 'u2', 'm0', 'CAMERON', 'She okay?'],\n", 93 | " ['L925', 'u0', 'm0', 'BIANCA', \"Let's go.\"],\n", 94 | " ['L924', 'u2', 'm0', 'CAMERON', 'Wow']]" 95 | ] 96 | }, 97 | "execution_count": 7, 98 | "metadata": { 99 | "tags": [] 100 | }, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "conversations[:6]" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 0, 111 | "metadata": { 112 | "colab": {}, 113 | "colab_type": "code", 114 | "id": "4hRs6j-vRMUE" 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "# idと会話だけ取り出す\n", 119 | "chats = {}\n", 120 | "for tokens in conversations:\n", 121 | " if len(tokens) > 4:\n", 122 | " idx = tokens[0][1:]\n", 123 | " chat = tokens[4]\n", 124 | " chats[int(idx)] = chat" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": { 131 | "colab": { 132 | "base_uri": "https://localhost:8080/", 133 | "height": 20454 134 | }, 135 | "colab_type": "code", 136 | "id": "Q5DrAY8PRMUN", 137 | "outputId": "35500f14-06a6-4eda-f163-397f0320a4c1" 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "# idと会話をセットにする\n", 142 | "sorted_chats = sorted(chats.items(), key = lambda x: x[0])\n", 143 | "sorted_chats" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 0, 149 | "metadata": { 150 | "colab": {}, 151 | "colab_type": "code", 152 | "id": "azB8ddZcRMUS" 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "# 会話のペアごとに辞書を作る { 会話セットid: [会話リスト] }\n", 157 | "conves_dict = {}\n", 158 | "counter = 1\n", 159 | "conves_ids = []\n", 160 | "for i in range(1, len(sorted_chats)+1):\n", 161 | " if i < len(sorted_chats):\n", 162 | " if (sorted_chats[i][0] - sorted_chats[i-1][0]) == 1:\n", 163 | " # 1つ前の会話の頭の文字がないのを確認\n", 164 | " if sorted_chats[i-1][1] not in conves_ids:\n", 165 | " conves_ids.append(sorted_chats[i-1][1])\n", 166 | " conves_ids.append(sorted_chats[i][1])\n", 167 | " elif (sorted_chats[i][0] - sorted_chats[i-1][0]) > 1: \n", 168 | " conves_dict[counter] = conves_ids\n", 169 | " conves_ids = []\n", 170 | " counter += 1\n", 171 | " else:\n", 172 | " pass" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": { 179 | "colab": { 180 | "base_uri": "https://localhost:8080/", 181 | "height": 117184 182 | }, 183 | "colab_type": "code", 184 | "id": "siMFeqSuRMUV", 185 | "outputId": "8a84cd2e-1da3-4a1a-ac69-39b96489c07c" 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "conves_dict" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 0, 195 | "metadata": { 196 | "colab": {}, 197 | "colab_type": "code", 198 | "id": "vLbT8fzGRMUa" 199 | }, 200 | "outputs": [], 201 | "source": [ 202 | "context_and_target = []\n", 203 | "for conves in conves_dict.values():\n", 204 | " # ペアがない会話は捨てる\n", 205 | " if len(conves) % 2 != 0:\n", 206 | " conves = conves[:-1]\n", 207 | " for i in range(0, len(conves), 2):\n", 208 | " context_and_target.append((conves[i], conves[i+1]))" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 14, 214 | "metadata": { 215 | "colab": { 216 | "base_uri": "https://localhost:8080/", 217 | "height": 153 218 | }, 219 | "colab_type": "code", 220 | "id": "wBJDX_kLRMUd", 221 | "outputId": "faf0d417-4f01-4296-9381-1e7ed10ca469" 222 | }, 223 | "outputs": [ 224 | { 225 | "data": { 226 | "text/plain": [ 227 | "[('Did you change your hair?', 'No.'),\n", 228 | " ('I missed you.',\n", 229 | " 'It says here you exposed yourself to a group of freshmen girls.'),\n", 230 | " ('It was a bratwurst. I was eating lunch.',\n", 231 | " 'With the teeth of your zipper?'),\n", 232 | " ('You the new guy?', 'So they tell me...'),\n", 233 | " (\"C'mon. I'm supposed to give you the tour.\",\n", 234 | " 'So -- which Dakota you from?')]" 235 | ] 236 | }, 237 | "execution_count": 14, 238 | "metadata": { 239 | "tags": [] 240 | }, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "# ペア完成\n", 246 | "context_and_target[:5]" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 0, 252 | "metadata": { 253 | "colab": {}, 254 | "colab_type": "code", 255 | "id": "nqBGBxadRMUi" 256 | }, 257 | "outputs": [], 258 | "source": [ 259 | "context, target = zip(*context_and_target)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 0, 265 | "metadata": { 266 | "colab": {}, 267 | "colab_type": "code", 268 | "id": "20Wf7yq3RMUl" 269 | }, 270 | "outputs": [], 271 | "source": [ 272 | "context = list(context)\n", 273 | "target = list(target)" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 17, 279 | "metadata": { 280 | "colab": { 281 | "base_uri": "https://localhost:8080/", 282 | "height": 102 283 | }, 284 | "colab_type": "code", 285 | "id": "ANNT_oibRMUp", 286 | "outputId": "20198454-ce47-44bd-df38-9a68a782347b" 287 | }, 288 | "outputs": [ 289 | { 290 | "data": { 291 | "text/plain": [ 292 | "['Did you change your hair?',\n", 293 | " 'I missed you.',\n", 294 | " 'It was a bratwurst. I was eating lunch.',\n", 295 | " 'You the new guy?',\n", 296 | " \"C'mon. I'm supposed to give you the tour.\"]" 297 | ] 298 | }, 299 | "execution_count": 17, 300 | "metadata": { 301 | "tags": [] 302 | }, 303 | "output_type": "execute_result" 304 | } 305 | ], 306 | "source": [ 307 | "context[:5]" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 18, 313 | "metadata": { 314 | "colab": { 315 | "base_uri": "https://localhost:8080/", 316 | "height": 357 317 | }, 318 | "colab_type": "code", 319 | "id": "kDb0GTENRMUw", 320 | "outputId": "0f72dffa-69f8-46a2-c2b8-257b02c7bfc6" 321 | }, 322 | "outputs": [ 323 | { 324 | "data": { 325 | "text/plain": [ 326 | "['No.',\n", 327 | " 'It says here you exposed yourself to a group of freshmen girls.',\n", 328 | " 'With the teeth of your zipper?',\n", 329 | " 'So they tell me...',\n", 330 | " 'So -- which Dakota you from?',\n", 331 | " 'I was kidding. People actually live there?',\n", 332 | " 'How many people were in your old school?',\n", 333 | " 'Get out!',\n", 334 | " 'Couple thousand. Most of them evil',\n", 335 | " 'Yeah, but these guys have never seen a horse. They just jack off to Clint Eastwood.',\n", 336 | " 'You burn, you pine, you perish?',\n", 337 | " \"Bianca Stratford. Sophomore. Don't even think about it\",\n", 338 | " \"I could start with your haircut, but it doesn't matter. She's not allowed to date until her older sister does. And that's an impossibility.\",\n", 339 | " 'Expressing my opinion is not a terrorist action.',\n", 340 | " 'I still maintain that he kicked himself in the balls. I was merely a spectator.',\n", 341 | " 'Tempestuous?',\n", 342 | " 'Patrick Verona Random skid.',\n", 343 | " \"I'm sure he's completely incapable of doing anything that interesting.\",\n", 344 | " 'Block E?',\n", 345 | " 'Just a little.']" 346 | ] 347 | }, 348 | "execution_count": 18, 349 | "metadata": { 350 | "tags": [] 351 | }, 352 | "output_type": "execute_result" 353 | } 354 | ], 355 | "source": [ 356 | "target[:20]" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": { 362 | "colab_type": "text", 363 | "id": "bl2kl5WTRMU3" 364 | }, 365 | "source": [ 366 | "## Step 2. Preprocessing for text data" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 0, 372 | "metadata": { 373 | "colab": {}, 374 | "colab_type": "code", 375 | "id": "pqqjGJsGRMU5" 376 | }, 377 | "outputs": [], 378 | "source": [ 379 | "# from my_seq2seq_text_cleanear import text_modifier, nonalpha_remover\n", 380 | "import re\n", 381 | "MAX_LEN = 12" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": 173, 387 | "metadata": { 388 | "colab": {}, 389 | "colab_type": "code", 390 | "id": "kXCK1rHRRMU-" 391 | }, 392 | "outputs": [], 393 | "source": [ 394 | "def clean_text(text):\n", 395 | " '''Clean text by removing unnecessary characters and altering the format of words.'''\n", 396 | "\n", 397 | " text = text.lower()\n", 398 | " \n", 399 | " text = re.sub(r\"i'm\", \"i am\", text)\n", 400 | " text = re.sub(r\"he's\", \"he is\", text)\n", 401 | " text = re.sub(r\"she's\", \"she is\", text)\n", 402 | " text = re.sub(r\"it's\", \"it is\", text)\n", 403 | " text = re.sub(r\"that's\", \"that is\", text)\n", 404 | " text = re.sub(r\"what's\", \"that is\", text)\n", 405 | " text = re.sub(r\"where's\", \"where is\", text)\n", 406 | " text = re.sub(r\"how's\", \"how is\", text)\n", 407 | " text = re.sub(r\"\\'ll\", \" will\", text)\n", 408 | " text = re.sub(r\"\\'ve\", \" have\", text)\n", 409 | " text = re.sub(r\"\\'re\", \" are\", text)\n", 410 | " text = re.sub(r\"\\'d\", \" would\", text)\n", 411 | " text = re.sub(r\"\\'re\", \" are\", text)\n", 412 | " text = re.sub(r\"won't\", \"will not\", text)\n", 413 | " text = re.sub(r\"can't\", \"cannot\", text)\n", 414 | " text = re.sub(r\"n't\", \" not\", text)\n", 415 | " text = re.sub(r\"n'\", \"ng\", text)\n", 416 | " text = re.sub(r\"'bout\", \"about\", text)\n", 417 | " text = re.sub(r\"'til\", \"until\", text)\n", 418 | " text = re.sub(r\"[-()\\\"#/@;:<>{}`+=~|.!?,]\", \"\", text)\n", 419 | " \n", 420 | " return text" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": { 426 | "colab_type": "text", 427 | "id": "iv5ggyQhRMVF" 428 | }, 429 | "source": [ 430 | "### 2-1. Clean Text" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 0, 436 | "metadata": { 437 | "colab": {}, 438 | "colab_type": "code", 439 | "id": "5WG930lmRMVH" 440 | }, 441 | "outputs": [], 442 | "source": [ 443 | "tidy_target = []\n", 444 | "for conve in target:\n", 445 | " text = clean_text(conve)\n", 446 | " tidy_target.append(text)" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 22, 452 | "metadata": { 453 | "colab": { 454 | "base_uri": "https://localhost:8080/", 455 | "height": 357 456 | }, 457 | "colab_type": "code", 458 | "id": "lil1q_dlRMVK", 459 | "outputId": "c9b3559d-ea48-4105-88ed-e99d911b02e6" 460 | }, 461 | "outputs": [ 462 | { 463 | "data": { 464 | "text/plain": [ 465 | "['no',\n", 466 | " 'it says here you exposed yourself to a group of freshmen girls',\n", 467 | " 'with the teeth of your zipper',\n", 468 | " 'so they tell me',\n", 469 | " 'so which dakota you from',\n", 470 | " 'i was kidding people actually live there',\n", 471 | " 'how many people were in your old school',\n", 472 | " 'get out',\n", 473 | " 'couple thousand most of them evil',\n", 474 | " 'yeah but these guys have never seen a horse they just jack off to clint eastwood',\n", 475 | " 'you burn you pine you perish',\n", 476 | " 'bianca stratford sophomore do not even think about it',\n", 477 | " 'i could start with your haircut but it does not matter she is not allowed to date until her older sister does and that is an impossibility',\n", 478 | " 'expressing my opinion is not a terrorist action',\n", 479 | " 'i still maintain that he kicked himself in the balls i was merely a spectator',\n", 480 | " 'tempestuous',\n", 481 | " 'patrick verona random skid',\n", 482 | " 'i am sure he is completely incapable of doing anything that interesting',\n", 483 | " 'block e',\n", 484 | " 'just a little']" 485 | ] 486 | }, 487 | "execution_count": 22, 488 | "metadata": { 489 | "tags": [] 490 | }, 491 | "output_type": "execute_result" 492 | } 493 | ], 494 | "source": [ 495 | "tidy_target[:20]" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": 0, 501 | "metadata": { 502 | "colab": {}, 503 | "colab_type": "code", 504 | "id": "1u7QXY-TRMVN" 505 | }, 506 | "outputs": [], 507 | "source": [ 508 | "tidy_context = []\n", 509 | "for conve in context:\n", 510 | " text = clean_text(conve)\n", 511 | " tidy_context.append(text)" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": 24, 517 | "metadata": { 518 | "colab": { 519 | "base_uri": "https://localhost:8080/", 520 | "height": 377 521 | }, 522 | "colab_type": "code", 523 | "id": "hUfGBwOqRMVP", 524 | "outputId": "60e91895-6e3b-434d-e9b4-ee0b624deb80" 525 | }, 526 | "outputs": [ 527 | { 528 | "data": { 529 | "text/plain": [ 530 | "['did you change your hair',\n", 531 | " 'i missed you',\n", 532 | " 'it was a bratwurst i was eating lunch',\n", 533 | " 'you the new guy',\n", 534 | " \"c'mon i am supposed to give you the tour\",\n", 535 | " 'north actually how would you ',\n", 536 | " 'yeah a couple we are outnumbered by the cows though',\n", 537 | " 'thirtytwo',\n", 538 | " 'how many people go here',\n", 539 | " 'that i am used to',\n", 540 | " 'that girl i ',\n", 541 | " 'who is she',\n", 542 | " 'why not',\n", 543 | " 'katarina stratford my my you have been terrorizing ms blaise again',\n", 544 | " \"well yes compared to your other choices of expression this year today's events are quite mild by the way bobby rictor's gonad retrieval operation went quite well in case you are interested\",\n", 545 | " 'the point is kat people perceive you as somewhat ',\n", 546 | " \"who's that\",\n", 547 | " 'that is pat verona the one who was gone for a year i heard he was doing porn movies',\n", 548 | " 'he always look so',\n", 549 | " 'mandella eat starving yourself is a very slow way to die']" 550 | ] 551 | }, 552 | "execution_count": 24, 553 | "metadata": { 554 | "tags": [] 555 | }, 556 | "output_type": "execute_result" 557 | } 558 | ], 559 | "source": [ 560 | "tidy_context[:20]" 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "execution_count": 0, 566 | "metadata": { 567 | "colab": {}, 568 | "colab_type": "code", 569 | "id": "TMRuuv8ZRMVX" 570 | }, 571 | "outputs": [], 572 | "source": [ 573 | "# decoderのinputにはタグ\n", 574 | "bos = \" \"\n", 575 | "eos = \" \"\n", 576 | "final_target = [bos + conve + eos for conve in tidy_target] \n", 577 | "encoder_inputs = tidy_context\n", 578 | "decoder_inputs = final_target" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 4, 584 | "metadata": { 585 | "colab": {}, 586 | "colab_type": "code", 587 | "id": "Rwj5qO6-RMVS" 588 | }, 589 | "outputs": [], 590 | "source": [ 591 | "import codecs\n", 592 | "with codecs.open(\"encoder_inputs.txt\", \"rb\", encoding=\"utf-8\", errors=\"ignore\") as f:\n", 593 | " lines = f.read().split(\"\\n\")\n", 594 | " encoder_text = []\n", 595 | " for line in lines:\n", 596 | " data = line.split(\"\\n\")[0]\n", 597 | " encoder_text.append(data)" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 51, 603 | "metadata": {}, 604 | "outputs": [ 605 | { 606 | "data": { 607 | "text/plain": [ 608 | "143865" 609 | ] 610 | }, 611 | "execution_count": 51, 612 | "metadata": {}, 613 | "output_type": "execute_result" 614 | } 615 | ], 616 | "source": [ 617 | "len(encoder_text)" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": null, 623 | "metadata": { 624 | "colab": {}, 625 | "colab_type": "code", 626 | "id": "pgjJwZnYRMVU" 627 | }, 628 | "outputs": [], 629 | "source": [ 630 | "encoder_text" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": 6, 636 | "metadata": {}, 637 | "outputs": [], 638 | "source": [ 639 | "with codecs.open(\"decoder_inputs.txt\", \"rb\", encoding=\"utf-8\", errors=\"ignore\") as f:\n", 640 | " lines = f.read().split(\"\\n\")\n", 641 | " decoder_text = []\n", 642 | " for line in lines:\n", 643 | " data = line.split(\"\\n\")[0]\n", 644 | " decoder_text.append(data)" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": null, 650 | "metadata": {}, 651 | "outputs": [], 652 | "source": [ 653 | "decoder_text" 654 | ] 655 | }, 656 | { 657 | "cell_type": "markdown", 658 | "metadata": { 659 | "colab_type": "text", 660 | "id": "UMXtB9i6RMVa" 661 | }, 662 | "source": [ 663 | "### 2-2. MAKE VOCABRALY" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": 0, 669 | "metadata": { 670 | "colab": {}, 671 | "colab_type": "code", 672 | "id": "BQWanckCRMVb" 673 | }, 674 | "outputs": [], 675 | "source": [ 676 | "# 一旦もともと辞書サイズを調べる\n", 677 | "dictionary = []\n", 678 | "for text in full_text:\n", 679 | " words = text.split()\n", 680 | " for i in range(0, len(words)):\n", 681 | " if words[i] not in dictionary:\n", 682 | " dictionary.append(words[i])" 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": 8, 688 | "metadata": { 689 | "colab": {}, 690 | "colab_type": "code", 691 | "id": "DR7nq44URMVf", 692 | "scrolled": true 693 | }, 694 | "outputs": [ 695 | { 696 | "name": "stderr", 697 | "output_type": "stream", 698 | "text": [ 699 | "Using TensorFlow backend.\n" 700 | ] 701 | } 702 | ], 703 | "source": [ 704 | "from keras.preprocessing.text import Tokenizer\n", 705 | "VOCAB_SIZE = 14999\n", 706 | "tokenizer = Tokenizer(num_words=VOCAB_SIZE)" 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "execution_count": 9, 712 | "metadata": {}, 713 | "outputs": [], 714 | "source": [ 715 | "full_text = encoder_text + decoder_text" 716 | ] 717 | }, 718 | { 719 | "cell_type": "code", 720 | "execution_count": 10, 721 | "metadata": { 722 | "colab": { 723 | "base_uri": "https://localhost:8080/", 724 | "height": 34 725 | }, 726 | "colab_type": "code", 727 | "id": "Q0RgA9ssRMVj", 728 | "outputId": "4ab16108-1ce4-4c32-d2f4-a9f4d7c66b80" 729 | }, 730 | "outputs": [ 731 | { 732 | "data": { 733 | "text/plain": [ 734 | "65283" 735 | ] 736 | }, 737 | "execution_count": 10, 738 | "metadata": {}, 739 | "output_type": "execute_result" 740 | } 741 | ], 742 | "source": [ 743 | "# 辞書を作る\n", 744 | "tokenizer.fit_on_texts(full_text)\n", 745 | "word_index = tokenizer.word_index\n", 746 | "len(word_index)" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "execution_count": 66, 752 | "metadata": { 753 | "colab": {}, 754 | "colab_type": "code", 755 | "id": "wIT8hFwjRMVn" 756 | }, 757 | "outputs": [], 758 | "source": [ 759 | "# リバースした辞書を用意\n", 760 | "index2word = {}\n", 761 | "for k, v in word_index.items():\n", 762 | " if v < 15000:\n", 763 | " index2word[v] = k\n", 764 | " if v > 15000:\n", 765 | " continue" 766 | ] 767 | }, 768 | { 769 | "cell_type": "code", 770 | "execution_count": null, 771 | "metadata": {}, 772 | "outputs": [], 773 | "source": [ 774 | "index2word" 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "execution_count": 68, 780 | "metadata": {}, 781 | "outputs": [], 782 | "source": [ 783 | "word2index = {}\n", 784 | "for k, v in index2word.items():\n", 785 | " word2index[v] = k" 786 | ] 787 | }, 788 | { 789 | "cell_type": "code", 790 | "execution_count": null, 791 | "metadata": {}, 792 | "outputs": [], 793 | "source": [ 794 | "word2index" 795 | ] 796 | }, 797 | { 798 | "cell_type": "code", 799 | "execution_count": 71, 800 | "metadata": { 801 | "colab": {}, 802 | "colab_type": "code", 803 | "id": "Vi5Zp56PZwWy" 804 | }, 805 | "outputs": [ 806 | { 807 | "data": { 808 | "text/plain": [ 809 | "True" 810 | ] 811 | }, 812 | "execution_count": 71, 813 | "metadata": {}, 814 | "output_type": "execute_result" 815 | } 816 | ], 817 | "source": [ 818 | "len(word2index) == len(index2word)" 819 | ] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": 70, 824 | "metadata": {}, 825 | "outputs": [ 826 | { 827 | "data": { 828 | "text/plain": [ 829 | "14999" 830 | ] 831 | }, 832 | "execution_count": 70, 833 | "metadata": {}, 834 | "output_type": "execute_result" 835 | } 836 | ], 837 | "source": [ 838 | "len(index2word)" 839 | ] 840 | }, 841 | { 842 | "cell_type": "markdown", 843 | "metadata": { 844 | "colab_type": "text", 845 | "id": "0ErckVpJRMVp" 846 | }, 847 | "source": [ 848 | "### 2-3. ONE-HOT VECTORIZER" 849 | ] 850 | }, 851 | { 852 | "cell_type": "code", 853 | "execution_count": 13, 854 | "metadata": { 855 | "colab": {}, 856 | "colab_type": "code", 857 | "id": "Dc9GMlYdRMVq" 858 | }, 859 | "outputs": [], 860 | "source": [ 861 | "# 単語のシーケンスを作る np.arrayにする\n", 862 | "encoder_sequences = tokenizer.texts_to_sequences(encoder_text)\n", 863 | "# encider_sequences = np.array(encider_sequences)" 864 | ] 865 | }, 866 | { 867 | "cell_type": "code", 868 | "execution_count": 14, 869 | "metadata": { 870 | "colab": {}, 871 | "colab_type": "code", 872 | "id": "EB6wvoIFRMVs" 873 | }, 874 | "outputs": [], 875 | "source": [ 876 | "# デコーダーデータ\n", 877 | "decoder_sequences = tokenizer.texts_to_sequences(decoder_text)\n", 878 | "# decoder_sequences = np.array(decoder_sequences)" 879 | ] 880 | }, 881 | { 882 | "cell_type": "code", 883 | "execution_count": null, 884 | "metadata": {}, 885 | "outputs": [], 886 | "source": [ 887 | "encoder_sequences" 888 | ] 889 | }, 890 | { 891 | "cell_type": "code", 892 | "execution_count": 16, 893 | "metadata": {}, 894 | "outputs": [], 895 | "source": [ 896 | "for seqs in encoder_sequences:\n", 897 | " for seq in seqs:\n", 898 | " if seq > 14999:\n", 899 | " print(seq)\n", 900 | " break" 901 | ] 902 | }, 903 | { 904 | "cell_type": "code", 905 | "execution_count": 139, 906 | "metadata": {}, 907 | "outputs": [ 908 | { 909 | "data": { 910 | "text/plain": [ 911 | "15000" 912 | ] 913 | }, 914 | "execution_count": 139, 915 | "metadata": {}, 916 | "output_type": "execute_result" 917 | } 918 | ], 919 | "source": [ 920 | "VOCAB_SIZE = len(index2word) + 1\n", 921 | "VOCAB_SIZE" 922 | ] 923 | }, 924 | { 925 | "cell_type": "code", 926 | "execution_count": 53, 927 | "metadata": {}, 928 | "outputs": [ 929 | { 930 | "data": { 931 | "text/plain": [ 932 | "(143865, 20, 15000)" 933 | ] 934 | }, 935 | "execution_count": 53, 936 | "metadata": {}, 937 | "output_type": "execute_result" 938 | } 939 | ], 940 | "source": [ 941 | "decoder_output_data.shape" 942 | ] 943 | }, 944 | { 945 | "cell_type": "code", 946 | "execution_count": null, 947 | "metadata": {}, 948 | "outputs": [], 949 | "source": [ 950 | "decoder_sequences" 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "execution_count": 98, 956 | "metadata": { 957 | "colab": {}, 958 | "colab_type": "code", 959 | "id": "KkUuN7EdRMVz" 960 | }, 961 | "outputs": [], 962 | "source": [ 963 | "import numpy as np\n", 964 | "MAX_LEN = 20\n", 965 | "num_samples = len(encoder_sequences)\n", 966 | "decoder_output_data = np.zeros((num_samples, MAX_LEN, VOCAB_SIZE), dtype=\"float32\")" 967 | ] 968 | }, 969 | { 970 | "cell_type": "code", 971 | "execution_count": 130, 972 | "metadata": { 973 | "colab": {}, 974 | "colab_type": "code", 975 | "id": "LTGqtHmlRMV4" 976 | }, 977 | "outputs": [], 978 | "source": [ 979 | "# outputの3Dテンソル\n", 980 | "for i, seqs in enumerate(decoder_input_data):\n", 981 | " for j, seq in enumerate(seqs):\n", 982 | " if j > 0:\n", 983 | " decoder_output_data[i][j][seq] = 1." 984 | ] 985 | }, 986 | { 987 | "cell_type": "code", 988 | "execution_count": 134, 989 | "metadata": {}, 990 | "outputs": [ 991 | { 992 | "data": { 993 | "text/plain": [ 994 | "(143865, 20, 15000)" 995 | ] 996 | }, 997 | "execution_count": 134, 998 | "metadata": {}, 999 | "output_type": "execute_result" 1000 | } 1001 | ], 1002 | "source": [ 1003 | "decoder_output_data.shape" 1004 | ] 1005 | }, 1006 | { 1007 | "cell_type": "markdown", 1008 | "metadata": { 1009 | "colab_type": "text", 1010 | "id": "VXYJEes1RMV9" 1011 | }, 1012 | "source": [ 1013 | "### 2-4. PADDING" 1014 | ] 1015 | }, 1016 | { 1017 | "cell_type": "code", 1018 | "execution_count": 128, 1019 | "metadata": { 1020 | "colab": { 1021 | "base_uri": "https://localhost:8080/", 1022 | "height": 215 1023 | }, 1024 | "colab_type": "code", 1025 | "id": "d_qBdh0eRMV-", 1026 | "outputId": "fcfb794d-4945-4d18-f1ec-5c4e42712bbb" 1027 | }, 1028 | "outputs": [], 1029 | "source": [ 1030 | "from keras.preprocessing.sequence import pad_sequences\n", 1031 | "encoder_input_data = pad_sequences(encoder_sequences, maxlen=MAX_LEN, dtype='int32', padding='post', truncating='post')\n", 1032 | "decoder_input_data = pad_sequences(decoder_sequences, maxlen=MAX_LEN, dtype='int32', padding='post', truncating='post')" 1033 | ] 1034 | }, 1035 | { 1036 | "cell_type": "code", 1037 | "execution_count": 129, 1038 | "metadata": {}, 1039 | "outputs": [ 1040 | { 1041 | "data": { 1042 | "text/plain": [ 1043 | "array([ 1, 32, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 1044 | " 0, 0, 0], dtype=int32)" 1045 | ] 1046 | }, 1047 | "execution_count": 129, 1048 | "metadata": {}, 1049 | "output_type": "execute_result" 1050 | } 1051 | ], 1052 | "source": [ 1053 | "decoder_input_data[0]" 1054 | ] 1055 | }, 1056 | { 1057 | "cell_type": "markdown", 1058 | "metadata": { 1059 | "colab_type": "text", 1060 | "id": "1IvPATnlRMWB" 1061 | }, 1062 | "source": [ 1063 | "### 2-5. Word2Vec: pretrained glove vector" 1064 | ] 1065 | }, 1066 | { 1067 | "cell_type": "code", 1068 | "execution_count": 57, 1069 | "metadata": { 1070 | "colab": {}, 1071 | "colab_type": "code", 1072 | "id": "lbEykxwmRMWC" 1073 | }, 1074 | "outputs": [ 1075 | { 1076 | "name": "stdout", 1077 | "output_type": "stream", 1078 | "text": [ 1079 | "Glove Loded!\n" 1080 | ] 1081 | } 1082 | ], 1083 | "source": [ 1084 | "embeddings_index = {}\n", 1085 | "with open('glove.6B.50d.txt', encoding='utf-8') as f:\n", 1086 | " for line in f:\n", 1087 | " values = line.split()\n", 1088 | " word = values[0]\n", 1089 | " coefs = np.asarray(values[1:], dtype='float32')\n", 1090 | " embeddings_index[word] = coefs\n", 1091 | " f.close()\n", 1092 | "\n", 1093 | "print(\"Glove Loded!\")" 1094 | ] 1095 | }, 1096 | { 1097 | "cell_type": "code", 1098 | "execution_count": 59, 1099 | "metadata": { 1100 | "colab": {}, 1101 | "colab_type": "code", 1102 | "id": "HUXce9sDRMWI" 1103 | }, 1104 | "outputs": [], 1105 | "source": [ 1106 | "embedding_dimention = 50\n", 1107 | "def embedding_matrix_creater(embedding_dimention, word_index):\n", 1108 | " embedding_matrix = np.zeros((len(word_index) + 1, embedding_dimention))\n", 1109 | " for word, i in word_index.items():\n", 1110 | " embedding_vector = embeddings_index.get(word)\n", 1111 | " if embedding_vector is not None:\n", 1112 | " # words not found in embedding index will be all-zeros.\n", 1113 | " embedding_matrix[i] = embedding_vector\n", 1114 | " return embedding_matrix" 1115 | ] 1116 | }, 1117 | { 1118 | "cell_type": "code", 1119 | "execution_count": 137, 1120 | "metadata": {}, 1121 | "outputs": [], 1122 | "source": [ 1123 | "embedding_matrix = embedding_matrix_creater(50, word_index=word2index)" 1124 | ] 1125 | }, 1126 | { 1127 | "cell_type": "code", 1128 | "execution_count": 140, 1129 | "metadata": {}, 1130 | "outputs": [], 1131 | "source": [ 1132 | "embed_layer = Embedding(input_dim=VOCAB_SIZE, output_dim=50, trainable=True,)\n", 1133 | "embed_layer.build((None,))\n", 1134 | "embed_layer.set_weights([embedding_matrix])" 1135 | ] 1136 | }, 1137 | { 1138 | "cell_type": "markdown", 1139 | "metadata": { 1140 | "colab_type": "text", 1141 | "id": "HBLzb6z0RMWM" 1142 | }, 1143 | "source": [ 1144 | "## Step 3. Build Seq2Seq Model" 1145 | ] 1146 | }, 1147 | { 1148 | "cell_type": "code", 1149 | "execution_count": 60, 1150 | "metadata": { 1151 | "colab": {}, 1152 | "colab_type": "code", 1153 | "id": "YW3imyr0RMWX" 1154 | }, 1155 | "outputs": [], 1156 | "source": [ 1157 | "from keras.layers import Embedding\n", 1158 | "from keras.layers import Input, Dense, LSTM, TimeDistributed\n", 1159 | "from keras.models import Model" 1160 | ] 1161 | }, 1162 | { 1163 | "cell_type": "code", 1164 | "execution_count": 149, 1165 | "metadata": {}, 1166 | "outputs": [], 1167 | "source": [ 1168 | "def seq2seq_model_builder(HIDDEN_DIM=300):\n", 1169 | " \n", 1170 | " encoder_inputs = Input(shape=(MAX_LEN, ), dtype='int32',)\n", 1171 | " encoder_embedding = embed_layer(encoder_inputs)\n", 1172 | " encoder_LSTM = LSTM(HIDDEN_DIM, return_state=True)\n", 1173 | " encoder_outputs, state_h, state_c = encoder_LSTM(encoder_embedding)\n", 1174 | " \n", 1175 | " decoder_inputs = Input(shape=(MAX_LEN, ), dtype='int32',)\n", 1176 | " decoder_embedding = embed_layer(decoder_inputs)\n", 1177 | " decoder_LSTM = LSTM(HIDDEN_DIM, return_state=True, return_sequences=True)\n", 1178 | " decoder_outputs, _, _ = decoder_LSTM(decoder_embedding, initial_state=[state_h, state_c])\n", 1179 | " \n", 1180 | " # dense_layer = Dense(VOCAB_SIZE, activation='softmax')\n", 1181 | " outputs = TimeDistributed(Dense(VOCAB_SIZE, activation='softmax'))(decoder_outputs)\n", 1182 | " model = Model([encoder_inputs, decoder_inputs], outputs)\n", 1183 | " \n", 1184 | " return model" 1185 | ] 1186 | }, 1187 | { 1188 | "cell_type": "code", 1189 | "execution_count": 150, 1190 | "metadata": {}, 1191 | "outputs": [], 1192 | "source": [ 1193 | "model = seq2seq_model_builder(HIDDEN_DIM=300)" 1194 | ] 1195 | }, 1196 | { 1197 | "cell_type": "code", 1198 | "execution_count": 151, 1199 | "metadata": {}, 1200 | "outputs": [ 1201 | { 1202 | "name": "stdout", 1203 | "output_type": "stream", 1204 | "text": [ 1205 | "__________________________________________________________________________________________________\n", 1206 | "Layer (type) Output Shape Param # Connected to \n", 1207 | "==================================================================================================\n", 1208 | "input_10 (InputLayer) (None, 20) 0 \n", 1209 | "__________________________________________________________________________________________________\n", 1210 | "input_9 (InputLayer) (None, 20) 0 \n", 1211 | "__________________________________________________________________________________________________\n", 1212 | "embedding_3 (Embedding) (None, 20, 50) 750000 input_9[0][0] \n", 1213 | " input_10[0][0] \n", 1214 | "__________________________________________________________________________________________________\n", 1215 | "lstm_11 (LSTM) [(None, 300), (None, 421200 embedding_3[8][0] \n", 1216 | "__________________________________________________________________________________________________\n", 1217 | "lstm_12 (LSTM) [(None, 20, 300), (N 421200 embedding_3[9][0] \n", 1218 | " lstm_11[0][1] \n", 1219 | " lstm_11[0][2] \n", 1220 | "__________________________________________________________________________________________________\n", 1221 | "time_distributed_4 (TimeDistrib (None, 20, 15000) 4515000 lstm_12[0][0] \n", 1222 | "==================================================================================================\n", 1223 | "Total params: 6,107,400\n", 1224 | "Trainable params: 6,107,400\n", 1225 | "Non-trainable params: 0\n", 1226 | "__________________________________________________________________________________________________\n" 1227 | ] 1228 | } 1229 | ], 1230 | "source": [ 1231 | "model.summary()" 1232 | ] 1233 | }, 1234 | { 1235 | "cell_type": "code", 1236 | "execution_count": 155, 1237 | "metadata": {}, 1238 | "outputs": [ 1239 | { 1240 | "data": { 1241 | "text/plain": [ 1242 | "'/Users/akr712/Desktop/CHATBOT'" 1243 | ] 1244 | }, 1245 | "execution_count": 155, 1246 | "metadata": {}, 1247 | "output_type": "execute_result" 1248 | } 1249 | ], 1250 | "source": [ 1251 | "pwd" 1252 | ] 1253 | }, 1254 | { 1255 | "cell_type": "code", 1256 | "execution_count": null, 1257 | "metadata": {}, 1258 | "outputs": [], 1259 | "source": [ 1260 | "from keras.utils import plot_model\n", 1261 | "plot_model(model, to_file='/Users/akr712/Desktop/CHATBOT/seq2seq.png')" 1262 | ] 1263 | }, 1264 | { 1265 | "cell_type": "code", 1266 | "execution_count": 154, 1267 | "metadata": { 1268 | "colab": {}, 1269 | "colab_type": "code", 1270 | "id": "rIFru1mFRMWd" 1271 | }, 1272 | "outputs": [], 1273 | "source": [ 1274 | "model.compile(optimizer='adam', loss ='categorical_crossentropy', metrics = ['accuracy'])" 1275 | ] 1276 | }, 1277 | { 1278 | "cell_type": "markdown", 1279 | "metadata": { 1280 | "colab_type": "text", 1281 | "id": "g8jouTFzRMWh" 1282 | }, 1283 | "source": [ 1284 | "## Step 4. Training Model" 1285 | ] 1286 | }, 1287 | { 1288 | "cell_type": "code", 1289 | "execution_count": 164, 1290 | "metadata": {}, 1291 | "outputs": [], 1292 | "source": [ 1293 | "BATCH_SIZE = 32\n", 1294 | "EPOCHS = 5" 1295 | ] 1296 | }, 1297 | { 1298 | "cell_type": "code", 1299 | "execution_count": 163, 1300 | "metadata": { 1301 | "colab": {}, 1302 | "colab_type": "code", 1303 | "id": "qGVVQvhFRMWq" 1304 | }, 1305 | "outputs": [ 1306 | { 1307 | "data": { 1308 | "text/plain": [ 1309 | "(143865, 20)" 1310 | ] 1311 | }, 1312 | "execution_count": 163, 1313 | "metadata": {}, 1314 | "output_type": "execute_result" 1315 | } 1316 | ], 1317 | "source": [ 1318 | "encoder_input_data.shape" 1319 | ] 1320 | }, 1321 | { 1322 | "cell_type": "code", 1323 | "execution_count": 165, 1324 | "metadata": { 1325 | "colab": {}, 1326 | "colab_type": "code", 1327 | "id": "p4i_CsA4RMWk" 1328 | }, 1329 | "outputs": [ 1330 | { 1331 | "name": "stdout", 1332 | "output_type": "stream", 1333 | "text": [ 1334 | "Epoch 1/5\n", 1335 | "143865/143865 [==============================] - 5913s 41ms/step - loss: 0.9308 - acc: 0.8280\n", 1336 | "Epoch 2/5\n", 1337 | "143865/143865 [==============================] - 5848s 41ms/step - loss: 0.0447 - acc: 0.9449\n", 1338 | "Epoch 3/5\n", 1339 | "143865/143865 [==============================] - 5494s 38ms/step - loss: 0.0052 - acc: 0.9493\n", 1340 | "Epoch 4/5\n", 1341 | "143865/143865 [==============================] - 5753s 40ms/step - loss: 0.0016 - acc: 0.9498\n", 1342 | "Epoch 5/5\n", 1343 | "143865/143865 [==============================] - 4970s 35ms/step - loss: 8.2432e-04 - acc: 0.9499\n" 1344 | ] 1345 | } 1346 | ], 1347 | "source": [ 1348 | "history = model.fit([encoder_input_data, decoder_input_data], \n", 1349 | " decoder_output_data, \n", 1350 | " epochs=EPOCHS, \n", 1351 | " batch_size=BATCH_SIZE)" 1352 | ] 1353 | }, 1354 | { 1355 | "cell_type": "markdown", 1356 | "metadata": {}, 1357 | "source": [ 1358 | "#### Visualize Learning History" 1359 | ] 1360 | }, 1361 | { 1362 | "cell_type": "code", 1363 | "execution_count": 183, 1364 | "metadata": { 1365 | "colab": {}, 1366 | "colab_type": "code", 1367 | "id": "Mj9pi9UGRMWn" 1368 | }, 1369 | "outputs": [ 1370 | { 1371 | "data": { 1372 | "image/png": "\n", 1373 | "text/plain": [ 1374 | "
" 1375 | ] 1376 | }, 1377 | "metadata": { 1378 | "needs_background": "light" 1379 | }, 1380 | "output_type": "display_data" 1381 | } 1382 | ], 1383 | "source": [ 1384 | "# 正確性の可視化\n", 1385 | "import matplotlib.pyplot as plt\n", 1386 | "%matplotlib inline\n", 1387 | "\n", 1388 | "plt.figure(figsize=(10, 6))\n", 1389 | "plt.plot(history.history['acc'])\n", 1390 | "#plt.plot(history.history['val_acc'])\n", 1391 | "plt.title('model accuracy')\n", 1392 | "plt.ylabel('accuracy')\n", 1393 | "plt.xlabel('epoch')\n", 1394 | "# plt.legend(['train', 'test'], loc='upper left')\n", 1395 | "plt.show()" 1396 | ] 1397 | }, 1398 | { 1399 | "cell_type": "code", 1400 | "execution_count": 184, 1401 | "metadata": {}, 1402 | "outputs": [ 1403 | { 1404 | "data": { 1405 | "image/png": "\n", 1406 | "text/plain": [ 1407 | "
" 1408 | ] 1409 | }, 1410 | "metadata": { 1411 | "needs_background": "light" 1412 | }, 1413 | "output_type": "display_data" 1414 | } 1415 | ], 1416 | "source": [ 1417 | "# 損失関数の可視化\n", 1418 | "plt.figure(figsize=(10, 6))\n", 1419 | "plt.plot(history.history['loss'])\n", 1420 | "# plt.plot(history.history['val_loss'])\n", 1421 | "plt.title('model loss')\n", 1422 | "plt.ylabel('loss')\n", 1423 | "plt.xlabel('epoch')\n", 1424 | "# plt.legend(['train', 'test'], loc='upper left')\n", 1425 | "plt.show()" 1426 | ] 1427 | }, 1428 | { 1429 | "cell_type": "code", 1430 | "execution_count": null, 1431 | "metadata": {}, 1432 | "outputs": [], 1433 | "source": [ 1434 | "# モデルの読み込み\n", 1435 | "with open('seq2seq.json',\"w\").write(model.to_json())\n", 1436 | "\n", 1437 | "# 重みの読み込み\n", 1438 | "model.load_weights('seq2seq.h5')\n", 1439 | "print(\"Saved Model!\")" 1440 | ] 1441 | }, 1442 | { 1443 | "cell_type": "code", 1444 | "execution_count": 187, 1445 | "metadata": { 1446 | "colab": {}, 1447 | "colab_type": "code", 1448 | "id": "4nigsv81RMXA" 1449 | }, 1450 | "outputs": [ 1451 | { 1452 | "name": "stdout", 1453 | "output_type": "stream", 1454 | "text": [ 1455 | "Saved Model!\n" 1456 | ] 1457 | } 1458 | ], 1459 | "source": [ 1460 | "# 重みを保存する\n", 1461 | "model_json = model.to_json()\n", 1462 | "with open(\"model.json\", \"w\") as json_file:\n", 1463 | " json_file.write(model_json)\n", 1464 | "\n", 1465 | "model.save_weights(\"chatbot_model.h5\")\n", 1466 | "print(\"Saved Model!\")" 1467 | ] 1468 | }, 1469 | { 1470 | "cell_type": "code", 1471 | "execution_count": 191, 1472 | "metadata": {}, 1473 | "outputs": [], 1474 | "source": [ 1475 | "json_string = model.to_json()\n", 1476 | "open('seq2seq.json', 'w').write(json_string)\n", 1477 | "model.save_weights('seq2seq_weights.h5')" 1478 | ] 1479 | }, 1480 | { 1481 | "cell_type": "code", 1482 | "execution_count": 192, 1483 | "metadata": {}, 1484 | "outputs": [ 1485 | { 1486 | "name": "stdout", 1487 | "output_type": "stream", 1488 | "text": [ 1489 | "1_0306_chatobot3.ipynb glove.6B.50d.txt\r\n", 1490 | "1_0306_chatobot4.ipynb model.json\r\n", 1491 | "apple_orange_model.json movie_lines.txt\r\n", 1492 | "apple_orange_weights.h5 padded_decoder_sequences.txt\r\n", 1493 | "chatbot_model.h5 padded_encoder_sequences.txt\r\n", 1494 | "decoder_inputs.txt seq2seq.json\r\n", 1495 | "encoder_inputs.txt seq2seq_weights.h5\r\n" 1496 | ] 1497 | } 1498 | ], 1499 | "source": [ 1500 | "%ls" 1501 | ] 1502 | }, 1503 | { 1504 | "cell_type": "code", 1505 | "execution_count": 190, 1506 | "metadata": {}, 1507 | "outputs": [ 1508 | { 1509 | "data": { 1510 | "text/plain": [ 1511 | "'/Users/akr712/Desktop/CHATBOT'" 1512 | ] 1513 | }, 1514 | "execution_count": 190, 1515 | "metadata": {}, 1516 | "output_type": "execute_result" 1517 | } 1518 | ], 1519 | "source": [ 1520 | "pwd" 1521 | ] 1522 | } 1523 | ], 1524 | "metadata": { 1525 | "colab": { 1526 | "name": "1. 0306_chatobot3.ipynb", 1527 | "provenance": [], 1528 | "version": "0.3.2" 1529 | }, 1530 | "kernelspec": { 1531 | "display_name": "Python 3", 1532 | "language": "python", 1533 | "name": "python3" 1534 | }, 1535 | "language_info": { 1536 | "codemirror_mode": { 1537 | "name": "ipython", 1538 | "version": 3 1539 | }, 1540 | "file_extension": ".py", 1541 | "mimetype": "text/x-python", 1542 | "name": "python", 1543 | "nbconvert_exporter": "python", 1544 | "pygments_lexer": "ipython3", 1545 | "version": "3.5.1" 1546 | } 1547 | }, 1548 | "nbformat": 4, 1549 | "nbformat_minor": 1 1550 | } 1551 | -------------------------------------------------------------------------------- /chatbot_keras.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from keras.layers import Input, Embedding, LSTM, Dense, RepeatVector, Bidirectional, Dropout, merge 4 | from keras.optimizers import Adam, SGD 5 | from keras.models import Model 6 | from keras.models import Sequential 7 | from keras.layers import Activation, Dense 8 | from keras.callbacks import EarlyStopping 9 | from keras.preprocessing import sequence 10 | 11 | import keras.backend as K 12 | import numpy as np 13 | np.random.seed(1234) # for reproducibility 14 | import pickle as cPickle 15 | import theano.tensor as T 16 | import os 17 | import pandas as pd 18 | import sys 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | """ 23 | Build Your Chatbot!!! 24 | センテンスレベルの文脈ベクトル 25 | 単語レベルの意味ベクトル 26 | """ 27 | 28 | # params 29 | WORD2VEC_DIMS = 100 30 | DOC2VEC_DIMS = 300 31 | 32 | DICTIONARY_SIZE = 10000 33 | MAX_INPUT_LENGTH = 30 34 | MAX_OUTPUT_LENGTH = 30 35 | 36 | NUM_HIDDEN_UNITS = 256 37 | BATCH_SIZE = 64 38 | NUM_EPOCHS = 100 39 | 40 | NUM_SUBSETS = 1 41 | 42 | PATIENCE = 0 43 | DROPOUT = .25 44 | N_TEST = 100 45 | 46 | CALL_BACKS = EarlyStopping(monitor='val_loss', patience=PATIENCE) 47 | 48 | # files 49 | vocabulary_file = 'vocabulary_movie' 50 | questions_file = 'Padded_context' 51 | answers_file = 'Padded_answers' 52 | weights_file = 'my_model_weights20.h5' 53 | GLOVE_DIR = './glove.6B/' 54 | 55 | # padding and buckets 56 | 57 | BOS = "" 58 | EOS = "" 59 | PAD = "" 60 | 61 | BUCKETS = [(5,10),(10,15),(15,25),(20,30)] 62 | 63 | def print_result(input): 64 | 65 | ans_partial = np.zeros((1,maxlen_input)) 66 | ans_partial[0, -1] = 2 # the index of the symbol BOS (begin of sentence) 67 | for k in range(maxlen_input - 1): 68 | ye = model.predict([input, ans_partial]) 69 | mp = np.argmax(ye) 70 | ans_partial[0, 0:-1] = ans_partial[0, 1:] 71 | ans_partial[0, -1] = mp 72 | text = '' 73 | for k in ans_partial[0]: 74 | k = k.astype(int) 75 | if k < (dictionary_size-2): 76 | w = vocabulary[k] 77 | text = text + w[0] + ' ' 78 | return(text) 79 | 80 | 81 | # ====================================================================== 82 | # Reading a pre-trained word embedding and addapting to our vocabulary: 83 | # ====================================================================== 84 | 85 | # 辞書づくり 86 | word2vec_index = {} 87 | f = open(os.path.join(GLOVE_DIR, "glove.6B.100d.txt")) 88 | for line in f: 89 | words = line.split() 90 | word = words[0] 91 | index = np.asarray(words[1:], dtype="float32") 92 | word2vec_index[word] = index 93 | f.close() 94 | 95 | print("The number of word vecters are: ", len(word2vec_index)) 96 | 97 | word_embedding_matrix = np.zeros((DICTIONARY_SIZE, WORD2VEC_DIMS)) 98 | # Load vocabulary 99 | vocabulary = cPickle.load(open(vocabulary_file, 'rb')) 100 | 101 | i = 0 102 | for word in vocabulary: 103 | word2vec = word2vec_index.get(word[0]) 104 | if word2vec is not None: 105 | word_embedding_matrix[i] = word2vec 106 | i += 1 107 | 108 | 109 | # ====================================================================== 110 | # Keras model of the chatbot: 111 | # ====================================================================== 112 | 113 | ADAM = Adam(lr=0.00005) 114 | 115 | """ 116 | Input Layer #Document*2 117 | """ 118 | input_context = Input(shape=(MAX_INPUT_LENGTH,), dtype="int32", name="input_context") 119 | input_answer = Input(shape=(MAX_INPUT_LENGTH,), dtype="int32", name="input_answer") 120 | 121 | """ 122 | Embedding Layer: 正の整数(インデックス)を固定次元の密ベクトルに変換します. 123 | ・input_dim: 正の整数.語彙数.入力データの最大インデックス + 1. 124 | ・output_dim: 0以上の整数.密なembeddingsの次元数. 125 | ・input_length: 入力の系列長(定数). この引数はこのレイヤーの後にFlattenからDenseレイヤーへ接続する際に必要です (これがないと,denseの出力のshapeを計算できません). 126 | """ 127 | # weightが存在したら引用する 128 | if os.path.isfile(weights_file): 129 | Shared_Embedding = Embedding(input_dim=DICTIONARY_SIZE, output_dim=WORD2VEC_DIMS, input_length=MAX_INPUT_LENGTH,) 130 | else: 131 | Shared_Embedding = Embedding(input_dim=DICTIONARY_SIZE, output_dim=WORD2VEC_DIMS, input_length=MAX_INPUT_LENGTH, 132 | weights=[word_embedding_matrix]) 133 | 134 | """ 135 | Shared Embedding Layer #Doc2Vec(Document*2) 136 | """ 137 | shared_embedding_context = Shared_Embedding(input_context) 138 | shared_embedding_answer = Shared_Embedding(input_answer) 139 | 140 | """ 141 | LSTM Layer # 142 | """ 143 | Encoder_LSTM = LSTM(units=DOC2VEC_DIMS, init= "lecun_uniform") 144 | Decoder_LSTM = LSTM(units=DOC2VEC_DIMS, init= "lecun_uniform") 145 | embedding_context = Encoder_LSTM(shared_embedding_context) 146 | embedding_answer = Decoder_LSTM(shared_embedding_answer) 147 | 148 | """ 149 | Merge Layer # 150 | """ 151 | merge_layer = merge([embedding_context, embedding_answer], mode='concat', concat_axis=1) 152 | 153 | """ 154 | Dense Layer # 155 | """ 156 | dence_layer = Dense(DICTIONARY_SIZE/2, activation="relu")(merge_layer) 157 | 158 | """ 159 | Output Layer # 160 | """ 161 | outputs = Dense(DICTIONARY_SIZE, activation="softmax")(dence_layer) 162 | 163 | """ 164 | Modeling 165 | """ 166 | model = Model(input=[input_context, input_answer], output=[outputs]) 167 | model.compile(loss="categorical_crossentropy", optimizer=ADAM) 168 | 169 | if os.path.isfile(weights_file): 170 | model.load_weights(weights_file) 171 | 172 | 173 | # ====================================================================== 174 | # Loading the data: 175 | # ====================================================================== 176 | 177 | Q = cPickle.load(open(questions_file, 'rb')) 178 | A = cPickle.load(open(answers_file, 'rb')) 179 | N_SAMPLES, N_WORDS = A.shape 180 | 181 | Q_test = Q[0:N_TEST,:] 182 | A_test = A[0:N_TEST,:] 183 | Q = Q[N_TEST + 1:,:] 184 | A = A[N_TEST + 1:,:] 185 | 186 | print("Number of Samples = %d"%(N_SAMPLES - N_TEST)) 187 | Step = np.around((N_SAMPLES - N_TEST) / NUM_SUBSETS) 188 | SAMPLE_ROUNDS = Step * NUM_SUBSETS 189 | 190 | 191 | # ====================================================================== 192 | # Bot training: 193 | # ====================================================================== 194 | 195 | x = range(0, NUM_EPOCHS) 196 | VALID_LOSS = np.zeros(NUM_EPOCHS) 197 | TRAIN_LOSS = np.zeros(NUM_EPOCHS) 198 | 199 | for n_epoch in range(NUM_EPOCHS): 200 | # Loop over training batches due to memory constraints 201 | for n_batch in range(0, SAMPLE_ROUNDS, Step): 202 | 203 | Q2 = Q[n_batch:n_batch+Step] 204 | s = Q2.shape 205 | counter = 0 206 | for id, sentence in enumerate(A[n_batch:n_batch+Step]): 207 | l = np.where(sentence==3) # the position od the symbol EOS 208 | limit = l[0][0] 209 | counter += limit + 1 210 | 211 | question = np.zeros((counter, MAX_INPUT_LENGTH)) 212 | answer = np.zeros((counter, MAX_INPUT_LENGTH)) 213 | target = np.zeros((counter, DICTIONARY_SIZE)) 214 | 215 | # Loop over the training examples: 216 | counter = 0 217 | for i, sentence in enumerate(A[n_batch:n_batch+Step]): 218 | ans_partial = np.zeros((1, MAX_INPUT_LENGTH)) 219 | 220 | # Loop over the positions of the current target output (the current output sequence) 221 | l = np.where(sent==3) # the position of the symbol EOS 222 | limit = l[0][0] 223 | 224 | for k in range(1, limit+1): 225 | # Mapping the target output (the next output word) for one-hot codding: 226 | target = np.zeros((1, DICTIONARY_SIZE)) 227 | target[0, sentence[k]] = 1 228 | 229 | # preparing the partial answer to input: 230 | ans_partial[0,-k:] = sentence[0:k] 231 | 232 | # training the model for one epoch using teacher forcing: 233 | 234 | question[counter, :] = Q2[i:i+1] 235 | answer[counter, :] = ans_partial 236 | target[counter, :] = target 237 | counter += 1 238 | 239 | print('Training epoch: %d, Training examples: %d - %d'%(n_epoch, n_batch, n_batch + Step)) 240 | model.fit([question, answer], target, batch_size=BATCH_SIZE, epochs=1) 241 | 242 | test_input = Q_test[41:42] 243 | print(print_result(test_input)) 244 | train_input = Q_test[41:42] 245 | print(print_result(train_input)) 246 | 247 | model.save_weights(weights_file, overwrite=True) 248 | -------------------------------------------------------------------------------- /conversation_spliter.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | text = open('dialog_simple', 'r') 5 | q = open('context', 'w') 6 | a = open('answers', 'w') 7 | pre_pre_previous_raw='' 8 | pre_previous_raw='' 9 | previous_raw='' 10 | person = ' ' 11 | previous_person=' ' 12 | 13 | l1 = ['won’t','won\'t','wouldn’t','wouldn\'t','’m', '’re', '’ve', '’ll', '’s','’d', 'n’t', '\'m', '\'re', '\'ve', '\'ll', '\'s', '\'d', 'can\'t', 'n\'t', 'B: ', 'A: ', ',', ';', '.', '?', '!', ':', '. ?', ', .', '. ,', 'EOS', 'BOS', 'eos', 'bos'] 14 | l2 = ['will not','will not','would not','would not',' am', ' are', ' have', ' will', ' is', ' had', ' not', ' am', ' are', ' have', ' will', ' is', ' had', 'can not', ' not', '', '', ' ,', ' ;', ' .', ' ?', ' !', ' :', '? ', '.', ',', '', '', '', ''] 15 | l3 = ['-', '_', ' *', ' /', '* ', '/ ', '\"', ' \\"', '\\ ', '--', '...', '. . .'] 16 | 17 | for i, raw_word in enumerate(text): 18 | pos = raw_word.find('+++$+++') 19 | 20 | if pos > -1: 21 | person = raw_word[pos+7:pos+10] 22 | raw_word = raw_word[pos+8:] 23 | while pos > -1: 24 | pos = raw_word.find('+++$+++') 25 | raw_word = raw_word[pos+2:] 26 | 27 | raw_word = raw_word.replace('$+++','') 28 | previous_person = person 29 | 30 | for j, term in enumerate(l1): 31 | raw_word = raw_word.replace(term,l2[j]) 32 | 33 | for term in l3: 34 | raw_word = raw_word.replace(term,' ') 35 | 36 | raw_word = raw_word.lower() 37 | 38 | if i>0: 39 | q.write(pre_previous_raw[:-1] + ' ' + previous_raw[:-1]+ '\n') # python will convert \n to os.linese 40 | a.write(raw_word[:-1]+ '\n') 41 | 42 | pre_pre_previous_raw = pre_previous_raw 43 | pre_previous_raw = previous_raw 44 | previous_raw = raw_word 45 | 46 | q.close() 47 | a.close() 48 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | np.random.seed(1234) # for reproducibility 5 | import pandas as pd 6 | import os 7 | import csv 8 | import nltk 9 | import itertools 10 | import operator 11 | import pickle 12 | import numpy as np 13 | from keras.preprocessing import sequence 14 | from scipy import sparse, io 15 | from numpy.random import permutation 16 | import re 17 | 18 | questions_file = 'context' 19 | answers_file = 'answers' 20 | vocabulary_file = 'vocabulary_movie' 21 | padded_questions_file = 'Padded_context' 22 | padded_answers_file = 'Padded_answers' 23 | unknown_token = 'something' 24 | 25 | vocabulary_size = 7000 26 | max_features = vocabulary_size 27 | maxlen_input = 50 28 | maxlen_output = 50 # cut texts after this number of words 29 | 30 | print ("Reading the context data...") 31 | q = open(questions_file, 'r') 32 | questions = q.read() 33 | print ("Reading the answer data...") 34 | a = open(answers_file, 'r') 35 | answers = a.read() 36 | all = answers + questions 37 | print ("Tokenazing the answers...") 38 | paragraphs_a = [p for p in answers.split('\n')] 39 | paragraphs_b = [p for p in all.split('\n')] 40 | paragraphs_a = ['BOS '+p+' EOS' for p in paragraphs_a] 41 | paragraphs_b = ['BOS '+p+' EOS' for p in paragraphs_b] 42 | paragraphs_b = ' '.join(paragraphs_b) 43 | tokenized_text = paragraphs_b.split() 44 | paragraphs_q = [p for p in questions.split('\n') ] 45 | tokenized_answers = [p.split() for p in paragraphs_a] 46 | tokenized_questions = [p.split() for p in paragraphs_q] 47 | 48 | ### Counting the word frequencies: 49 | ##word_freq = nltk.FreqDist(itertools.chain(tokenized_text)) 50 | ##print ("Found %d unique words tokens." % len(word_freq.items())) 51 | ## 52 | ### Getting the most common words and build index_to_word and word_to_index vectors: 53 | ##vocab = word_freq.most_common(vocabulary_size-1) 54 | ## 55 | ### Saving vocabulary: 56 | ##with open(vocabulary_file, 'w') as v: 57 | ## pickle.dump(vocab, v) 58 | 59 | vocab = pickle.load(open(vocabulary_file, 'rb')) 60 | 61 | 62 | index_to_word = [x[0] for x in vocab] 63 | index_to_word.append(unknown_token) 64 | word_to_index = dict([(w,i) for i,w in enumerate(index_to_word)]) 65 | 66 | print ("Using vocabulary of size %d." % vocabulary_size) 67 | print ("The least frequent word in our vocabulary is '%s' and appeared %d times." % (vocab[-1][0], vocab[-1][1])) 68 | 69 | # Replacing all words not in our vocabulary with the unknown token: 70 | for i, sent in enumerate(tokenized_answers): 71 | tokenized_answers[i] = [w if w in word_to_index else unknown_token for w in sent] 72 | 73 | for i, sent in enumerate(tokenized_questions): 74 | tokenized_questions[i] = [w if w in word_to_index else unknown_token for w in sent] 75 | 76 | # Creating the training data: 77 | X = np.asarray([[word_to_index[w] for w in sent] for sent in tokenized_questions]) 78 | Y = np.asarray([[word_to_index[w] for w in sent] for sent in tokenized_answers]) 79 | 80 | Q = sequence.pad_sequences(X, maxlen=maxlen_input) 81 | A = sequence.pad_sequences(Y, maxlen=maxlen_output, padding='post') 82 | 83 | with open(padded_questions_file, 'w') as q: 84 | pickle.dump(Q, q) 85 | 86 | with open(padded_answers_file, 'w') as a: 87 | pickle.dump(A, a) 88 | -------------------------------------------------------------------------------- /seq2seq_model.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Glove pre-trained word embedding 4 | """ 5 | 6 | from keras.layers import Input, Embedding, LSTM, Dense, RepeatVector, Bidirectional, Dropout, merge 7 | from keras.optimizers import Adam, SGD 8 | from keras.models import Model 9 | from keras.models import Sequential 10 | from keras.layers import Activation, Dense 11 | from keras.callbacks import EarlyStopping 12 | from keras.preprocessing import sequence 13 | 14 | import keras.backend as K 15 | import numpy as np 16 | np.random.seed(1234) # for reproducibility 17 | import pickle as cPickle 18 | import theano.tensor as T 19 | import os 20 | import pandas as pd 21 | import sys 22 | import matplotlib.pyplot as plt 23 | 24 | 25 | """ 26 | Build Your Chatbot!!! 27 | センテンスレベルの文脈ベクトル 28 | 単語レベルの意味ベクトル 29 | """ 30 | 31 | # params 32 | WORD2VEC_DIMS = 100 33 | DOC2VEC_DIMS = 300 34 | 35 | DICTIONARY_SIZE = 10000 36 | MAX_INPUT_LENGTH = 30 37 | MAX_OUTPUT_LENGTH = 30 38 | 39 | NUM_HIDDEN_UNITS = 256 40 | BATCH_SIZE = 64 41 | NUM_EPOCHS = 100 42 | 43 | NUM_SUBSETS = 1 44 | 45 | PATIENCE = 0 46 | DROPOUT = .25 47 | N_TEST = 100 48 | 49 | CALL_BACKS = EarlyStopping(monitor='val_loss', patience=PATIENCE) 50 | 51 | # files 52 | vocabulary_file = 'vocabulary_movie' 53 | questions_file = 'Padded_context' 54 | answers_file = 'Padded_answers' 55 | weights_file = 'my_model_weights20.h5' 56 | GLOVE_DIR = './glove.6B/' 57 | 58 | # padding and buckets 59 | 60 | BOS = "" 61 | EOS = "" 62 | PAD = "" 63 | 64 | BUCKETS = [(5,10),(10,15),(15,25),(20,30)] 65 | 66 | def print_result(input): 67 | 68 | ans_partial = np.zeros((1,maxlen_input)) 69 | ans_partial[0, -1] = 2 # the index of the symbol BOS (begin of sentence) 70 | for k in range(maxlen_input - 1): 71 | ye = model.predict([input, ans_partial]) 72 | mp = np.argmax(ye) 73 | ans_partial[0, 0:-1] = ans_partial[0, 1:] 74 | ans_partial[0, -1] = mp 75 | text = '' 76 | for k in ans_partial[0]: 77 | k = k.astype(int) 78 | if k < (dictionary_size-2): 79 | w = vocabulary[k] 80 | text = text + w[0] + ' ' 81 | return(text) 82 | 83 | 84 | # ====================================================================== 85 | # Reading a pre-trained word embedding and addapting to our vocabulary: 86 | # ====================================================================== 87 | 88 | # 辞書づくり 89 | word2vec_index = {} 90 | f = open(os.path.join(GLOVE_DIR, "glove.6B.100d.txt")) 91 | for line in f: 92 | words = line.split() 93 | word = words[0] 94 | index = np.asarray(words[1:], dtype="float32") 95 | word2vec_index[word] = index 96 | f.close() 97 | 98 | print("The number of word vecters are: ", len(word2vec_index)) 99 | 100 | word_embedding_matrix = np.zeros((DICTIONARY_SIZE, WORD2VEC_DIMS)) 101 | # Load vocabulary 102 | vocabulary = cPickle.load(open(vocabulary_file, 'rb')) 103 | 104 | i = 0 105 | for word in vocabulary: 106 | word2vec = word2vec_index.get(word[0]) 107 | if word2vec is not None: 108 | word_embedding_matrix[i] = word2vec 109 | i += 1 110 | 111 | 112 | # ====================================================================== 113 | # Keras model of the chatbot: 114 | # ====================================================================== 115 | 116 | ADAM = Adam(lr=0.00005) 117 | 118 | """ 119 | Input Layer #Document*2 120 | """ 121 | input_context = Input(shape=(MAX_INPUT_LENGTH,), dtype="int32", name="input_context") 122 | input_answer = Input(shape=(MAX_INPUT_LENGTH,), dtype="int32", name="input_answer") 123 | 124 | """ 125 | Embedding Layer: 正の整数(インデックス)を固定次元の密ベクトルに変換します. 126 | ・input_dim: 正の整数.語彙数.入力データの最大インデックス + 1. 127 | ・output_dim: 0以上の整数.密なembeddingsの次元数. 128 | ・input_length: 入力の系列長(定数). この引数はこのレイヤーの後にFlattenからDenseレイヤーへ接続する際に必要です (これがないと,denseの出力のshapeを計算できません). 129 | """ 130 | # weightが存在したら引用する 131 | if os.path.isfile(weights_file): 132 | Shared_Embedding = Embedding(input_dim=DICTIONARY_SIZE, output_dim=WORD2VEC_DIMS, input_length=MAX_INPUT_LENGTH,) 133 | else: 134 | Shared_Embedding = Embedding(input_dim=DICTIONARY_SIZE, output_dim=WORD2VEC_DIMS, input_length=MAX_INPUT_LENGTH, 135 | weights=[word_embedding_matrix]) 136 | 137 | """ 138 | Shared Embedding Layer #Doc2Vec(Document*2) 139 | """ 140 | shared_embedding_context = Shared_Embedding(input_context) 141 | shared_embedding_answer = Shared_Embedding(input_answer) 142 | 143 | """ 144 | LSTM Layer # 145 | """ 146 | Encoder_LSTM = LSTM(units=DOC2VEC_DIMS, init= "lecun_uniform") 147 | Decoder_LSTM = LSTM(units=DOC2VEC_DIMS, init= "lecun_uniform") 148 | embedding_context = Encoder_LSTM(shared_embedding_context) 149 | embedding_answer = Decoder_LSTM(shared_embedding_answer) 150 | 151 | """ 152 | Merge Layer # 153 | """ 154 | merge_layer = merge([embedding_context, embedding_answer], mode='concat', concat_axis=1) 155 | 156 | """ 157 | Dense Layer # 158 | """ 159 | dence_layer = Dense(DICTIONARY_SIZE/2, activation="relu")(merge_layer) 160 | 161 | """ 162 | Output Layer # 163 | """ 164 | outputs = Dense(DICTIONARY_SIZE, activation="softmax")(dence_layer) 165 | 166 | """ 167 | Modeling 168 | """ 169 | model = Model(input=[input_context, input_answer], output=[outputs]) 170 | model.compile(loss="categorical_crossentropy", optimizer=ADAM) 171 | 172 | if os.path.isfile(weights_file): 173 | model.load_weights(weights_file) 174 | 175 | 176 | # ====================================================================== 177 | # Loading the data: 178 | # ====================================================================== 179 | 180 | Q = cPickle.load(open(questions_file, 'rb')) 181 | A = cPickle.load(open(answers_file, 'rb')) 182 | N_SAMPLES, N_WORDS = A.shape 183 | 184 | Q_test = Q[0:N_TEST,:] 185 | A_test = A[0:N_TEST,:] 186 | Q = Q[N_TEST + 1:,:] 187 | A = A[N_TEST + 1:,:] 188 | 189 | print("Number of Samples = %d"%(N_SAMPLES - N_TEST)) 190 | Step = np.around((N_SAMPLES - N_TEST) / NUM_SUBSETS) 191 | SAMPLE_ROUNDS = Step * NUM_SUBSETS 192 | 193 | 194 | # ====================================================================== 195 | # Bot training: 196 | # ====================================================================== 197 | 198 | x = range(0, NUM_EPOCHS) 199 | VALID_LOSS = np.zeros(NUM_EPOCHS) 200 | TRAIN_LOSS = np.zeros(NUM_EPOCHS) 201 | 202 | for n_epoch in range(NUM_EPOCHS): 203 | # Loop over training batches due to memory constraints 204 | for n_batch in range(0, SAMPLE_ROUNDS, Step): 205 | 206 | Q2 = Q[n_batch:n_batch+Step] 207 | s = Q2.shape 208 | counter = 0 209 | for id, sentence in enumerate(A[n_batch:n_batch+Step]): 210 | l = np.where(sentence==3) # the position od the symbol EOS 211 | limit = l[0][0] 212 | counter += limit + 1 213 | 214 | question = np.zeros((counter, MAX_INPUT_LENGTH)) 215 | answer = np.zeros((counter, MAX_INPUT_LENGTH)) 216 | target = np.zeros((counter, DICTIONARY_SIZE)) 217 | 218 | # Loop over the training examples: 219 | counter = 0 220 | for i, sentence in enumerate(A[n_batch:n_batch+Step]): 221 | ans_partial = np.zeros((1, MAX_INPUT_LENGTH)) 222 | 223 | # Loop over the positions of the current target output (the current output sequence) 224 | l = np.where(sent==3) # the position of the symbol EOS 225 | limit = l[0][0] 226 | 227 | for k in range(1, limit+1): 228 | # Mapping the target output (the next output word) for one-hot codding: 229 | target = np.zeros((1, DICTIONARY_SIZE)) 230 | target[0, sentence[k]] = 1 231 | 232 | # preparing the partial answer to input: 233 | ans_partial[0,-k:] = sentence[0:k] 234 | 235 | # training the model for one epoch using teacher forcing: 236 | 237 | question[counter, :] = Q2[i:i+1] 238 | answer[counter, :] = ans_partial 239 | target[counter, :] = target 240 | counter += 1 241 | 242 | print('Training epoch: %d, Training examples: %d - %d'%(n_epoch, n_batch, n_batch + Step)) 243 | model.fit([question, answer], target, batch_size=BATCH_SIZE, epochs=1) 244 | 245 | test_input = Q_test[41:42] 246 | print(print_result(test_input)) 247 | train_input = Q_test[41:42] 248 | print(print_result(train_input)) 249 | 250 | model.save_weights(weights_file, overwrite=True) 251 | -------------------------------------------------------------------------------- /training_chatbot.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samurainote/Automatic-Encoder-Decoder_Seq2Seq_Chatbot/d4ef7a0b8a3760507ebbac34354cbc1469708d2e/training_chatbot.py --------------------------------------------------------------------------------