├── dic.pkl
├── inv.pkl
├── example_chat.JPG
├── README.md
├── AttentionLayer.py
├── chatbot.py
├── LICENSE
└── seq2seq-chatbot-keras-with-attention.ipynb
/dic.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pawandeep-prog/keras-seq2seq-chatbot-with-attention/HEAD/dic.pkl
--------------------------------------------------------------------------------
/inv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pawandeep-prog/keras-seq2seq-chatbot-with-attention/HEAD/inv.pkl
--------------------------------------------------------------------------------
/example_chat.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pawandeep-prog/keras-seq2seq-chatbot-with-attention/HEAD/example_chat.JPG
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # keras-seq2seq-chatbot-with-attention
2 | It is a seq2seq encoder decoder chatbot using keras and with attention
3 |
4 |
files
5 |
6 | - chatbot.py :- This is file to run chatbot using the saved model
7 | - ipynb file :- This file is all in one you just need below datasets to run it Hopefully with no errors.
8 | it also saves the model in h5 format
9 |
10 |
11 | thanks to thushv89 https://github.com/thushv89/attention_keras for attention layer all credit for Attention Layer layer class goes him
12 |
13 | How to RUN
14 |
15 | - Run on kaggle : https://www.kaggle.com/programminghut/seq2seq-chatbot-keras-with-attention
16 | - use chatbot.py once you ran ipynb file because save model is executed in above ipynb file
17 |
18 |
19 | Datasets used :
20 |
21 | - glove6b 50d : https://www.kaggle.com/watts2/glove6b50dtxt
22 | - cornell movie : https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html
23 |
24 |
25 | chatbot.py is the python file to run locally using the saved model.
26 |
27 | Sample chat
28 |
29 |
30 | if you still have queries you may contact me here
31 | contact me :
32 | facebook : https://m.facebook.com/proogramminghub
33 | twitter : https://twitter.com/programming_hut
34 | github : https://github.com/Pawandeep-prog
35 | discord : https://discord.gg/G5Cunyg
36 | linkedin : https://www.linkedin.com/in/programminghut
37 | youtube : https://www.youtube.com/c/programminghutofficial
38 |
39 |
--------------------------------------------------------------------------------
/AttentionLayer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | from tensorflow.python.keras.layers import Layer
4 | from tensorflow.python.keras import backend as K
5 |
6 |
7 | class AttentionLayer(Layer):
8 | """
9 | This class implements Bahdanau attention (https://arxiv.org/pdf/1409.0473.pdf).
10 | There are three sets of weights introduced W_a, U_a, and V_a
11 | """
12 |
13 | def __init__(self, **kwargs):
14 | super(AttentionLayer, self).__init__(**kwargs)
15 |
16 | def build(self, input_shape):
17 | assert isinstance(input_shape, list)
18 | # Create a trainable weight variable for this layer.
19 |
20 | self.W_a = self.add_weight(name='W_a',
21 | shape=tf.TensorShape((input_shape[0][2], input_shape[0][2])),
22 | initializer='uniform',
23 | trainable=True)
24 | self.U_a = self.add_weight(name='U_a',
25 | shape=tf.TensorShape((input_shape[1][2], input_shape[0][2])),
26 | initializer='uniform',
27 | trainable=True)
28 | self.V_a = self.add_weight(name='V_a',
29 | shape=tf.TensorShape((input_shape[0][2], 1)),
30 | initializer='uniform',
31 | trainable=True)
32 |
33 | super(AttentionLayer, self).build(input_shape) # Be sure to call this at the end
34 |
35 | def call(self, inputs, verbose=False):
36 | """
37 | inputs: [encoder_output_sequence, decoder_output_sequence]
38 | """
39 | assert type(inputs) == list
40 | encoder_out_seq, decoder_out_seq = inputs
41 | if verbose:
42 | print('encoder_out_seq>', encoder_out_seq.shape)
43 | print('decoder_out_seq>', decoder_out_seq.shape)
44 |
45 | def energy_step(inputs, states):
46 | """ Step function for computing energy for a single decoder state
47 | inputs: (batchsize * 1 * de_in_dim)
48 | states: (batchsize * 1 * de_latent_dim)
49 | """
50 |
51 | assert_msg = "States must be an iterable. Got {} of type {}".format(states, type(states))
52 | assert isinstance(states, list) or isinstance(states, tuple), assert_msg
53 |
54 | """ Some parameters required for shaping tensors"""
55 | en_seq_len, en_hidden = encoder_out_seq.shape[1], encoder_out_seq.shape[2]
56 | de_hidden = inputs.shape[-1]
57 |
58 | """ Computing S.Wa where S=[s0, s1, ..., si]"""
59 | # <= batch size * en_seq_len * latent_dim
60 | W_a_dot_s = K.dot(encoder_out_seq, self.W_a)
61 |
62 | """ Computing hj.Ua """
63 | U_a_dot_h = K.expand_dims(K.dot(inputs, self.U_a), 1) # <= batch_size, 1, latent_dim
64 | if verbose:
65 | print('Ua.h>', U_a_dot_h.shape)
66 |
67 | """ tanh(S.Wa + hj.Ua) """
68 | # <= batch_size*en_seq_len, latent_dim
69 | Ws_plus_Uh = K.tanh(W_a_dot_s + U_a_dot_h)
70 | if verbose:
71 | print('Ws+Uh>', Ws_plus_Uh.shape)
72 |
73 | """ softmax(va.tanh(S.Wa + hj.Ua)) """
74 | # <= batch_size, en_seq_len
75 | e_i = K.squeeze(K.dot(Ws_plus_Uh, self.V_a), axis=-1)
76 | # <= batch_size, en_seq_len
77 | e_i = K.softmax(e_i)
78 |
79 | if verbose:
80 | print('ei>', e_i.shape)
81 |
82 | return e_i, [e_i]
83 |
84 | def context_step(inputs, states):
85 | """ Step function for computing ci using ei """
86 |
87 | assert_msg = "States must be an iterable. Got {} of type {}".format(states, type(states))
88 | assert isinstance(states, list) or isinstance(states, tuple), assert_msg
89 |
90 | # <= batch_size, hidden_size
91 | c_i = K.sum(encoder_out_seq * K.expand_dims(inputs, -1), axis=1)
92 | if verbose:
93 | print('ci>', c_i.shape)
94 | return c_i, [c_i]
95 |
96 | fake_state_c = K.sum(encoder_out_seq, axis=1)
97 | fake_state_e = K.sum(encoder_out_seq, axis=2) # <= (batch_size, enc_seq_len, latent_dim
98 |
99 | """ Computing energy outputs """
100 | # e_outputs => (batch_size, de_seq_len, en_seq_len)
101 | last_out, e_outputs, _ = K.rnn(
102 | energy_step, decoder_out_seq, [fake_state_e],
103 | )
104 |
105 | """ Computing context vectors """
106 | last_out, c_outputs, _ = K.rnn(
107 | context_step, e_outputs, [fake_state_c],
108 | )
109 |
110 | return c_outputs, e_outputs
111 |
112 | def compute_output_shape(self, input_shape):
113 | """ Outputs produced by the layer """
114 | return [
115 | tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[1][2])),
116 | tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][1]))
117 | ]
--------------------------------------------------------------------------------
/chatbot.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.models import load_model, Model
2 | from tensorflow.keras.layers import Input, Concatenate
3 | import tensorflow as tf
4 | import os
5 | from tensorflow.python.keras.layers import Layer
6 | from tensorflow.python.keras import backend as K
7 | import pickle
8 | import numpy as np
9 | import re
10 | from AttentionLayer import AttentionLayer
11 |
12 | with open('dic.pkl', 'rb') as f:
13 | vocab = pickle.load(f)
14 | with open('inv.pkl', 'rb') as f:
15 | inv_vocab = pickle.load(f)
16 |
17 |
18 | def clean_text(txt):
19 | txt = txt.lower()
20 | txt = re.sub(r"i'm", "i am", txt)
21 | txt = re.sub(r"he's", "he is", txt)
22 | txt = re.sub(r"she's", "she is", txt)
23 | txt = re.sub(r"that's", "that is", txt)
24 | txt = re.sub(r"what's", "what is", txt)
25 | txt = re.sub(r"where's", "where is", txt)
26 | txt = re.sub(r"\'ll", " will", txt)
27 | txt = re.sub(r"\'ve", " have", txt)
28 | txt = re.sub(r"\'re", " are", txt)
29 | txt = re.sub(r"\'d", " would", txt)
30 | txt = re.sub(r"won't", "will not", txt)
31 | txt = re.sub(r"can't", "can not", txt)
32 | txt = re.sub(r"[^\w\s]", "", txt)
33 | return txt
34 |
35 |
36 |
37 | attn_layer = AttentionLayer()
38 |
39 | model = load_model('chatbot.h5', custom_objects={'AttentionLayer' : attn_layer})
40 |
41 |
42 |
43 | encoder_inputs = model.layers[0].input
44 | embed = model.layers[2]
45 | enc_embed = embed(encoder_inputs)
46 | enocoder_layer = model.layers[3]
47 |
48 | encoder_outputs, fstate_h, fstate_c, bstate_h, bstate_c = enocoder_layer(enc_embed)
49 |
50 | h = Concatenate()([fstate_h, bstate_h])
51 | c = Concatenate()([fstate_c, bstate_c])
52 | encoder_states = [h, c]
53 |
54 | enc_model = Model(encoder_inputs,
55 | [encoder_outputs,
56 | encoder_states])
57 |
58 |
59 | latent_dim = 800
60 |
61 | decoder_inputs = model.layers[1].input
62 | decoder_lstm = model.layers[6]
63 | decoder_dense = model.layers[9]
64 | decoder_state_input_h = Input(shape=(latent_dim,), name='input_3')
65 | decoder_state_input_c = Input(shape=(latent_dim,), name='input_4')
66 |
67 | decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
68 |
69 | dec_embed = embed(decoder_inputs)
70 |
71 | decoder_outputs, state_h, state_c = decoder_lstm(dec_embed, initial_state=decoder_states_inputs)
72 | decoder_states = [state_h, state_c]
73 |
74 | dec_model = Model([decoder_inputs, decoder_states_inputs], [decoder_outputs] + decoder_states)
75 |
76 | dec_dense = model.layers[-1]
77 | attn_layer = model.layers[7]
78 |
79 | from keras.preprocessing.sequence import pad_sequences
80 | print("##########################################")
81 | print("# start chatting ver. 1.0 #")
82 | print("##########################################")
83 |
84 |
85 | prepro1 = ""
86 | while prepro1 != 'q':
87 |
88 | prepro1 = input("you : ")
89 | try:
90 | prepro1 = clean_text(prepro1)
91 | prepro = [prepro1]
92 |
93 | txt = []
94 | for x in prepro:
95 | lst = []
96 | for y in x.split():
97 | try:
98 | lst.append(vocab[y])
99 | except:
100 | lst.append(vocab[''])
101 | txt.append(lst)
102 | txt = pad_sequences(txt, 13, padding='post')
103 |
104 |
105 | ###
106 | enc_op, stat = enc_model.predict( txt )
107 |
108 | empty_target_seq = np.zeros( ( 1 , 1) )
109 | empty_target_seq[0, 0] = vocab['']
110 | stop_condition = False
111 | decoded_translation = ''
112 |
113 |
114 | while not stop_condition :
115 |
116 | dec_outputs , h , c = dec_model.predict([ empty_target_seq ] + stat )
117 |
118 | ###
119 | ###########################
120 | attn_op, attn_state = attn_layer([enc_op, dec_outputs])
121 | decoder_concat_input = Concatenate(axis=-1)([dec_outputs, attn_op])
122 | decoder_concat_input = dec_dense(decoder_concat_input)
123 | ###########################
124 |
125 | sampled_word_index = np.argmax( decoder_concat_input[0, -1, :] )
126 |
127 | sampled_word = inv_vocab[sampled_word_index] + ' '
128 |
129 | if sampled_word != ' ':
130 | decoded_translation += sampled_word
131 |
132 |
133 | if sampled_word == ' ' or len(decoded_translation.split()) > 13:
134 | stop_condition = True
135 |
136 | empty_target_seq = np.zeros( ( 1 , 1 ) )
137 | empty_target_seq[ 0 , 0 ] = sampled_word_index
138 | stat = [ h , c ]
139 |
140 | print("chatbot attention : ", decoded_translation )
141 | print("==============================================")
142 |
143 | except:
144 | print("sorry didn't got you , please type again :( ")
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/seq2seq-chatbot-keras-with-attention.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
8 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5"
9 | },
10 | "outputs": [],
11 | "source": [
12 | "import numpy as np \n",
13 | "import pandas as pd \n",
14 | "import os"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "execution_count": null,
20 | "metadata": {},
21 | "source": [
22 | "# Attention Class"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "import tensorflow as tf\n",
32 | "import os\n",
33 | "from tensorflow.python.keras.layers import Layer\n",
34 | "from tensorflow.python.keras import backend as K\n",
35 | "\n",
36 | "\n",
37 | "class AttentionLayer(Layer):\n",
38 | " \"\"\"\n",
39 | " This class implements Bahdanau attention (https://arxiv.org/pdf/1409.0473.pdf).\n",
40 | " There are three sets of weights introduced W_a, U_a, and V_a\n",
41 | " \"\"\"\n",
42 | "\n",
43 | " def __init__(self, **kwargs):\n",
44 | " super(AttentionLayer, self).__init__(**kwargs)\n",
45 | "\n",
46 | " def build(self, input_shape):\n",
47 | " assert isinstance(input_shape, list)\n",
48 | " # Create a trainable weight variable for this layer.\n",
49 | "\n",
50 | " self.W_a = self.add_weight(name='W_a',\n",
51 | " shape=tf.TensorShape((input_shape[0][2], input_shape[0][2])),\n",
52 | " initializer='uniform',\n",
53 | " trainable=True)\n",
54 | " self.U_a = self.add_weight(name='U_a',\n",
55 | " shape=tf.TensorShape((input_shape[1][2], input_shape[0][2])),\n",
56 | " initializer='uniform',\n",
57 | " trainable=True)\n",
58 | " self.V_a = self.add_weight(name='V_a',\n",
59 | " shape=tf.TensorShape((input_shape[0][2], 1)),\n",
60 | " initializer='uniform',\n",
61 | " trainable=True)\n",
62 | "\n",
63 | " super(AttentionLayer, self).build(input_shape) # Be sure to call this at the end\n",
64 | "\n",
65 | " def call(self, inputs, verbose=False):\n",
66 | " \"\"\"\n",
67 | " inputs: [encoder_output_sequence, decoder_output_sequence]\n",
68 | " \"\"\"\n",
69 | " assert type(inputs) == list\n",
70 | " encoder_out_seq, decoder_out_seq = inputs\n",
71 | " if verbose:\n",
72 | " print('encoder_out_seq>', encoder_out_seq.shape)\n",
73 | " print('decoder_out_seq>', decoder_out_seq.shape)\n",
74 | "\n",
75 | " def energy_step(inputs, states):\n",
76 | " \"\"\" Step function for computing energy for a single decoder state\n",
77 | " inputs: (batchsize * 1 * de_in_dim)\n",
78 | " states: (batchsize * 1 * de_latent_dim)\n",
79 | " \"\"\"\n",
80 | "\n",
81 | " assert_msg = \"States must be an iterable. Got {} of type {}\".format(states, type(states))\n",
82 | " assert isinstance(states, list) or isinstance(states, tuple), assert_msg\n",
83 | "\n",
84 | " \"\"\" Some parameters required for shaping tensors\"\"\"\n",
85 | " en_seq_len, en_hidden = encoder_out_seq.shape[1], encoder_out_seq.shape[2]\n",
86 | " de_hidden = inputs.shape[-1]\n",
87 | "\n",
88 | " \"\"\" Computing S.Wa where S=[s0, s1, ..., si]\"\"\"\n",
89 | " # <= batch size * en_seq_len * latent_dim\n",
90 | " W_a_dot_s = K.dot(encoder_out_seq, self.W_a)\n",
91 | "\n",
92 | " \"\"\" Computing hj.Ua \"\"\"\n",
93 | " U_a_dot_h = K.expand_dims(K.dot(inputs, self.U_a), 1) # <= batch_size, 1, latent_dim\n",
94 | " if verbose:\n",
95 | " print('Ua.h>', U_a_dot_h.shape)\n",
96 | "\n",
97 | " \"\"\" tanh(S.Wa + hj.Ua) \"\"\"\n",
98 | " # <= batch_size*en_seq_len, latent_dim\n",
99 | " Ws_plus_Uh = K.tanh(W_a_dot_s + U_a_dot_h)\n",
100 | " if verbose:\n",
101 | " print('Ws+Uh>', Ws_plus_Uh.shape)\n",
102 | "\n",
103 | " \"\"\" softmax(va.tanh(S.Wa + hj.Ua)) \"\"\"\n",
104 | " # <= batch_size, en_seq_len\n",
105 | " e_i = K.squeeze(K.dot(Ws_plus_Uh, self.V_a), axis=-1)\n",
106 | " # <= batch_size, en_seq_len\n",
107 | " e_i = K.softmax(e_i)\n",
108 | "\n",
109 | " if verbose:\n",
110 | " print('ei>', e_i.shape)\n",
111 | "\n",
112 | " return e_i, [e_i]\n",
113 | "\n",
114 | " def context_step(inputs, states):\n",
115 | " \"\"\" Step function for computing ci using ei \"\"\"\n",
116 | "\n",
117 | " assert_msg = \"States must be an iterable. Got {} of type {}\".format(states, type(states))\n",
118 | " assert isinstance(states, list) or isinstance(states, tuple), assert_msg\n",
119 | "\n",
120 | " # <= batch_size, hidden_size\n",
121 | " c_i = K.sum(encoder_out_seq * K.expand_dims(inputs, -1), axis=1)\n",
122 | " if verbose:\n",
123 | " print('ci>', c_i.shape)\n",
124 | " return c_i, [c_i]\n",
125 | "\n",
126 | " fake_state_c = K.sum(encoder_out_seq, axis=1)\n",
127 | " fake_state_e = K.sum(encoder_out_seq, axis=2) # <= (batch_size, enc_seq_len, latent_dim\n",
128 | "\n",
129 | " \"\"\" Computing energy outputs \"\"\"\n",
130 | " # e_outputs => (batch_size, de_seq_len, en_seq_len)\n",
131 | " last_out, e_outputs, _ = K.rnn(\n",
132 | " energy_step, decoder_out_seq, [fake_state_e],\n",
133 | " )\n",
134 | "\n",
135 | " \"\"\" Computing context vectors \"\"\"\n",
136 | " last_out, c_outputs, _ = K.rnn(\n",
137 | " context_step, e_outputs, [fake_state_c],\n",
138 | " )\n",
139 | "\n",
140 | " return c_outputs, e_outputs\n",
141 | "\n",
142 | " def compute_output_shape(self, input_shape):\n",
143 | " \"\"\" Outputs produced by the layer \"\"\"\n",
144 | " return [\n",
145 | " tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[1][2])),\n",
146 | " tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][1]))\n",
147 | " ]"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 3,
153 | "metadata": {
154 | "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
155 | "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
156 | },
157 | "outputs": [],
158 | "source": [
159 | "import re\n",
160 | "\n",
161 | "lines = open('../input/chatbot-data/cornell movie-dialogs corpus/movie_lines.txt', encoding='utf-8',\n",
162 | " errors='ignore').read().split('\\n')\n",
163 | "\n",
164 | "convers = open('../input/chatbot-data/cornell movie-dialogs corpus/movie_conversations.txt', encoding='utf-8',\n",
165 | " errors='ignore').read().split('\\n')\n"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 4,
171 | "metadata": {},
172 | "outputs": [
173 | {
174 | "data": {
175 | "text/plain": [
176 | "304714"
177 | ]
178 | },
179 | "execution_count": 4,
180 | "metadata": {},
181 | "output_type": "execute_result"
182 | }
183 | ],
184 | "source": [
185 | "len(lines)"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "execution_count": null,
191 | "metadata": {},
192 | "source": [
193 | "# Data Preprocess"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 5,
199 | "metadata": {},
200 | "outputs": [],
201 | "source": [
202 | "\n",
203 | "exchn = []\n",
204 | "for conver in convers:\n",
205 | " exchn.append(conver.split(' +++$+++ ')[-1][1:-1].replace(\"'\", \" \").replace(\",\",\"\").split())\n",
206 | "\n",
207 | "diag = {}\n",
208 | "for line in lines:\n",
209 | " diag[line.split(' +++$+++ ')[0]] = line.split(' +++$+++ ')[-1]\n",
210 | "\n",
211 | "\n",
212 | "\n",
213 | "## delete\n",
214 | "del(lines, convers, conver, line)\n",
215 | "\n",
216 | "\n",
217 | "\n",
218 | "questions = []\n",
219 | "answers = []\n",
220 | "\n",
221 | "for conver in exchn:\n",
222 | " for i in range(len(conver) - 1):\n",
223 | " questions.append(diag[conver[i]])\n",
224 | " answers.append(diag[conver[i+1]])\n",
225 | " \n",
226 | " \n",
227 | " \n",
228 | "\n",
229 | "## delete\n",
230 | "del(diag, exchn, conver, i)\n",
231 | "\n",
232 | "\n",
233 | "###############################\n",
234 | "# max_len = 13 #\n",
235 | "###############################\n",
236 | "\n",
237 | "sorted_ques = []\n",
238 | "sorted_ans = []\n",
239 | "for i in range(len(questions)):\n",
240 | " if len(questions[i]) < 13:\n",
241 | " sorted_ques.append(questions[i])\n",
242 | " sorted_ans.append(answers[i])\n",
243 | "\n",
244 | "\n",
245 | "\n",
246 | "###############################\n",
247 | "# #\n",
248 | "###############################\n",
249 | "\n",
250 | "\n",
251 | "\n",
252 | "\n",
253 | "def clean_text(txt):\n",
254 | " txt = txt.lower()\n",
255 | " txt = re.sub(r\"i'm\", \"i am\", txt)\n",
256 | " txt = re.sub(r\"he's\", \"he is\", txt)\n",
257 | " txt = re.sub(r\"she's\", \"she is\", txt)\n",
258 | " txt = re.sub(r\"that's\", \"that is\", txt)\n",
259 | " txt = re.sub(r\"what's\", \"what is\", txt)\n",
260 | " txt = re.sub(r\"where's\", \"where is\", txt)\n",
261 | " txt = re.sub(r\"\\'ll\", \" will\", txt)\n",
262 | " txt = re.sub(r\"\\'ve\", \" have\", txt)\n",
263 | " txt = re.sub(r\"\\'re\", \" are\", txt)\n",
264 | " txt = re.sub(r\"\\'d\", \" would\", txt)\n",
265 | " txt = re.sub(r\"won't\", \"will not\", txt)\n",
266 | " txt = re.sub(r\"can't\", \"can not\", txt)\n",
267 | " txt = re.sub(r\"[^\\w\\s]\", \"\", txt)\n",
268 | " return txt\n",
269 | "\n",
270 | "clean_ques = []\n",
271 | "clean_ans = []\n",
272 | "\n",
273 | "for line in sorted_ques:\n",
274 | " clean_ques.append(clean_text(line))\n",
275 | " \n",
276 | "for line in sorted_ans:\n",
277 | " clean_ans.append(clean_text(line))\n",
278 | "\n",
279 | "\n",
280 | "\n",
281 | "## delete\n",
282 | "del(answers, questions, line)\n",
283 | "\n",
284 | "\n",
285 | "###############################\n",
286 | "# #\n",
287 | "###############################\n",
288 | "\n",
289 | "\n",
290 | "for i in range(len(clean_ans)):\n",
291 | " clean_ans[i] = ' '.join(clean_ans[i].split()[:11])\n",
292 | "\n",
293 | "\n",
294 | "\n",
295 | "###############################\n",
296 | "# #\n",
297 | "###############################\n",
298 | "\n",
299 | "del(sorted_ans, sorted_ques)\n",
300 | "\n",
301 | "\n",
302 | "## trimming\n",
303 | "clean_ans=clean_ans[:30000]\n",
304 | "clean_ques=clean_ques[:30000]\n",
305 | "## delete\n",
306 | "\n",
307 | "\n",
308 | "### count occurences ###\n",
309 | "word2count = {}\n",
310 | "\n",
311 | "for line in clean_ques:\n",
312 | " for word in line.split():\n",
313 | " if word not in word2count:\n",
314 | " word2count[word] = 1\n",
315 | " else:\n",
316 | " word2count[word] += 1\n",
317 | "for line in clean_ans:\n",
318 | " for word in line.split():\n",
319 | " if word not in word2count:\n",
320 | " word2count[word] = 1\n",
321 | " else:\n",
322 | " word2count[word] += 1\n",
323 | "\n",
324 | "## delete\n",
325 | "del(word, line)\n",
326 | "\n",
327 | "\n",
328 | "### remove less frequent ###\n",
329 | "thresh = 5\n",
330 | "\n",
331 | "vocab = {}\n",
332 | "word_num = 0\n",
333 | "for word, count in word2count.items():\n",
334 | " if count >= thresh:\n",
335 | " vocab[word] = word_num\n",
336 | " word_num += 1\n",
337 | " \n",
338 | "## delete\n",
339 | "del(word2count, word, count, thresh) \n",
340 | "del(word_num) \n",
341 | "\n",
342 | "\n",
343 | "\n",
344 | "for i in range(len(clean_ans)):\n",
345 | " clean_ans[i] = ' ' + clean_ans[i] + ' '\n",
346 | "\n",
347 | "\n",
348 | "\n",
349 | "tokens = ['', '', '', '']\n",
350 | "x = len(vocab)\n",
351 | "for token in tokens:\n",
352 | " vocab[token] = x\n",
353 | " x += 1\n",
354 | " \n",
355 | " \n",
356 | "\n",
357 | "vocab['cameron'] = vocab['']\n",
358 | "vocab[''] = 0\n",
359 | "\n",
360 | "## delete\n",
361 | "del(token, tokens) \n",
362 | "del(x)\n",
363 | "\n",
364 | "### inv answers dict ###\n",
365 | "inv_vocab = {w:v for v, w in vocab.items()}\n",
366 | "\n",
367 | "\n",
368 | "\n",
369 | "## delete\n",
370 | "del(i)\n",
371 | "\n",
372 | "\n",
373 | "\n",
374 | "encoder_inp = []\n",
375 | "for line in clean_ques:\n",
376 | " lst = []\n",
377 | " for word in line.split():\n",
378 | " if word not in vocab:\n",
379 | " lst.append(vocab[''])\n",
380 | " else:\n",
381 | " lst.append(vocab[word])\n",
382 | " \n",
383 | " encoder_inp.append(lst)\n",
384 | "\n",
385 | "decoder_inp = []\n",
386 | "for line in clean_ans:\n",
387 | " lst = []\n",
388 | " for word in line.split():\n",
389 | " if word not in vocab:\n",
390 | " lst.append(vocab[''])\n",
391 | " else:\n",
392 | " lst.append(vocab[word]) \n",
393 | " decoder_inp.append(lst)\n",
394 | "\n",
395 | "### delete\n",
396 | "del(clean_ans, clean_ques, line, lst, word)\n",
397 | "\n",
398 | "\n",
399 | "\n",
400 | "\n",
401 | "\n",
402 | "\n",
403 | "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
404 | "encoder_inp = pad_sequences(encoder_inp, 13, padding='post', truncating='post')\n",
405 | "decoder_inp = pad_sequences(decoder_inp, 13, padding='post', truncating='post')\n",
406 | "\n",
407 | "\n",
408 | "\n",
409 | "\n",
410 | "decoder_final_output = []\n",
411 | "for i in decoder_inp:\n",
412 | " decoder_final_output.append(i[1:]) \n",
413 | "\n",
414 | "decoder_final_output = pad_sequences(decoder_final_output, 13, padding='post', truncating='post')\n",
415 | "\n",
416 | "\n",
417 | "del(i)\n",
418 | "\n"
419 | ]
420 | },
421 | {
422 | "cell_type": "code",
423 | "execution_count": 6,
424 | "metadata": {},
425 | "outputs": [
426 | {
427 | "name": "stdout",
428 | "output_type": "stream",
429 | "text": [
430 | "(30000, 13) (30000, 13) (30000, 13) 3027 3027 \n"
431 | ]
432 | }
433 | ],
434 | "source": [
435 | "# decoder_final_output, decoder_final_input, encoder_final, vocab, inv_vocab\n",
436 | "\n",
437 | "VOCAB_SIZE = len(vocab)\n",
438 | "MAX_LEN = 13\n",
439 | "\n",
440 | "print(decoder_final_output.shape, decoder_inp.shape, encoder_inp.shape, len(vocab), len(inv_vocab), inv_vocab[0])"
441 | ]
442 | },
443 | {
444 | "cell_type": "code",
445 | "execution_count": 7,
446 | "metadata": {},
447 | "outputs": [
448 | {
449 | "data": {
450 | "text/plain": [
451 | "'they'"
452 | ]
453 | },
454 | "execution_count": 7,
455 | "metadata": {},
456 | "output_type": "execute_result"
457 | }
458 | ],
459 | "source": [
460 | "inv_vocab[16]"
461 | ]
462 | },
463 | {
464 | "cell_type": "code",
465 | "execution_count": 8,
466 | "metadata": {},
467 | "outputs": [],
468 | "source": [
469 | "#print(len(decoder_final_input), MAX_LEN, VOCAB_SIZE)\n",
470 | "#decoder_final_input[0]\n",
471 | "#decoder_output_data = np.zeros((len(decoder_final_input), MAX_LEN, VOCAB_SIZE), dtype=\"float32\")\n",
472 | "#print(decoder_output_data.shape)\n",
473 | "#decoder_final_input[80]"
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": 9,
479 | "metadata": {},
480 | "outputs": [],
481 | "source": [
482 | "from tensorflow.keras.utils import to_categorical\n",
483 | "decoder_final_output = to_categorical(decoder_final_output, len(vocab))"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": 10,
489 | "metadata": {},
490 | "outputs": [
491 | {
492 | "data": {
493 | "text/plain": [
494 | "(30000, 13, 3027)"
495 | ]
496 | },
497 | "execution_count": 10,
498 | "metadata": {},
499 | "output_type": "execute_result"
500 | }
501 | ],
502 | "source": [
503 | "decoder_final_output.shape"
504 | ]
505 | },
506 | {
507 | "cell_type": "markdown",
508 | "execution_count": null,
509 | "metadata": {},
510 | "source": [
511 | "# Glove Embedding"
512 | ]
513 | },
514 | {
515 | "cell_type": "code",
516 | "execution_count": 11,
517 | "metadata": {},
518 | "outputs": [
519 | {
520 | "name": "stdout",
521 | "output_type": "stream",
522 | "text": [
523 | "Glove Loded!\n"
524 | ]
525 | }
526 | ],
527 | "source": [
528 | "\n",
529 | "embeddings_index = {}\n",
530 | "with open('../input/glove6b50d/glove.6B.50d.txt', encoding='utf-8') as f:\n",
531 | " for line in f:\n",
532 | " values = line.split()\n",
533 | " word = values[0]\n",
534 | " coefs = np.asarray(values[1:], dtype='float32')\n",
535 | " embeddings_index[word] = coefs\n",
536 | " f.close()\n",
537 | "\n",
538 | "print(\"Glove Loded!\")\n"
539 | ]
540 | },
541 | {
542 | "cell_type": "code",
543 | "execution_count": 12,
544 | "metadata": {},
545 | "outputs": [],
546 | "source": [
547 | "\n",
548 | "embedding_dimention = 50\n",
549 | "def embedding_matrix_creater(embedding_dimention, word_index):\n",
550 | " embedding_matrix = np.zeros((len(word_index)+1, embedding_dimention))\n",
551 | " for word, i in word_index.items():\n",
552 | " embedding_vector = embeddings_index.get(word)\n",
553 | " if embedding_vector is not None:\n",
554 | " # words not found in embedding index will be all-zeros.\n",
555 | " embedding_matrix[i] = embedding_vector\n",
556 | " return embedding_matrix\n",
557 | "embedding_matrix = embedding_matrix_creater(50, word_index=vocab) \n"
558 | ]
559 | },
560 | {
561 | "cell_type": "code",
562 | "execution_count": 13,
563 | "metadata": {},
564 | "outputs": [],
565 | "source": [
566 | "del(embeddings_index)"
567 | ]
568 | },
569 | {
570 | "cell_type": "code",
571 | "execution_count": 14,
572 | "metadata": {},
573 | "outputs": [
574 | {
575 | "data": {
576 | "text/plain": [
577 | "(3028, 50)"
578 | ]
579 | },
580 | "execution_count": 14,
581 | "metadata": {},
582 | "output_type": "execute_result"
583 | }
584 | ],
585 | "source": [
586 | "embedding_matrix.shape"
587 | ]
588 | },
589 | {
590 | "cell_type": "code",
591 | "execution_count": 15,
592 | "metadata": {},
593 | "outputs": [
594 | {
595 | "data": {
596 | "text/plain": [
597 | "array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
598 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
599 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])"
600 | ]
601 | },
602 | "execution_count": 15,
603 | "metadata": {},
604 | "output_type": "execute_result"
605 | }
606 | ],
607 | "source": [
608 | "embedding_matrix[0]"
609 | ]
610 | },
611 | {
612 | "cell_type": "code",
613 | "execution_count": 16,
614 | "metadata": {},
615 | "outputs": [],
616 | "source": [
617 | "from tensorflow.keras.models import Model\n",
618 | "from tensorflow.keras.layers import Dense, Embedding, LSTM, Input, Bidirectional, Concatenate, Dropout, Attention"
619 | ]
620 | },
621 | {
622 | "cell_type": "code",
623 | "execution_count": 17,
624 | "metadata": {},
625 | "outputs": [],
626 | "source": [
627 | "embed = Embedding(VOCAB_SIZE+1, \n",
628 | " 50, \n",
629 | " \n",
630 | " input_length=13,\n",
631 | " trainable=True)\n",
632 | "\n",
633 | "embed.build((None,))\n",
634 | "embed.set_weights([embedding_matrix])\n"
635 | ]
636 | },
637 | {
638 | "cell_type": "markdown",
639 | "execution_count": null,
640 | "metadata": {},
641 | "source": [
642 | "# Model"
643 | ]
644 | },
645 | {
646 | "cell_type": "code",
647 | "execution_count": 18,
648 | "metadata": {},
649 | "outputs": [],
650 | "source": [
651 | "enc_inp = Input(shape=(13, ))"
652 | ]
653 | },
654 | {
655 | "cell_type": "code",
656 | "execution_count": 19,
657 | "metadata": {},
658 | "outputs": [
659 | {
660 | "name": "stdout",
661 | "output_type": "stream",
662 | "text": [
663 | "Model: \"model\"\n",
664 | "__________________________________________________________________________________________________\n",
665 | "Layer (type) Output Shape Param # Connected to \n",
666 | "==================================================================================================\n",
667 | "input_2 (InputLayer) [(None, 13)] 0 \n",
668 | "__________________________________________________________________________________________________\n",
669 | "input_1 (InputLayer) [(None, 13)] 0 \n",
670 | "__________________________________________________________________________________________________\n",
671 | "embedding (Embedding) (None, 13, 50) 151400 input_1[0][0] \n",
672 | " input_2[0][0] \n",
673 | "__________________________________________________________________________________________________\n",
674 | "bidirectional (Bidirectional) [(None, 13, 800), (N 1443200 embedding[0][0] \n",
675 | "__________________________________________________________________________________________________\n",
676 | "concatenate (Concatenate) (None, 800) 0 bidirectional[0][1] \n",
677 | " bidirectional[0][3] \n",
678 | "__________________________________________________________________________________________________\n",
679 | "concatenate_1 (Concatenate) (None, 800) 0 bidirectional[0][2] \n",
680 | " bidirectional[0][4] \n",
681 | "__________________________________________________________________________________________________\n",
682 | "lstm_1 (LSTM) [(None, 13, 800), (N 2723200 embedding[1][0] \n",
683 | " concatenate[0][0] \n",
684 | " concatenate_1[0][0] \n",
685 | "__________________________________________________________________________________________________\n",
686 | "attention_layer (AttentionLayer ((None, 13, 800), (N 1280800 bidirectional[0][0] \n",
687 | " lstm_1[0][0] \n",
688 | "__________________________________________________________________________________________________\n",
689 | "concatenate_2 (Concatenate) (None, 13, 1600) 0 lstm_1[0][0] \n",
690 | " attention_layer[0][0] \n",
691 | "__________________________________________________________________________________________________\n",
692 | "dense (Dense) (None, 13, 3027) 4846227 concatenate_2[0][0] \n",
693 | "==================================================================================================\n",
694 | "Total params: 10,444,827\n",
695 | "Trainable params: 10,444,827\n",
696 | "Non-trainable params: 0\n",
697 | "__________________________________________________________________________________________________\n"
698 | ]
699 | }
700 | ],
701 | "source": [
702 | "#embed = Embedding(VOCAB_SIZE+1, 50, mask_zero=True, input_length=13)(enc_inp)\n",
703 | "enc_embed = embed(enc_inp)\n",
704 | "enc_lstm = Bidirectional(LSTM(400, return_state=True, dropout=0.05, return_sequences = True))\n",
705 | "\n",
706 | "encoder_outputs, forward_h, forward_c, backward_h, backward_c = enc_lstm(enc_embed)\n",
707 | "\n",
708 | "state_h = Concatenate()([forward_h, backward_h])\n",
709 | "state_c = Concatenate()([forward_c, backward_c])\n",
710 | "\n",
711 | "enc_states = [state_h, state_c]\n",
712 | "\n",
713 | "\n",
714 | "dec_inp = Input(shape=(13, ))\n",
715 | "dec_embed = embed(dec_inp)\n",
716 | "dec_lstm = LSTM(400*2, return_state=True, return_sequences=True, dropout=0.05)\n",
717 | "output, _, _ = dec_lstm(dec_embed, initial_state=enc_states)\n",
718 | "\n",
719 | "# attention\n",
720 | "attn_layer = AttentionLayer()\n",
721 | "attn_op, attn_state = attn_layer([encoder_outputs, output])\n",
722 | "decoder_concat_input = Concatenate(axis=-1)([output, attn_op])\n",
723 | "\n",
724 | "\n",
725 | "dec_dense = Dense(VOCAB_SIZE, activation='softmax')\n",
726 | "final_output = dec_dense(decoder_concat_input)\n",
727 | "\n",
728 | "model = Model([enc_inp, dec_inp], final_output)\n",
729 | "\n",
730 | "model.summary()"
731 | ]
732 | },
733 | {
734 | "cell_type": "code",
735 | "execution_count": 20,
736 | "metadata": {},
737 | "outputs": [
738 | {
739 | "name": "stderr",
740 | "output_type": "stream",
741 | "text": [
742 | "Using TensorFlow backend.\n"
743 | ]
744 | }
745 | ],
746 | "source": [
747 | "import keras\n",
748 | "import tensorflow as tf"
749 | ]
750 | },
751 | {
752 | "cell_type": "code",
753 | "execution_count": 21,
754 | "metadata": {},
755 | "outputs": [],
756 | "source": [
757 | "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])"
758 | ]
759 | },
760 | {
761 | "cell_type": "code",
762 | "execution_count": 22,
763 | "metadata": {
764 | "collapsed": true
765 | },
766 | "outputs": [
767 | {
768 | "name": "stdout",
769 | "output_type": "stream",
770 | "text": [
771 | "1063/1063 [==============================] - 54s 50ms/step - loss: 2.8887 - acc: 0.5147 - val_loss: 2.6757 - val_acc: 0.5333\n"
772 | ]
773 | },
774 | {
775 | "data": {
776 | "text/plain": [
777 | ""
778 | ]
779 | },
780 | "execution_count": 22,
781 | "metadata": {},
782 | "output_type": "execute_result"
783 | }
784 | ],
785 | "source": [
786 | "model.fit([encoder_inp, decoder_inp], decoder_final_output, epochs=1, batch_size=24, validation_split=0.15)"
787 | ]
788 | },
789 | {
790 | "cell_type": "markdown",
791 | "execution_count": null,
792 | "metadata": {},
793 | "source": [
794 | "# inferece"
795 | ]
796 | },
797 | {
798 | "cell_type": "code",
799 | "execution_count": 23,
800 | "metadata": {},
801 | "outputs": [],
802 | "source": [
803 | "model.save('chatbot.h5')\n",
804 | "model.save_weights('chatbot_weights.h5')"
805 | ]
806 | },
807 | {
808 | "cell_type": "markdown",
809 | "execution_count": null,
810 | "metadata": {},
811 | "source": [
812 | "# Attention Inference\n"
813 | ]
814 | },
815 | {
816 | "cell_type": "code",
817 | "execution_count": 24,
818 | "metadata": {},
819 | "outputs": [],
820 | "source": [
821 | "enc_model = tf.keras.models.Model(enc_inp, [encoder_outputs, enc_states])\n",
822 | "\n",
823 | "\n",
824 | "decoder_state_input_h = tf.keras.layers.Input(shape=( 400 * 2,))\n",
825 | "decoder_state_input_c = tf.keras.layers.Input(shape=( 400 * 2,))\n",
826 | "\n",
827 | "decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]\n",
828 | "\n",
829 | "\n",
830 | "decoder_outputs, state_h, state_c = dec_lstm(dec_embed , initial_state=decoder_states_inputs)\n",
831 | "\n",
832 | "\n",
833 | "decoder_states = [state_h, state_c]\n",
834 | "\n",
835 | "#decoder_output = dec_dense(decoder_outputs)\n",
836 | "\n",
837 | "dec_model = tf.keras.models.Model([dec_inp, decoder_states_inputs],\n",
838 | " [decoder_outputs] + decoder_states)\n"
839 | ]
840 | },
841 | {
842 | "cell_type": "code",
843 | "execution_count": 25,
844 | "metadata": {},
845 | "outputs": [
846 | {
847 | "name": "stdout",
848 | "output_type": "stream",
849 | "text": [
850 | "##########################################\n",
851 | "# start chatting ver. 1.0 #\n",
852 | "##########################################\n"
853 | ]
854 | },
855 | {
856 | "ename": "StdinNotImplementedError",
857 | "evalue": "raw_input was called, but this frontend does not support input requests.",
858 | "output_type": "error",
859 | "traceback": [
860 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
861 | "\u001b[0;31mStdinNotImplementedError\u001b[0m Traceback (most recent call last)",
862 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mprepro1\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'q'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mprepro1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"you : \"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0mprepro\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mprepro1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
863 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mraw_input\u001b[0;34m(self, prompt)\u001b[0m\n\u001b[1;32m 853\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_allow_stdin\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 854\u001b[0m raise StdinNotImplementedError(\n\u001b[0;32m--> 855\u001b[0;31m \u001b[0;34m\"raw_input was called, but this frontend does not support input requests.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 856\u001b[0m )\n\u001b[1;32m 857\u001b[0m return self._input_request(str(prompt),\n",
864 | "\u001b[0;31mStdinNotImplementedError\u001b[0m: raw_input was called, but this frontend does not support input requests."
865 | ]
866 | }
867 | ],
868 | "source": [
869 | "print(\"##########################################\")\n",
870 | "print(\"# start chatting ver. 1.0 #\")\n",
871 | "print(\"##########################################\")\n",
872 | "\n",
873 | "\n",
874 | "prepro1 = \"\"\n",
875 | "while prepro1 != 'q':\n",
876 | " \n",
877 | " prepro1 = input(\"you : \")\n",
878 | " prepro = [prepro1]\n",
879 | " \n",
880 | " try:\n",
881 | " txt = []\n",
882 | " for x in prepro:\n",
883 | " lst = []\n",
884 | " for y in x.split():\n",
885 | " lst.append(vocab[y])\n",
886 | " txt.append(lst)\n",
887 | " txt = pad_sequences(txt, 13, padding='post')\n",
888 | "\n",
889 | "\n",
890 | " ###\n",
891 | " enc_op, stat = enc_model.predict( txt )\n",
892 | "\n",
893 | " empty_target_seq = np.zeros( ( 1 , 1) )\n",
894 | " empty_target_seq[0, 0] = vocab['']\n",
895 | " stop_condition = False\n",
896 | " decoded_translation = ''\n",
897 | "\n",
898 | "\n",
899 | " while not stop_condition :\n",
900 | "\n",
901 | " dec_outputs , h , c = dec_model.predict([ empty_target_seq ] + stat )\n",
902 | "\n",
903 | " ###\n",
904 | " ###########################\n",
905 | " attn_op, attn_state = attn_layer([enc_op, dec_outputs])\n",
906 | " decoder_concat_input = Concatenate(axis=-1)([dec_outputs, attn_op])\n",
907 | " decoder_concat_input = dec_dense(decoder_concat_input)\n",
908 | " ###########################\n",
909 | "\n",
910 | " sampled_word_index = np.argmax( decoder_concat_input[0, -1, :] )\n",
911 | "\n",
912 | " sampled_word = inv_vocab[sampled_word_index] + ' '\n",
913 | "\n",
914 | " if sampled_word != ' ':\n",
915 | " decoded_translation += sampled_word \n",
916 | "\n",
917 | "\n",
918 | " if sampled_word == ' ' or len(decoded_translation.split()) > 13:\n",
919 | " stop_condition = True\n",
920 | "\n",
921 | " empty_target_seq = np.zeros( ( 1 , 1 ) ) \n",
922 | " empty_target_seq[ 0 , 0 ] = sampled_word_index\n",
923 | " stat = [ h , c ] \n",
924 | " except:\n",
925 | " pass\n",
926 | "\n",
927 | " print(\"chatbot attention : \", decoded_translation )\n",
928 | " print(\"==============================================\")\n",
929 | "\n",
930 | "\n"
931 | ]
932 | },
933 | {
934 | "cell_type": "code",
935 | "execution_count": null,
936 | "metadata": {},
937 | "outputs": [],
938 | "source": []
939 | },
940 | {
941 | "cell_type": "code",
942 | "execution_count": null,
943 | "metadata": {},
944 | "outputs": [],
945 | "source": []
946 | }
947 | ],
948 | "metadata": {
949 | "kernelspec": {
950 | "display_name": "Python 3",
951 | "language": "python",
952 | "name": "python3"
953 | },
954 | "language_info": {
955 | "codemirror_mode": {
956 | "name": "ipython",
957 | "version": 3
958 | },
959 | "file_extension": ".py",
960 | "mimetype": "text/x-python",
961 | "name": "python",
962 | "nbconvert_exporter": "python",
963 | "pygments_lexer": "ipython3",
964 | "version": "3.7.6"
965 | }
966 | },
967 | "nbformat": 4,
968 | "nbformat_minor": 4
969 | }
970 |
--------------------------------------------------------------------------------