├── dic.pkl ├── inv.pkl ├── example_chat.JPG ├── README.md ├── AttentionLayer.py ├── chatbot.py ├── LICENSE └── seq2seq-chatbot-keras-with-attention.ipynb /dic.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pawandeep-prog/keras-seq2seq-chatbot-with-attention/HEAD/dic.pkl -------------------------------------------------------------------------------- /inv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pawandeep-prog/keras-seq2seq-chatbot-with-attention/HEAD/inv.pkl -------------------------------------------------------------------------------- /example_chat.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pawandeep-prog/keras-seq2seq-chatbot-with-attention/HEAD/example_chat.JPG -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-seq2seq-chatbot-with-attention 2 | It is a seq2seq encoder decoder chatbot using keras and with attention 3 | 4 |

files

5 |
    6 |
  1. chatbot.py :- This is file to run chatbot using the saved model
  2. 7 |
  3. ipynb file :- This file is all in one you just need below datasets to run it Hopefully with no errors. 8 | it also saves the model in h5 format
  4. 9 |
10 | 11 | thanks to thushv89 https://github.com/thushv89/attention_keras for attention layer all credit for Attention Layer layer class goes him 12 | 13 |

How to RUN

14 |
    15 |
  1. Run on kaggle : https://www.kaggle.com/programminghut/seq2seq-chatbot-keras-with-attention
  2. 16 |
  3. use chatbot.py once you ran ipynb file because save model is executed in above ipynb file
  4. 17 |
18 | 19 |

Datasets used :

20 |
    21 |
  1. glove6b 50d : https://www.kaggle.com/watts2/glove6b50dtxt
  2. 22 |
  3. cornell movie : https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html
  4. 23 |
24 | 25 | chatbot.py is the python file to run locally using the saved model. 26 | 27 |

Sample chat

28 | 29 | 30 | if you still have queries you may contact me here 31 | contact me : 32 | facebook : https://m.facebook.com/proogramminghub
33 | twitter : https://twitter.com/programming_hut
34 | github : https://github.com/Pawandeep-prog
35 | discord : https://discord.gg/G5Cunyg
36 | linkedin : https://www.linkedin.com/in/programminghut
37 | youtube : https://www.youtube.com/c/programminghutofficial
38 | 39 | -------------------------------------------------------------------------------- /AttentionLayer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | from tensorflow.python.keras.layers import Layer 4 | from tensorflow.python.keras import backend as K 5 | 6 | 7 | class AttentionLayer(Layer): 8 | """ 9 | This class implements Bahdanau attention (https://arxiv.org/pdf/1409.0473.pdf). 10 | There are three sets of weights introduced W_a, U_a, and V_a 11 | """ 12 | 13 | def __init__(self, **kwargs): 14 | super(AttentionLayer, self).__init__(**kwargs) 15 | 16 | def build(self, input_shape): 17 | assert isinstance(input_shape, list) 18 | # Create a trainable weight variable for this layer. 19 | 20 | self.W_a = self.add_weight(name='W_a', 21 | shape=tf.TensorShape((input_shape[0][2], input_shape[0][2])), 22 | initializer='uniform', 23 | trainable=True) 24 | self.U_a = self.add_weight(name='U_a', 25 | shape=tf.TensorShape((input_shape[1][2], input_shape[0][2])), 26 | initializer='uniform', 27 | trainable=True) 28 | self.V_a = self.add_weight(name='V_a', 29 | shape=tf.TensorShape((input_shape[0][2], 1)), 30 | initializer='uniform', 31 | trainable=True) 32 | 33 | super(AttentionLayer, self).build(input_shape) # Be sure to call this at the end 34 | 35 | def call(self, inputs, verbose=False): 36 | """ 37 | inputs: [encoder_output_sequence, decoder_output_sequence] 38 | """ 39 | assert type(inputs) == list 40 | encoder_out_seq, decoder_out_seq = inputs 41 | if verbose: 42 | print('encoder_out_seq>', encoder_out_seq.shape) 43 | print('decoder_out_seq>', decoder_out_seq.shape) 44 | 45 | def energy_step(inputs, states): 46 | """ Step function for computing energy for a single decoder state 47 | inputs: (batchsize * 1 * de_in_dim) 48 | states: (batchsize * 1 * de_latent_dim) 49 | """ 50 | 51 | assert_msg = "States must be an iterable. Got {} of type {}".format(states, type(states)) 52 | assert isinstance(states, list) or isinstance(states, tuple), assert_msg 53 | 54 | """ Some parameters required for shaping tensors""" 55 | en_seq_len, en_hidden = encoder_out_seq.shape[1], encoder_out_seq.shape[2] 56 | de_hidden = inputs.shape[-1] 57 | 58 | """ Computing S.Wa where S=[s0, s1, ..., si]""" 59 | # <= batch size * en_seq_len * latent_dim 60 | W_a_dot_s = K.dot(encoder_out_seq, self.W_a) 61 | 62 | """ Computing hj.Ua """ 63 | U_a_dot_h = K.expand_dims(K.dot(inputs, self.U_a), 1) # <= batch_size, 1, latent_dim 64 | if verbose: 65 | print('Ua.h>', U_a_dot_h.shape) 66 | 67 | """ tanh(S.Wa + hj.Ua) """ 68 | # <= batch_size*en_seq_len, latent_dim 69 | Ws_plus_Uh = K.tanh(W_a_dot_s + U_a_dot_h) 70 | if verbose: 71 | print('Ws+Uh>', Ws_plus_Uh.shape) 72 | 73 | """ softmax(va.tanh(S.Wa + hj.Ua)) """ 74 | # <= batch_size, en_seq_len 75 | e_i = K.squeeze(K.dot(Ws_plus_Uh, self.V_a), axis=-1) 76 | # <= batch_size, en_seq_len 77 | e_i = K.softmax(e_i) 78 | 79 | if verbose: 80 | print('ei>', e_i.shape) 81 | 82 | return e_i, [e_i] 83 | 84 | def context_step(inputs, states): 85 | """ Step function for computing ci using ei """ 86 | 87 | assert_msg = "States must be an iterable. Got {} of type {}".format(states, type(states)) 88 | assert isinstance(states, list) or isinstance(states, tuple), assert_msg 89 | 90 | # <= batch_size, hidden_size 91 | c_i = K.sum(encoder_out_seq * K.expand_dims(inputs, -1), axis=1) 92 | if verbose: 93 | print('ci>', c_i.shape) 94 | return c_i, [c_i] 95 | 96 | fake_state_c = K.sum(encoder_out_seq, axis=1) 97 | fake_state_e = K.sum(encoder_out_seq, axis=2) # <= (batch_size, enc_seq_len, latent_dim 98 | 99 | """ Computing energy outputs """ 100 | # e_outputs => (batch_size, de_seq_len, en_seq_len) 101 | last_out, e_outputs, _ = K.rnn( 102 | energy_step, decoder_out_seq, [fake_state_e], 103 | ) 104 | 105 | """ Computing context vectors """ 106 | last_out, c_outputs, _ = K.rnn( 107 | context_step, e_outputs, [fake_state_c], 108 | ) 109 | 110 | return c_outputs, e_outputs 111 | 112 | def compute_output_shape(self, input_shape): 113 | """ Outputs produced by the layer """ 114 | return [ 115 | tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[1][2])), 116 | tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][1])) 117 | ] -------------------------------------------------------------------------------- /chatbot.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.models import load_model, Model 2 | from tensorflow.keras.layers import Input, Concatenate 3 | import tensorflow as tf 4 | import os 5 | from tensorflow.python.keras.layers import Layer 6 | from tensorflow.python.keras import backend as K 7 | import pickle 8 | import numpy as np 9 | import re 10 | from AttentionLayer import AttentionLayer 11 | 12 | with open('dic.pkl', 'rb') as f: 13 | vocab = pickle.load(f) 14 | with open('inv.pkl', 'rb') as f: 15 | inv_vocab = pickle.load(f) 16 | 17 | 18 | def clean_text(txt): 19 | txt = txt.lower() 20 | txt = re.sub(r"i'm", "i am", txt) 21 | txt = re.sub(r"he's", "he is", txt) 22 | txt = re.sub(r"she's", "she is", txt) 23 | txt = re.sub(r"that's", "that is", txt) 24 | txt = re.sub(r"what's", "what is", txt) 25 | txt = re.sub(r"where's", "where is", txt) 26 | txt = re.sub(r"\'ll", " will", txt) 27 | txt = re.sub(r"\'ve", " have", txt) 28 | txt = re.sub(r"\'re", " are", txt) 29 | txt = re.sub(r"\'d", " would", txt) 30 | txt = re.sub(r"won't", "will not", txt) 31 | txt = re.sub(r"can't", "can not", txt) 32 | txt = re.sub(r"[^\w\s]", "", txt) 33 | return txt 34 | 35 | 36 | 37 | attn_layer = AttentionLayer() 38 | 39 | model = load_model('chatbot.h5', custom_objects={'AttentionLayer' : attn_layer}) 40 | 41 | 42 | 43 | encoder_inputs = model.layers[0].input 44 | embed = model.layers[2] 45 | enc_embed = embed(encoder_inputs) 46 | enocoder_layer = model.layers[3] 47 | 48 | encoder_outputs, fstate_h, fstate_c, bstate_h, bstate_c = enocoder_layer(enc_embed) 49 | 50 | h = Concatenate()([fstate_h, bstate_h]) 51 | c = Concatenate()([fstate_c, bstate_c]) 52 | encoder_states = [h, c] 53 | 54 | enc_model = Model(encoder_inputs, 55 | [encoder_outputs, 56 | encoder_states]) 57 | 58 | 59 | latent_dim = 800 60 | 61 | decoder_inputs = model.layers[1].input 62 | decoder_lstm = model.layers[6] 63 | decoder_dense = model.layers[9] 64 | decoder_state_input_h = Input(shape=(latent_dim,), name='input_3') 65 | decoder_state_input_c = Input(shape=(latent_dim,), name='input_4') 66 | 67 | decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] 68 | 69 | dec_embed = embed(decoder_inputs) 70 | 71 | decoder_outputs, state_h, state_c = decoder_lstm(dec_embed, initial_state=decoder_states_inputs) 72 | decoder_states = [state_h, state_c] 73 | 74 | dec_model = Model([decoder_inputs, decoder_states_inputs], [decoder_outputs] + decoder_states) 75 | 76 | dec_dense = model.layers[-1] 77 | attn_layer = model.layers[7] 78 | 79 | from keras.preprocessing.sequence import pad_sequences 80 | print("##########################################") 81 | print("# start chatting ver. 1.0 #") 82 | print("##########################################") 83 | 84 | 85 | prepro1 = "" 86 | while prepro1 != 'q': 87 | 88 | prepro1 = input("you : ") 89 | try: 90 | prepro1 = clean_text(prepro1) 91 | prepro = [prepro1] 92 | 93 | txt = [] 94 | for x in prepro: 95 | lst = [] 96 | for y in x.split(): 97 | try: 98 | lst.append(vocab[y]) 99 | except: 100 | lst.append(vocab['']) 101 | txt.append(lst) 102 | txt = pad_sequences(txt, 13, padding='post') 103 | 104 | 105 | ### 106 | enc_op, stat = enc_model.predict( txt ) 107 | 108 | empty_target_seq = np.zeros( ( 1 , 1) ) 109 | empty_target_seq[0, 0] = vocab[''] 110 | stop_condition = False 111 | decoded_translation = '' 112 | 113 | 114 | while not stop_condition : 115 | 116 | dec_outputs , h , c = dec_model.predict([ empty_target_seq ] + stat ) 117 | 118 | ### 119 | ########################### 120 | attn_op, attn_state = attn_layer([enc_op, dec_outputs]) 121 | decoder_concat_input = Concatenate(axis=-1)([dec_outputs, attn_op]) 122 | decoder_concat_input = dec_dense(decoder_concat_input) 123 | ########################### 124 | 125 | sampled_word_index = np.argmax( decoder_concat_input[0, -1, :] ) 126 | 127 | sampled_word = inv_vocab[sampled_word_index] + ' ' 128 | 129 | if sampled_word != ' ': 130 | decoded_translation += sampled_word 131 | 132 | 133 | if sampled_word == ' ' or len(decoded_translation.split()) > 13: 134 | stop_condition = True 135 | 136 | empty_target_seq = np.zeros( ( 1 , 1 ) ) 137 | empty_target_seq[ 0 , 0 ] = sampled_word_index 138 | stat = [ h , c ] 139 | 140 | print("chatbot attention : ", decoded_translation ) 141 | print("==============================================") 142 | 143 | except: 144 | print("sorry didn't got you , please type again :( ") 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /seq2seq-chatbot-keras-with-attention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 8 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5" 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import numpy as np \n", 13 | "import pandas as pd \n", 14 | "import os" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "source": [ 22 | "# Attention Class" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import tensorflow as tf\n", 32 | "import os\n", 33 | "from tensorflow.python.keras.layers import Layer\n", 34 | "from tensorflow.python.keras import backend as K\n", 35 | "\n", 36 | "\n", 37 | "class AttentionLayer(Layer):\n", 38 | " \"\"\"\n", 39 | " This class implements Bahdanau attention (https://arxiv.org/pdf/1409.0473.pdf).\n", 40 | " There are three sets of weights introduced W_a, U_a, and V_a\n", 41 | " \"\"\"\n", 42 | "\n", 43 | " def __init__(self, **kwargs):\n", 44 | " super(AttentionLayer, self).__init__(**kwargs)\n", 45 | "\n", 46 | " def build(self, input_shape):\n", 47 | " assert isinstance(input_shape, list)\n", 48 | " # Create a trainable weight variable for this layer.\n", 49 | "\n", 50 | " self.W_a = self.add_weight(name='W_a',\n", 51 | " shape=tf.TensorShape((input_shape[0][2], input_shape[0][2])),\n", 52 | " initializer='uniform',\n", 53 | " trainable=True)\n", 54 | " self.U_a = self.add_weight(name='U_a',\n", 55 | " shape=tf.TensorShape((input_shape[1][2], input_shape[0][2])),\n", 56 | " initializer='uniform',\n", 57 | " trainable=True)\n", 58 | " self.V_a = self.add_weight(name='V_a',\n", 59 | " shape=tf.TensorShape((input_shape[0][2], 1)),\n", 60 | " initializer='uniform',\n", 61 | " trainable=True)\n", 62 | "\n", 63 | " super(AttentionLayer, self).build(input_shape) # Be sure to call this at the end\n", 64 | "\n", 65 | " def call(self, inputs, verbose=False):\n", 66 | " \"\"\"\n", 67 | " inputs: [encoder_output_sequence, decoder_output_sequence]\n", 68 | " \"\"\"\n", 69 | " assert type(inputs) == list\n", 70 | " encoder_out_seq, decoder_out_seq = inputs\n", 71 | " if verbose:\n", 72 | " print('encoder_out_seq>', encoder_out_seq.shape)\n", 73 | " print('decoder_out_seq>', decoder_out_seq.shape)\n", 74 | "\n", 75 | " def energy_step(inputs, states):\n", 76 | " \"\"\" Step function for computing energy for a single decoder state\n", 77 | " inputs: (batchsize * 1 * de_in_dim)\n", 78 | " states: (batchsize * 1 * de_latent_dim)\n", 79 | " \"\"\"\n", 80 | "\n", 81 | " assert_msg = \"States must be an iterable. Got {} of type {}\".format(states, type(states))\n", 82 | " assert isinstance(states, list) or isinstance(states, tuple), assert_msg\n", 83 | "\n", 84 | " \"\"\" Some parameters required for shaping tensors\"\"\"\n", 85 | " en_seq_len, en_hidden = encoder_out_seq.shape[1], encoder_out_seq.shape[2]\n", 86 | " de_hidden = inputs.shape[-1]\n", 87 | "\n", 88 | " \"\"\" Computing S.Wa where S=[s0, s1, ..., si]\"\"\"\n", 89 | " # <= batch size * en_seq_len * latent_dim\n", 90 | " W_a_dot_s = K.dot(encoder_out_seq, self.W_a)\n", 91 | "\n", 92 | " \"\"\" Computing hj.Ua \"\"\"\n", 93 | " U_a_dot_h = K.expand_dims(K.dot(inputs, self.U_a), 1) # <= batch_size, 1, latent_dim\n", 94 | " if verbose:\n", 95 | " print('Ua.h>', U_a_dot_h.shape)\n", 96 | "\n", 97 | " \"\"\" tanh(S.Wa + hj.Ua) \"\"\"\n", 98 | " # <= batch_size*en_seq_len, latent_dim\n", 99 | " Ws_plus_Uh = K.tanh(W_a_dot_s + U_a_dot_h)\n", 100 | " if verbose:\n", 101 | " print('Ws+Uh>', Ws_plus_Uh.shape)\n", 102 | "\n", 103 | " \"\"\" softmax(va.tanh(S.Wa + hj.Ua)) \"\"\"\n", 104 | " # <= batch_size, en_seq_len\n", 105 | " e_i = K.squeeze(K.dot(Ws_plus_Uh, self.V_a), axis=-1)\n", 106 | " # <= batch_size, en_seq_len\n", 107 | " e_i = K.softmax(e_i)\n", 108 | "\n", 109 | " if verbose:\n", 110 | " print('ei>', e_i.shape)\n", 111 | "\n", 112 | " return e_i, [e_i]\n", 113 | "\n", 114 | " def context_step(inputs, states):\n", 115 | " \"\"\" Step function for computing ci using ei \"\"\"\n", 116 | "\n", 117 | " assert_msg = \"States must be an iterable. Got {} of type {}\".format(states, type(states))\n", 118 | " assert isinstance(states, list) or isinstance(states, tuple), assert_msg\n", 119 | "\n", 120 | " # <= batch_size, hidden_size\n", 121 | " c_i = K.sum(encoder_out_seq * K.expand_dims(inputs, -1), axis=1)\n", 122 | " if verbose:\n", 123 | " print('ci>', c_i.shape)\n", 124 | " return c_i, [c_i]\n", 125 | "\n", 126 | " fake_state_c = K.sum(encoder_out_seq, axis=1)\n", 127 | " fake_state_e = K.sum(encoder_out_seq, axis=2) # <= (batch_size, enc_seq_len, latent_dim\n", 128 | "\n", 129 | " \"\"\" Computing energy outputs \"\"\"\n", 130 | " # e_outputs => (batch_size, de_seq_len, en_seq_len)\n", 131 | " last_out, e_outputs, _ = K.rnn(\n", 132 | " energy_step, decoder_out_seq, [fake_state_e],\n", 133 | " )\n", 134 | "\n", 135 | " \"\"\" Computing context vectors \"\"\"\n", 136 | " last_out, c_outputs, _ = K.rnn(\n", 137 | " context_step, e_outputs, [fake_state_c],\n", 138 | " )\n", 139 | "\n", 140 | " return c_outputs, e_outputs\n", 141 | "\n", 142 | " def compute_output_shape(self, input_shape):\n", 143 | " \"\"\" Outputs produced by the layer \"\"\"\n", 144 | " return [\n", 145 | " tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[1][2])),\n", 146 | " tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][1]))\n", 147 | " ]" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 3, 153 | "metadata": { 154 | "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", 155 | "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a" 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "import re\n", 160 | "\n", 161 | "lines = open('../input/chatbot-data/cornell movie-dialogs corpus/movie_lines.txt', encoding='utf-8',\n", 162 | " errors='ignore').read().split('\\n')\n", 163 | "\n", 164 | "convers = open('../input/chatbot-data/cornell movie-dialogs corpus/movie_conversations.txt', encoding='utf-8',\n", 165 | " errors='ignore').read().split('\\n')\n" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 4, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "304714" 177 | ] 178 | }, 179 | "execution_count": 4, 180 | "metadata": {}, 181 | "output_type": "execute_result" 182 | } 183 | ], 184 | "source": [ 185 | "len(lines)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "source": [ 193 | "# Data Preprocess" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 5, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "\n", 203 | "exchn = []\n", 204 | "for conver in convers:\n", 205 | " exchn.append(conver.split(' +++$+++ ')[-1][1:-1].replace(\"'\", \" \").replace(\",\",\"\").split())\n", 206 | "\n", 207 | "diag = {}\n", 208 | "for line in lines:\n", 209 | " diag[line.split(' +++$+++ ')[0]] = line.split(' +++$+++ ')[-1]\n", 210 | "\n", 211 | "\n", 212 | "\n", 213 | "## delete\n", 214 | "del(lines, convers, conver, line)\n", 215 | "\n", 216 | "\n", 217 | "\n", 218 | "questions = []\n", 219 | "answers = []\n", 220 | "\n", 221 | "for conver in exchn:\n", 222 | " for i in range(len(conver) - 1):\n", 223 | " questions.append(diag[conver[i]])\n", 224 | " answers.append(diag[conver[i+1]])\n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | "\n", 229 | "## delete\n", 230 | "del(diag, exchn, conver, i)\n", 231 | "\n", 232 | "\n", 233 | "###############################\n", 234 | "# max_len = 13 #\n", 235 | "###############################\n", 236 | "\n", 237 | "sorted_ques = []\n", 238 | "sorted_ans = []\n", 239 | "for i in range(len(questions)):\n", 240 | " if len(questions[i]) < 13:\n", 241 | " sorted_ques.append(questions[i])\n", 242 | " sorted_ans.append(answers[i])\n", 243 | "\n", 244 | "\n", 245 | "\n", 246 | "###############################\n", 247 | "# #\n", 248 | "###############################\n", 249 | "\n", 250 | "\n", 251 | "\n", 252 | "\n", 253 | "def clean_text(txt):\n", 254 | " txt = txt.lower()\n", 255 | " txt = re.sub(r\"i'm\", \"i am\", txt)\n", 256 | " txt = re.sub(r\"he's\", \"he is\", txt)\n", 257 | " txt = re.sub(r\"she's\", \"she is\", txt)\n", 258 | " txt = re.sub(r\"that's\", \"that is\", txt)\n", 259 | " txt = re.sub(r\"what's\", \"what is\", txt)\n", 260 | " txt = re.sub(r\"where's\", \"where is\", txt)\n", 261 | " txt = re.sub(r\"\\'ll\", \" will\", txt)\n", 262 | " txt = re.sub(r\"\\'ve\", \" have\", txt)\n", 263 | " txt = re.sub(r\"\\'re\", \" are\", txt)\n", 264 | " txt = re.sub(r\"\\'d\", \" would\", txt)\n", 265 | " txt = re.sub(r\"won't\", \"will not\", txt)\n", 266 | " txt = re.sub(r\"can't\", \"can not\", txt)\n", 267 | " txt = re.sub(r\"[^\\w\\s]\", \"\", txt)\n", 268 | " return txt\n", 269 | "\n", 270 | "clean_ques = []\n", 271 | "clean_ans = []\n", 272 | "\n", 273 | "for line in sorted_ques:\n", 274 | " clean_ques.append(clean_text(line))\n", 275 | " \n", 276 | "for line in sorted_ans:\n", 277 | " clean_ans.append(clean_text(line))\n", 278 | "\n", 279 | "\n", 280 | "\n", 281 | "## delete\n", 282 | "del(answers, questions, line)\n", 283 | "\n", 284 | "\n", 285 | "###############################\n", 286 | "# #\n", 287 | "###############################\n", 288 | "\n", 289 | "\n", 290 | "for i in range(len(clean_ans)):\n", 291 | " clean_ans[i] = ' '.join(clean_ans[i].split()[:11])\n", 292 | "\n", 293 | "\n", 294 | "\n", 295 | "###############################\n", 296 | "# #\n", 297 | "###############################\n", 298 | "\n", 299 | "del(sorted_ans, sorted_ques)\n", 300 | "\n", 301 | "\n", 302 | "## trimming\n", 303 | "clean_ans=clean_ans[:30000]\n", 304 | "clean_ques=clean_ques[:30000]\n", 305 | "## delete\n", 306 | "\n", 307 | "\n", 308 | "### count occurences ###\n", 309 | "word2count = {}\n", 310 | "\n", 311 | "for line in clean_ques:\n", 312 | " for word in line.split():\n", 313 | " if word not in word2count:\n", 314 | " word2count[word] = 1\n", 315 | " else:\n", 316 | " word2count[word] += 1\n", 317 | "for line in clean_ans:\n", 318 | " for word in line.split():\n", 319 | " if word not in word2count:\n", 320 | " word2count[word] = 1\n", 321 | " else:\n", 322 | " word2count[word] += 1\n", 323 | "\n", 324 | "## delete\n", 325 | "del(word, line)\n", 326 | "\n", 327 | "\n", 328 | "### remove less frequent ###\n", 329 | "thresh = 5\n", 330 | "\n", 331 | "vocab = {}\n", 332 | "word_num = 0\n", 333 | "for word, count in word2count.items():\n", 334 | " if count >= thresh:\n", 335 | " vocab[word] = word_num\n", 336 | " word_num += 1\n", 337 | " \n", 338 | "## delete\n", 339 | "del(word2count, word, count, thresh) \n", 340 | "del(word_num) \n", 341 | "\n", 342 | "\n", 343 | "\n", 344 | "for i in range(len(clean_ans)):\n", 345 | " clean_ans[i] = ' ' + clean_ans[i] + ' '\n", 346 | "\n", 347 | "\n", 348 | "\n", 349 | "tokens = ['', '', '', '']\n", 350 | "x = len(vocab)\n", 351 | "for token in tokens:\n", 352 | " vocab[token] = x\n", 353 | " x += 1\n", 354 | " \n", 355 | " \n", 356 | "\n", 357 | "vocab['cameron'] = vocab['']\n", 358 | "vocab[''] = 0\n", 359 | "\n", 360 | "## delete\n", 361 | "del(token, tokens) \n", 362 | "del(x)\n", 363 | "\n", 364 | "### inv answers dict ###\n", 365 | "inv_vocab = {w:v for v, w in vocab.items()}\n", 366 | "\n", 367 | "\n", 368 | "\n", 369 | "## delete\n", 370 | "del(i)\n", 371 | "\n", 372 | "\n", 373 | "\n", 374 | "encoder_inp = []\n", 375 | "for line in clean_ques:\n", 376 | " lst = []\n", 377 | " for word in line.split():\n", 378 | " if word not in vocab:\n", 379 | " lst.append(vocab[''])\n", 380 | " else:\n", 381 | " lst.append(vocab[word])\n", 382 | " \n", 383 | " encoder_inp.append(lst)\n", 384 | "\n", 385 | "decoder_inp = []\n", 386 | "for line in clean_ans:\n", 387 | " lst = []\n", 388 | " for word in line.split():\n", 389 | " if word not in vocab:\n", 390 | " lst.append(vocab[''])\n", 391 | " else:\n", 392 | " lst.append(vocab[word]) \n", 393 | " decoder_inp.append(lst)\n", 394 | "\n", 395 | "### delete\n", 396 | "del(clean_ans, clean_ques, line, lst, word)\n", 397 | "\n", 398 | "\n", 399 | "\n", 400 | "\n", 401 | "\n", 402 | "\n", 403 | "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", 404 | "encoder_inp = pad_sequences(encoder_inp, 13, padding='post', truncating='post')\n", 405 | "decoder_inp = pad_sequences(decoder_inp, 13, padding='post', truncating='post')\n", 406 | "\n", 407 | "\n", 408 | "\n", 409 | "\n", 410 | "decoder_final_output = []\n", 411 | "for i in decoder_inp:\n", 412 | " decoder_final_output.append(i[1:]) \n", 413 | "\n", 414 | "decoder_final_output = pad_sequences(decoder_final_output, 13, padding='post', truncating='post')\n", 415 | "\n", 416 | "\n", 417 | "del(i)\n", 418 | "\n" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 6, 424 | "metadata": {}, 425 | "outputs": [ 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "(30000, 13) (30000, 13) (30000, 13) 3027 3027 \n" 431 | ] 432 | } 433 | ], 434 | "source": [ 435 | "# decoder_final_output, decoder_final_input, encoder_final, vocab, inv_vocab\n", 436 | "\n", 437 | "VOCAB_SIZE = len(vocab)\n", 438 | "MAX_LEN = 13\n", 439 | "\n", 440 | "print(decoder_final_output.shape, decoder_inp.shape, encoder_inp.shape, len(vocab), len(inv_vocab), inv_vocab[0])" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 7, 446 | "metadata": {}, 447 | "outputs": [ 448 | { 449 | "data": { 450 | "text/plain": [ 451 | "'they'" 452 | ] 453 | }, 454 | "execution_count": 7, 455 | "metadata": {}, 456 | "output_type": "execute_result" 457 | } 458 | ], 459 | "source": [ 460 | "inv_vocab[16]" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 8, 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "#print(len(decoder_final_input), MAX_LEN, VOCAB_SIZE)\n", 470 | "#decoder_final_input[0]\n", 471 | "#decoder_output_data = np.zeros((len(decoder_final_input), MAX_LEN, VOCAB_SIZE), dtype=\"float32\")\n", 472 | "#print(decoder_output_data.shape)\n", 473 | "#decoder_final_input[80]" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 9, 479 | "metadata": {}, 480 | "outputs": [], 481 | "source": [ 482 | "from tensorflow.keras.utils import to_categorical\n", 483 | "decoder_final_output = to_categorical(decoder_final_output, len(vocab))" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 10, 489 | "metadata": {}, 490 | "outputs": [ 491 | { 492 | "data": { 493 | "text/plain": [ 494 | "(30000, 13, 3027)" 495 | ] 496 | }, 497 | "execution_count": 10, 498 | "metadata": {}, 499 | "output_type": "execute_result" 500 | } 501 | ], 502 | "source": [ 503 | "decoder_final_output.shape" 504 | ] 505 | }, 506 | { 507 | "cell_type": "markdown", 508 | "execution_count": null, 509 | "metadata": {}, 510 | "source": [ 511 | "# Glove Embedding" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": 11, 517 | "metadata": {}, 518 | "outputs": [ 519 | { 520 | "name": "stdout", 521 | "output_type": "stream", 522 | "text": [ 523 | "Glove Loded!\n" 524 | ] 525 | } 526 | ], 527 | "source": [ 528 | "\n", 529 | "embeddings_index = {}\n", 530 | "with open('../input/glove6b50d/glove.6B.50d.txt', encoding='utf-8') as f:\n", 531 | " for line in f:\n", 532 | " values = line.split()\n", 533 | " word = values[0]\n", 534 | " coefs = np.asarray(values[1:], dtype='float32')\n", 535 | " embeddings_index[word] = coefs\n", 536 | " f.close()\n", 537 | "\n", 538 | "print(\"Glove Loded!\")\n" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": 12, 544 | "metadata": {}, 545 | "outputs": [], 546 | "source": [ 547 | "\n", 548 | "embedding_dimention = 50\n", 549 | "def embedding_matrix_creater(embedding_dimention, word_index):\n", 550 | " embedding_matrix = np.zeros((len(word_index)+1, embedding_dimention))\n", 551 | " for word, i in word_index.items():\n", 552 | " embedding_vector = embeddings_index.get(word)\n", 553 | " if embedding_vector is not None:\n", 554 | " # words not found in embedding index will be all-zeros.\n", 555 | " embedding_matrix[i] = embedding_vector\n", 556 | " return embedding_matrix\n", 557 | "embedding_matrix = embedding_matrix_creater(50, word_index=vocab) \n" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": 13, 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [ 566 | "del(embeddings_index)" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 14, 572 | "metadata": {}, 573 | "outputs": [ 574 | { 575 | "data": { 576 | "text/plain": [ 577 | "(3028, 50)" 578 | ] 579 | }, 580 | "execution_count": 14, 581 | "metadata": {}, 582 | "output_type": "execute_result" 583 | } 584 | ], 585 | "source": [ 586 | "embedding_matrix.shape" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": 15, 592 | "metadata": {}, 593 | "outputs": [ 594 | { 595 | "data": { 596 | "text/plain": [ 597 | "array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 598 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 599 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])" 600 | ] 601 | }, 602 | "execution_count": 15, 603 | "metadata": {}, 604 | "output_type": "execute_result" 605 | } 606 | ], 607 | "source": [ 608 | "embedding_matrix[0]" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": 16, 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "from tensorflow.keras.models import Model\n", 618 | "from tensorflow.keras.layers import Dense, Embedding, LSTM, Input, Bidirectional, Concatenate, Dropout, Attention" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": 17, 624 | "metadata": {}, 625 | "outputs": [], 626 | "source": [ 627 | "embed = Embedding(VOCAB_SIZE+1, \n", 628 | " 50, \n", 629 | " \n", 630 | " input_length=13,\n", 631 | " trainable=True)\n", 632 | "\n", 633 | "embed.build((None,))\n", 634 | "embed.set_weights([embedding_matrix])\n" 635 | ] 636 | }, 637 | { 638 | "cell_type": "markdown", 639 | "execution_count": null, 640 | "metadata": {}, 641 | "source": [ 642 | "# Model" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": 18, 648 | "metadata": {}, 649 | "outputs": [], 650 | "source": [ 651 | "enc_inp = Input(shape=(13, ))" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": 19, 657 | "metadata": {}, 658 | "outputs": [ 659 | { 660 | "name": "stdout", 661 | "output_type": "stream", 662 | "text": [ 663 | "Model: \"model\"\n", 664 | "__________________________________________________________________________________________________\n", 665 | "Layer (type) Output Shape Param # Connected to \n", 666 | "==================================================================================================\n", 667 | "input_2 (InputLayer) [(None, 13)] 0 \n", 668 | "__________________________________________________________________________________________________\n", 669 | "input_1 (InputLayer) [(None, 13)] 0 \n", 670 | "__________________________________________________________________________________________________\n", 671 | "embedding (Embedding) (None, 13, 50) 151400 input_1[0][0] \n", 672 | " input_2[0][0] \n", 673 | "__________________________________________________________________________________________________\n", 674 | "bidirectional (Bidirectional) [(None, 13, 800), (N 1443200 embedding[0][0] \n", 675 | "__________________________________________________________________________________________________\n", 676 | "concatenate (Concatenate) (None, 800) 0 bidirectional[0][1] \n", 677 | " bidirectional[0][3] \n", 678 | "__________________________________________________________________________________________________\n", 679 | "concatenate_1 (Concatenate) (None, 800) 0 bidirectional[0][2] \n", 680 | " bidirectional[0][4] \n", 681 | "__________________________________________________________________________________________________\n", 682 | "lstm_1 (LSTM) [(None, 13, 800), (N 2723200 embedding[1][0] \n", 683 | " concatenate[0][0] \n", 684 | " concatenate_1[0][0] \n", 685 | "__________________________________________________________________________________________________\n", 686 | "attention_layer (AttentionLayer ((None, 13, 800), (N 1280800 bidirectional[0][0] \n", 687 | " lstm_1[0][0] \n", 688 | "__________________________________________________________________________________________________\n", 689 | "concatenate_2 (Concatenate) (None, 13, 1600) 0 lstm_1[0][0] \n", 690 | " attention_layer[0][0] \n", 691 | "__________________________________________________________________________________________________\n", 692 | "dense (Dense) (None, 13, 3027) 4846227 concatenate_2[0][0] \n", 693 | "==================================================================================================\n", 694 | "Total params: 10,444,827\n", 695 | "Trainable params: 10,444,827\n", 696 | "Non-trainable params: 0\n", 697 | "__________________________________________________________________________________________________\n" 698 | ] 699 | } 700 | ], 701 | "source": [ 702 | "#embed = Embedding(VOCAB_SIZE+1, 50, mask_zero=True, input_length=13)(enc_inp)\n", 703 | "enc_embed = embed(enc_inp)\n", 704 | "enc_lstm = Bidirectional(LSTM(400, return_state=True, dropout=0.05, return_sequences = True))\n", 705 | "\n", 706 | "encoder_outputs, forward_h, forward_c, backward_h, backward_c = enc_lstm(enc_embed)\n", 707 | "\n", 708 | "state_h = Concatenate()([forward_h, backward_h])\n", 709 | "state_c = Concatenate()([forward_c, backward_c])\n", 710 | "\n", 711 | "enc_states = [state_h, state_c]\n", 712 | "\n", 713 | "\n", 714 | "dec_inp = Input(shape=(13, ))\n", 715 | "dec_embed = embed(dec_inp)\n", 716 | "dec_lstm = LSTM(400*2, return_state=True, return_sequences=True, dropout=0.05)\n", 717 | "output, _, _ = dec_lstm(dec_embed, initial_state=enc_states)\n", 718 | "\n", 719 | "# attention\n", 720 | "attn_layer = AttentionLayer()\n", 721 | "attn_op, attn_state = attn_layer([encoder_outputs, output])\n", 722 | "decoder_concat_input = Concatenate(axis=-1)([output, attn_op])\n", 723 | "\n", 724 | "\n", 725 | "dec_dense = Dense(VOCAB_SIZE, activation='softmax')\n", 726 | "final_output = dec_dense(decoder_concat_input)\n", 727 | "\n", 728 | "model = Model([enc_inp, dec_inp], final_output)\n", 729 | "\n", 730 | "model.summary()" 731 | ] 732 | }, 733 | { 734 | "cell_type": "code", 735 | "execution_count": 20, 736 | "metadata": {}, 737 | "outputs": [ 738 | { 739 | "name": "stderr", 740 | "output_type": "stream", 741 | "text": [ 742 | "Using TensorFlow backend.\n" 743 | ] 744 | } 745 | ], 746 | "source": [ 747 | "import keras\n", 748 | "import tensorflow as tf" 749 | ] 750 | }, 751 | { 752 | "cell_type": "code", 753 | "execution_count": 21, 754 | "metadata": {}, 755 | "outputs": [], 756 | "source": [ 757 | "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])" 758 | ] 759 | }, 760 | { 761 | "cell_type": "code", 762 | "execution_count": 22, 763 | "metadata": { 764 | "collapsed": true 765 | }, 766 | "outputs": [ 767 | { 768 | "name": "stdout", 769 | "output_type": "stream", 770 | "text": [ 771 | "1063/1063 [==============================] - 54s 50ms/step - loss: 2.8887 - acc: 0.5147 - val_loss: 2.6757 - val_acc: 0.5333\n" 772 | ] 773 | }, 774 | { 775 | "data": { 776 | "text/plain": [ 777 | "" 778 | ] 779 | }, 780 | "execution_count": 22, 781 | "metadata": {}, 782 | "output_type": "execute_result" 783 | } 784 | ], 785 | "source": [ 786 | "model.fit([encoder_inp, decoder_inp], decoder_final_output, epochs=1, batch_size=24, validation_split=0.15)" 787 | ] 788 | }, 789 | { 790 | "cell_type": "markdown", 791 | "execution_count": null, 792 | "metadata": {}, 793 | "source": [ 794 | "# inferece" 795 | ] 796 | }, 797 | { 798 | "cell_type": "code", 799 | "execution_count": 23, 800 | "metadata": {}, 801 | "outputs": [], 802 | "source": [ 803 | "model.save('chatbot.h5')\n", 804 | "model.save_weights('chatbot_weights.h5')" 805 | ] 806 | }, 807 | { 808 | "cell_type": "markdown", 809 | "execution_count": null, 810 | "metadata": {}, 811 | "source": [ 812 | "# Attention Inference\n" 813 | ] 814 | }, 815 | { 816 | "cell_type": "code", 817 | "execution_count": 24, 818 | "metadata": {}, 819 | "outputs": [], 820 | "source": [ 821 | "enc_model = tf.keras.models.Model(enc_inp, [encoder_outputs, enc_states])\n", 822 | "\n", 823 | "\n", 824 | "decoder_state_input_h = tf.keras.layers.Input(shape=( 400 * 2,))\n", 825 | "decoder_state_input_c = tf.keras.layers.Input(shape=( 400 * 2,))\n", 826 | "\n", 827 | "decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]\n", 828 | "\n", 829 | "\n", 830 | "decoder_outputs, state_h, state_c = dec_lstm(dec_embed , initial_state=decoder_states_inputs)\n", 831 | "\n", 832 | "\n", 833 | "decoder_states = [state_h, state_c]\n", 834 | "\n", 835 | "#decoder_output = dec_dense(decoder_outputs)\n", 836 | "\n", 837 | "dec_model = tf.keras.models.Model([dec_inp, decoder_states_inputs],\n", 838 | " [decoder_outputs] + decoder_states)\n" 839 | ] 840 | }, 841 | { 842 | "cell_type": "code", 843 | "execution_count": 25, 844 | "metadata": {}, 845 | "outputs": [ 846 | { 847 | "name": "stdout", 848 | "output_type": "stream", 849 | "text": [ 850 | "##########################################\n", 851 | "# start chatting ver. 1.0 #\n", 852 | "##########################################\n" 853 | ] 854 | }, 855 | { 856 | "ename": "StdinNotImplementedError", 857 | "evalue": "raw_input was called, but this frontend does not support input requests.", 858 | "output_type": "error", 859 | "traceback": [ 860 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 861 | "\u001b[0;31mStdinNotImplementedError\u001b[0m Traceback (most recent call last)", 862 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mprepro1\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'q'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mprepro1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"you : \"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0mprepro\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mprepro1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 863 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mraw_input\u001b[0;34m(self, prompt)\u001b[0m\n\u001b[1;32m 853\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_allow_stdin\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 854\u001b[0m raise StdinNotImplementedError(\n\u001b[0;32m--> 855\u001b[0;31m \u001b[0;34m\"raw_input was called, but this frontend does not support input requests.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 856\u001b[0m )\n\u001b[1;32m 857\u001b[0m return self._input_request(str(prompt),\n", 864 | "\u001b[0;31mStdinNotImplementedError\u001b[0m: raw_input was called, but this frontend does not support input requests." 865 | ] 866 | } 867 | ], 868 | "source": [ 869 | "print(\"##########################################\")\n", 870 | "print(\"# start chatting ver. 1.0 #\")\n", 871 | "print(\"##########################################\")\n", 872 | "\n", 873 | "\n", 874 | "prepro1 = \"\"\n", 875 | "while prepro1 != 'q':\n", 876 | " \n", 877 | " prepro1 = input(\"you : \")\n", 878 | " prepro = [prepro1]\n", 879 | " \n", 880 | " try:\n", 881 | " txt = []\n", 882 | " for x in prepro:\n", 883 | " lst = []\n", 884 | " for y in x.split():\n", 885 | " lst.append(vocab[y])\n", 886 | " txt.append(lst)\n", 887 | " txt = pad_sequences(txt, 13, padding='post')\n", 888 | "\n", 889 | "\n", 890 | " ###\n", 891 | " enc_op, stat = enc_model.predict( txt )\n", 892 | "\n", 893 | " empty_target_seq = np.zeros( ( 1 , 1) )\n", 894 | " empty_target_seq[0, 0] = vocab['']\n", 895 | " stop_condition = False\n", 896 | " decoded_translation = ''\n", 897 | "\n", 898 | "\n", 899 | " while not stop_condition :\n", 900 | "\n", 901 | " dec_outputs , h , c = dec_model.predict([ empty_target_seq ] + stat )\n", 902 | "\n", 903 | " ###\n", 904 | " ###########################\n", 905 | " attn_op, attn_state = attn_layer([enc_op, dec_outputs])\n", 906 | " decoder_concat_input = Concatenate(axis=-1)([dec_outputs, attn_op])\n", 907 | " decoder_concat_input = dec_dense(decoder_concat_input)\n", 908 | " ###########################\n", 909 | "\n", 910 | " sampled_word_index = np.argmax( decoder_concat_input[0, -1, :] )\n", 911 | "\n", 912 | " sampled_word = inv_vocab[sampled_word_index] + ' '\n", 913 | "\n", 914 | " if sampled_word != ' ':\n", 915 | " decoded_translation += sampled_word \n", 916 | "\n", 917 | "\n", 918 | " if sampled_word == ' ' or len(decoded_translation.split()) > 13:\n", 919 | " stop_condition = True\n", 920 | "\n", 921 | " empty_target_seq = np.zeros( ( 1 , 1 ) ) \n", 922 | " empty_target_seq[ 0 , 0 ] = sampled_word_index\n", 923 | " stat = [ h , c ] \n", 924 | " except:\n", 925 | " pass\n", 926 | "\n", 927 | " print(\"chatbot attention : \", decoded_translation )\n", 928 | " print(\"==============================================\")\n", 929 | "\n", 930 | "\n" 931 | ] 932 | }, 933 | { 934 | "cell_type": "code", 935 | "execution_count": null, 936 | "metadata": {}, 937 | "outputs": [], 938 | "source": [] 939 | }, 940 | { 941 | "cell_type": "code", 942 | "execution_count": null, 943 | "metadata": {}, 944 | "outputs": [], 945 | "source": [] 946 | } 947 | ], 948 | "metadata": { 949 | "kernelspec": { 950 | "display_name": "Python 3", 951 | "language": "python", 952 | "name": "python3" 953 | }, 954 | "language_info": { 955 | "codemirror_mode": { 956 | "name": "ipython", 957 | "version": 3 958 | }, 959 | "file_extension": ".py", 960 | "mimetype": "text/x-python", 961 | "name": "python", 962 | "nbconvert_exporter": "python", 963 | "pygments_lexer": "ipython3", 964 | "version": "3.7.6" 965 | } 966 | }, 967 | "nbformat": 4, 968 | "nbformat_minor": 4 969 | } 970 | --------------------------------------------------------------------------------