├── .gitignore
├── README.md
├── data
├── __init__.py
├── cornell_corpus
│ └── data.py
└── twitter
│ ├── data.py
│ ├── idx_a.npy
│ ├── idx_q.npy
│ ├── metadata.pkl
│ ├── pull
│ └── pull_raw_data
├── main.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | dea
2 | docs/_build
3 | tensorlayer
4 | tensorlayer/__pacache__
5 | tensorlayer/.DS_Store
6 | .DS_Store
7 | dist
8 | build/
9 | tensorlayer.egg-info
10 | data/.DS_Store
11 | *.pyc
12 | *.gz
13 | .spyproject/
14 | .vscode/*
15 | model.npz
16 | env/
17 | venv/
18 | .idea/
19 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Seq2Seq Chatbot
2 |
3 | This is a 200 lines implementation of Twitter/Cornell-Movie Chatbot, please read the following references before you read the code:
4 |
5 | - [Practical-Seq2Seq](http://suriyadeepan.github.io/2016-12-31-practical-seq2seq/)
6 | - [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)
7 | - [Understanding LSTM Networks](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) (optional)
8 |
9 | ### Prerequisites
10 |
11 | - Python 3.6
12 | - [TensorFlow](https://github.com/tensorflow/tensorflow) >= 2.0
13 | - [TensorLayer](https://github.com/zsdonghao/tensorlayer) >= 2.0
14 |
15 | ### Model
16 |
17 |
18 |
19 |

20 |
21 |
22 |
23 |
24 |
25 | ### Training
26 |
27 | ```
28 | python3 main.py
29 | ```
30 |
31 |
32 | ### Results
33 |
34 | ```
35 | Query > happy birthday have a nice day
36 | > thank you so much
37 | > thank babe
38 | > thank bro
39 | > thanks so much
40 | > thank babe i appreciate it
41 | Query > donald trump won last nights presidential debate according to snap online polls
42 | > i dont know what the fuck is that
43 | > i think he was a racist
44 | > he is not a racist
45 | > he is a liar
46 | > trump needs to be president
47 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | # from . import twitter
5 | # from . import imagenet_classes
6 | # from . import
7 |
--------------------------------------------------------------------------------
/data/cornell_corpus/data.py:
--------------------------------------------------------------------------------
1 | EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz ' # space is included in whitelist
2 | EN_BLACKLIST = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\''
3 |
4 | limit = {
5 | 'maxq' : 25,
6 | 'minq' : 2,
7 | 'maxa' : 25,
8 | 'mina' : 2
9 | }
10 |
11 | UNK = 'unk'
12 | VOCAB_SIZE = 8000
13 |
14 |
15 | import random
16 |
17 | import nltk
18 | import itertools
19 | from collections import defaultdict
20 |
21 | import numpy as np
22 |
23 | import pickle
24 |
25 |
26 |
27 | '''
28 | 1. Read from 'movie-lines.txt'
29 | 2. Create a dictionary with ( key = line_id, value = text )
30 | '''
31 | def get_id2line():
32 | lines=open('raw_data/movie_lines.txt', encoding='utf-8', errors='ignore').read().split('\n')
33 | id2line = {}
34 | for line in lines:
35 | _line = line.split(' +++$+++ ')
36 | if len(_line) == 5:
37 | id2line[_line[0]] = _line[4]
38 | return id2line
39 |
40 | '''
41 | 1. Read from 'movie_conversations.txt'
42 | 2. Create a list of [list of line_id's]
43 | '''
44 | def get_conversations():
45 | conv_lines = open('raw_data/movie_conversations.txt', encoding='utf-8', errors='ignore').read().split('\n')
46 | convs = [ ]
47 | for line in conv_lines[:-1]:
48 | _line = line.split(' +++$+++ ')[-1][1:-1].replace("'","").replace(" ","")
49 | convs.append(_line.split(','))
50 | return convs
51 |
52 | '''
53 | 1. Get each conversation
54 | 2. Get each line from conversation
55 | 3. Save each conversation to file
56 | '''
57 | def extract_conversations(convs,id2line,path=''):
58 | idx = 0
59 | for conv in convs:
60 | f_conv = open(path + str(idx)+'.txt', 'w')
61 | for line_id in conv:
62 | f_conv.write(id2line[line_id])
63 | f_conv.write('\n')
64 | f_conv.close()
65 | idx += 1
66 |
67 | '''
68 | Get lists of all conversations as Questions and Answers
69 | 1. [questions]
70 | 2. [answers]
71 | '''
72 | def gather_dataset(convs, id2line):
73 | questions = []; answers = []
74 |
75 | for conv in convs:
76 | if len(conv) %2 != 0:
77 | conv = conv[:-1]
78 | for i in range(len(conv)):
79 | if i%2 == 0:
80 | questions.append(id2line[conv[i]])
81 | else:
82 | answers.append(id2line[conv[i]])
83 |
84 | return questions, answers
85 |
86 |
87 | '''
88 | We need 4 files
89 | 1. train.enc : Encoder input for training
90 | 2. train.dec : Decoder input for training
91 | 3. test.enc : Encoder input for testing
92 | 4. test.dec : Decoder input for testing
93 | '''
94 | def prepare_seq2seq_files(questions, answers, path='',TESTSET_SIZE = 30000):
95 |
96 | # open files
97 | train_enc = open(path + 'train.enc','w')
98 | train_dec = open(path + 'train.dec','w')
99 | test_enc = open(path + 'test.enc', 'w')
100 | test_dec = open(path + 'test.dec', 'w')
101 |
102 | # choose 30,000 (TESTSET_SIZE) items to put into testset
103 | test_ids = random.sample([i for i in range(len(questions))],TESTSET_SIZE)
104 |
105 | for i in range(len(questions)):
106 | if i in test_ids:
107 | test_enc.write(questions[i]+'\n')
108 | test_dec.write(answers[i]+ '\n' )
109 | else:
110 | train_enc.write(questions[i]+'\n')
111 | train_dec.write(answers[i]+ '\n' )
112 | if i%10000 == 0:
113 | print('\n>> written {} lines'.format(i))
114 |
115 | # close files
116 | train_enc.close()
117 | train_dec.close()
118 | test_enc.close()
119 | test_dec.close()
120 |
121 |
122 |
123 | '''
124 | remove anything that isn't in the vocabulary
125 | return str(pure en)
126 |
127 | '''
128 | def filter_line(line, whitelist):
129 | return ''.join([ ch for ch in line if ch in whitelist ])
130 |
131 |
132 |
133 | '''
134 | filter too long and too short sequences
135 | return tuple( filtered_ta, filtered_en )
136 |
137 | '''
138 | def filter_data(qseq, aseq):
139 | filtered_q, filtered_a = [], []
140 | raw_data_len = len(qseq)
141 |
142 | assert len(qseq) == len(aseq)
143 |
144 | for i in range(raw_data_len):
145 | qlen, alen = len(qseq[i].split(' ')), len(aseq[i].split(' '))
146 | if qlen >= limit['minq'] and qlen <= limit['maxq']:
147 | if alen >= limit['mina'] and alen <= limit['maxa']:
148 | filtered_q.append(qseq[i])
149 | filtered_a.append(aseq[i])
150 |
151 | # print the fraction of the original data, filtered
152 | filt_data_len = len(filtered_q)
153 | filtered = int((raw_data_len - filt_data_len)*100/raw_data_len)
154 | print(str(filtered) + '% filtered from original data')
155 |
156 | return filtered_q, filtered_a
157 |
158 |
159 | '''
160 | read list of words, create index to word,
161 | word to index dictionaries
162 | return tuple( vocab->(word, count), idx2w, w2idx )
163 |
164 | '''
165 | def index_(tokenized_sentences, vocab_size):
166 | # get frequency distribution
167 | freq_dist = nltk.FreqDist(itertools.chain(*tokenized_sentences))
168 | # get vocabulary of 'vocab_size' most used words
169 | vocab = freq_dist.most_common(vocab_size)
170 | # index2word
171 | index2word = ['_'] + [UNK] + [ x[0] for x in vocab ]
172 | # word2index
173 | word2index = dict([(w,i) for i,w in enumerate(index2word)] )
174 | return index2word, word2index, freq_dist
175 |
176 | '''
177 | filter based on number of unknowns (words not in vocabulary)
178 | filter out the worst sentences
179 |
180 | '''
181 | def filter_unk(qtokenized, atokenized, w2idx):
182 | data_len = len(qtokenized)
183 |
184 | filtered_q, filtered_a = [], []
185 |
186 | for qline, aline in zip(qtokenized, atokenized):
187 | unk_count_q = len([ w for w in qline if w not in w2idx ])
188 | unk_count_a = len([ w for w in aline if w not in w2idx ])
189 | if unk_count_a <= 2:
190 | if unk_count_q > 0:
191 | if unk_count_q/len(qline) > 0.2:
192 | pass
193 | filtered_q.append(qline)
194 | filtered_a.append(aline)
195 |
196 | # print the fraction of the original data, filtered
197 | filt_data_len = len(filtered_q)
198 | filtered = int((data_len - filt_data_len)*100/data_len)
199 | print(str(filtered) + '% filtered from original data')
200 |
201 | return filtered_q, filtered_a
202 |
203 |
204 |
205 |
206 | '''
207 | create the final dataset :
208 | - convert list of items to arrays of indices
209 | - add zero padding
210 | return ( [array_en([indices]), array_ta([indices]) )
211 |
212 | '''
213 | def zero_pad(qtokenized, atokenized, w2idx):
214 | # num of rows
215 | data_len = len(qtokenized)
216 |
217 | # numpy arrays to store indices
218 | idx_q = np.zeros([data_len, limit['maxq']], dtype=np.int32)
219 | idx_a = np.zeros([data_len, limit['maxa']], dtype=np.int32)
220 |
221 | for i in range(data_len):
222 | q_indices = pad_seq(qtokenized[i], w2idx, limit['maxq'])
223 | a_indices = pad_seq(atokenized[i], w2idx, limit['maxa'])
224 |
225 | #print(len(idx_q[i]), len(q_indices))
226 | #print(len(idx_a[i]), len(a_indices))
227 | idx_q[i] = np.array(q_indices)
228 | idx_a[i] = np.array(a_indices)
229 |
230 | return idx_q, idx_a
231 |
232 |
233 | '''
234 | replace words with indices in a sequence
235 | replace with unknown if word not in lookup
236 | return [list of indices]
237 |
238 | '''
239 | def pad_seq(seq, lookup, maxlen):
240 | indices = []
241 | for word in seq:
242 | if word in lookup:
243 | indices.append(lookup[word])
244 | else:
245 | indices.append(lookup[UNK])
246 | return indices + [0]*(maxlen - len(seq))
247 |
248 |
249 |
250 |
251 |
252 | def process_data():
253 |
254 | id2line = get_id2line()
255 | print('>> gathered id2line dictionary.\n')
256 | convs = get_conversations()
257 | print(convs[121:125])
258 | print('>> gathered conversations.\n')
259 | questions, answers = gather_dataset(convs,id2line)
260 |
261 | # change to lower case (just for en)
262 | questions = [ line.lower() for line in questions ]
263 | answers = [ line.lower() for line in answers ]
264 |
265 | # filter out unnecessary characters
266 | print('\n>> Filter lines')
267 | questions = [ filter_line(line, EN_WHITELIST) for line in questions ]
268 | answers = [ filter_line(line, EN_WHITELIST) for line in answers ]
269 |
270 | # filter out too long or too short sequences
271 | print('\n>> 2nd layer of filtering')
272 | qlines, alines = filter_data(questions, answers)
273 |
274 | for q,a in zip(qlines[141:145], alines[141:145]):
275 | print('q : [{0}]; a : [{1}]'.format(q,a))
276 |
277 | # convert list of [lines of text] into list of [list of words ]
278 | print('\n>> Segment lines into words')
279 | qtokenized = [ [w.strip() for w in wordlist.split(' ') if w] for wordlist in qlines ]
280 | atokenized = [ [w.strip() for w in wordlist.split(' ') if w] for wordlist in alines ]
281 | print('\n:: Sample from segmented list of words')
282 |
283 | for q,a in zip(qtokenized[141:145], atokenized[141:145]):
284 | print('q : [{0}]; a : [{1}]'.format(q,a))
285 |
286 | # indexing -> idx2w, w2idx
287 | print('\n >> Index words')
288 | idx2w, w2idx, freq_dist = index_( qtokenized + atokenized, vocab_size=VOCAB_SIZE)
289 |
290 | # filter out sentences with too many unknowns
291 | print('\n >> Filter Unknowns')
292 | qtokenized, atokenized = filter_unk(qtokenized, atokenized, w2idx)
293 | print('\n Final dataset len : ' + str(len(qtokenized)))
294 |
295 |
296 | print('\n >> Zero Padding')
297 | idx_q, idx_a = zero_pad(qtokenized, atokenized, w2idx)
298 |
299 | print('\n >> Save numpy arrays to disk')
300 | # save them
301 | np.save('idx_q.npy', idx_q)
302 | np.save('idx_a.npy', idx_a)
303 |
304 | # let us now save the necessary dictionaries
305 | metadata = {
306 | 'w2idx' : w2idx,
307 | 'idx2w' : idx2w,
308 | 'limit' : limit,
309 | 'freq_dist' : freq_dist
310 | }
311 |
312 | # write to disk : data control dictionaries
313 | with open('metadata.pkl', 'wb') as f:
314 | pickle.dump(metadata, f)
315 |
316 | # count of unknowns
317 | unk_count = (idx_q == 1).sum() + (idx_a == 1).sum()
318 | # count of words
319 | word_count = (idx_q > 1).sum() + (idx_a > 1).sum()
320 |
321 | print('% unknown : {0}'.format(100 * (unk_count/word_count)))
322 | print('Dataset count : ' + str(idx_q.shape[0]))
323 |
324 |
325 | #print '>> gathered questions and answers.\n'
326 | #prepare_seq2seq_files(questions,answers)
327 |
328 |
329 | import numpy as np
330 | from random import sample
331 |
332 | '''
333 | split data into train (70%), test (15%) and valid(15%)
334 | return tuple( (trainX, trainY), (testX,testY), (validX,validY) )
335 |
336 | '''
337 | def split_dataset(x, y, ratio = [0.7, 0.15, 0.15] ):
338 | # number of examples
339 | data_len = len(x)
340 | lens = [ int(data_len*item) for item in ratio ]
341 |
342 | trainX, trainY = x[:lens[0]], y[:lens[0]]
343 | testX, testY = x[lens[0]:lens[0]+lens[1]], y[lens[0]:lens[0]+lens[1]]
344 | validX, validY = x[-lens[-1]:], y[-lens[-1]:]
345 |
346 | return (trainX,trainY), (testX,testY), (validX,validY)
347 |
348 |
349 | '''
350 | generate batches from dataset
351 | yield (x_gen, y_gen)
352 |
353 | TODO : fix needed
354 |
355 | '''
356 | def batch_gen(x, y, batch_size):
357 | # infinite while
358 | while True:
359 | for i in range(0, len(x), batch_size):
360 | if (i+1)*batch_size < len(x):
361 | yield x[i : (i+1)*batch_size ].T, y[i : (i+1)*batch_size ].T
362 |
363 | '''
364 | generate batches, by random sampling a bunch of items
365 | yield (x_gen, y_gen)
366 |
367 | '''
368 | def rand_batch_gen(x, y, batch_size):
369 | while True:
370 | sample_idx = sample(list(np.arange(len(x))), batch_size)
371 | yield x[sample_idx].T, y[sample_idx].T
372 |
373 | #'''
374 | # convert indices of alphabets into a string (word)
375 | # return str(word)
376 | #
377 | #'''
378 | #def decode_word(alpha_seq, idx2alpha):
379 | # return ''.join([ idx2alpha[alpha] for alpha in alpha_seq if alpha ])
380 | #
381 | #
382 | #'''
383 | # convert indices of phonemes into list of phonemes (as string)
384 | # return str(phoneme_list)
385 | #
386 | #'''
387 | #def decode_phonemes(pho_seq, idx2pho):
388 | # return ' '.join( [ idx2pho[pho] for pho in pho_seq if pho ])
389 |
390 |
391 | '''
392 | a generic decode function
393 | inputs : sequence, lookup
394 |
395 | '''
396 | def decode(sequence, lookup, separator=''): # 0 used for padding, is ignored
397 | return separator.join([ lookup[element] for element in sequence if element ])
398 |
399 |
400 |
401 | if __name__ == '__main__':
402 | process_data()
403 |
404 |
405 | def load_data(PATH=''):
406 | # read data control dictionaries
407 | with open(PATH + 'metadata.pkl', 'rb') as f:
408 | metadata = pickle.load(f)
409 | # read numpy arrays
410 | idx_q = np.load(PATH + 'idx_q.npy')
411 | idx_a = np.load(PATH + 'idx_a.npy')
412 | return metadata, idx_q, idx_a
413 |
--------------------------------------------------------------------------------
/data/twitter/data.py:
--------------------------------------------------------------------------------
1 | EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz ' # space is included in whitelist
2 | EN_BLACKLIST = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\''
3 |
4 | FILENAME = 'data/chat.txt'
5 |
6 | limit = {
7 | 'maxq' : 20,
8 | 'minq' : 0,
9 | 'maxa' : 20,
10 | 'mina' : 3
11 | }
12 |
13 | UNK = 'unk'
14 | VOCAB_SIZE = 6000
15 |
16 | import random
17 | import sys
18 |
19 | import nltk
20 | import itertools
21 | from collections import defaultdict
22 |
23 | import numpy as np
24 |
25 | import pickle
26 |
27 |
28 | def ddefault():
29 | return 1
30 |
31 | '''
32 | read lines from file
33 | return [list of lines]
34 |
35 | '''
36 | def read_lines(filename):
37 | return open(filename).read().split('\n')[:-1]
38 |
39 |
40 | '''
41 | split sentences in one line
42 | into multiple lines
43 | return [list of lines]
44 |
45 | '''
46 | def split_line(line):
47 | return line.split('.')
48 |
49 |
50 | '''
51 | remove anything that isn't in the vocabulary
52 | return str(pure ta/en)
53 |
54 | '''
55 | def filter_line(line, whitelist):
56 | return ''.join([ ch for ch in line if ch in whitelist ])
57 |
58 |
59 | '''
60 | read list of words, create index to word,
61 | word to index dictionaries
62 | return tuple( vocab->(word, count), idx2w, w2idx )
63 |
64 | '''
65 | def index_(tokenized_sentences, vocab_size):
66 | # get frequency distribution
67 | freq_dist = nltk.FreqDist(itertools.chain(*tokenized_sentences))
68 | # get vocabulary of 'vocab_size' most used words
69 | vocab = freq_dist.most_common(vocab_size)
70 | # index2word
71 | index2word = ['_'] + [UNK] + [ x[0] for x in vocab ]
72 | # word2index
73 | word2index = dict([(w,i) for i,w in enumerate(index2word)] )
74 | return index2word, word2index, freq_dist
75 |
76 |
77 | '''
78 | filter too long and too short sequences
79 | return tuple( filtered_ta, filtered_en )
80 |
81 | '''
82 | def filter_data(sequences):
83 | filtered_q, filtered_a = [], []
84 | raw_data_len = len(sequences)//2
85 |
86 | for i in range(0, len(sequences), 2):
87 | qlen, alen = len(sequences[i].split(' ')), len(sequences[i+1].split(' '))
88 | if qlen >= limit['minq'] and qlen <= limit['maxq']:
89 | if alen >= limit['mina'] and alen <= limit['maxa']:
90 | filtered_q.append(sequences[i])
91 | filtered_a.append(sequences[i+1])
92 |
93 | # print the fraction of the original data, filtered
94 | filt_data_len = len(filtered_q)
95 | filtered = int((raw_data_len - filt_data_len)*100/raw_data_len)
96 | print(str(filtered) + '% filtered from original data')
97 |
98 | return filtered_q, filtered_a
99 |
100 |
101 |
102 |
103 |
104 | '''
105 | create the final dataset :
106 | - convert list of items to arrays of indices
107 | - add zero padding
108 | return ( [array_en([indices]), array_ta([indices]) )
109 |
110 | '''
111 | def zero_pad(qtokenized, atokenized, w2idx):
112 | # num of rows
113 | data_len = len(qtokenized)
114 |
115 | # numpy arrays to store indices
116 | idx_q = np.zeros([data_len, limit['maxq']], dtype=np.int32)
117 | idx_a = np.zeros([data_len, limit['maxa']], dtype=np.int32)
118 |
119 | for i in range(data_len):
120 | q_indices = pad_seq(qtokenized[i], w2idx, limit['maxq'])
121 | a_indices = pad_seq(atokenized[i], w2idx, limit['maxa'])
122 |
123 | #print(len(idx_q[i]), len(q_indices))
124 | #print(len(idx_a[i]), len(a_indices))
125 | idx_q[i] = np.array(q_indices)
126 | idx_a[i] = np.array(a_indices)
127 |
128 | return idx_q, idx_a
129 |
130 |
131 | '''
132 | replace words with indices in a sequence
133 | replace with unknown if word not in lookup
134 | return [list of indices]
135 |
136 | '''
137 | def pad_seq(seq, lookup, maxlen):
138 | indices = []
139 | for word in seq:
140 | if word in lookup:
141 | indices.append(lookup[word])
142 | else:
143 | indices.append(lookup[UNK])
144 | return indices + [0]*(maxlen - len(seq))
145 |
146 |
147 | def process_data():
148 |
149 | print('\n>> Read lines from file')
150 | lines = read_lines(filename=FILENAME)
151 |
152 | # change to lower case (just for en)
153 | lines = [ line.lower() for line in lines ]
154 |
155 | print('\n:: Sample from read(p) lines')
156 | print(lines[121:125])
157 |
158 | # filter out unnecessary characters
159 | print('\n>> Filter lines')
160 | lines = [ filter_line(line, EN_WHITELIST) for line in lines ]
161 | print(lines[121:125])
162 |
163 | # filter out too long or too short sequences
164 | print('\n>> 2nd layer of filtering')
165 | qlines, alines = filter_data(lines)
166 | print('\nq : {0} ; a : {1}'.format(qlines[60], alines[60]))
167 | print('\nq : {0} ; a : {1}'.format(qlines[61], alines[61]))
168 |
169 |
170 | # convert list of [lines of text] into list of [list of words ]
171 | print('\n>> Segment lines into words')
172 | qtokenized = [ wordlist.split(' ') for wordlist in qlines ]
173 | atokenized = [ wordlist.split(' ') for wordlist in alines ]
174 | print('\n:: Sample from segmented list of words')
175 | print('\nq : {0} ; a : {1}'.format(qtokenized[60], atokenized[60]))
176 | print('\nq : {0} ; a : {1}'.format(qtokenized[61], atokenized[61]))
177 |
178 |
179 | # indexing -> idx2w, w2idx : en/ta
180 | print('\n >> Index words')
181 | idx2w, w2idx, freq_dist = index_( qtokenized + atokenized, vocab_size=VOCAB_SIZE)
182 |
183 | print('\n >> Zero Padding')
184 | idx_q, idx_a = zero_pad(qtokenized, atokenized, w2idx)
185 |
186 | print('\n >> Save numpy arrays to disk')
187 | # save them
188 | np.save('idx_q.npy', idx_q)
189 | np.save('idx_a.npy', idx_a)
190 |
191 | # let us now save the necessary dictionaries
192 | metadata = {
193 | 'w2idx' : w2idx,
194 | 'idx2w' : idx2w,
195 | 'limit' : limit,
196 | 'freq_dist' : freq_dist
197 | }
198 |
199 | # write to disk : data control dictionaries
200 | with open('metadata.pkl', 'wb') as f:
201 | pickle.dump(metadata, f)
202 |
203 | def load_data(PATH=''):
204 | # read data control dictionaries
205 | try:
206 | with open(PATH + 'metadata.pkl', 'rb') as f:
207 | metadata = pickle.load(f)
208 | except:
209 | metadata = None
210 | # read numpy arrays
211 | idx_q = np.load(PATH + 'idx_q.npy')
212 | idx_a = np.load(PATH + 'idx_a.npy')
213 | return metadata, idx_q, idx_a
214 |
215 | import numpy as np
216 | from random import sample
217 |
218 | '''
219 | split data into train (70%), test (15%) and valid(15%)
220 | return tuple( (trainX, trainY), (testX,testY), (validX,validY) )
221 |
222 | '''
223 | def split_dataset(x, y, ratio = [0.7, 0.15, 0.15] ):
224 | # number of examples
225 | data_len = len(x)
226 | lens = [ int(data_len*item) for item in ratio ]
227 |
228 | trainX, trainY = x[:lens[0]], y[:lens[0]]
229 | testX, testY = x[lens[0]:lens[0]+lens[1]], y[lens[0]:lens[0]+lens[1]]
230 | validX, validY = x[-lens[-1]:], y[-lens[-1]:]
231 |
232 | return (trainX,trainY), (testX,testY), (validX,validY)
233 |
234 |
235 | '''
236 | generate batches from dataset
237 | yield (x_gen, y_gen)
238 |
239 | TODO : fix needed
240 |
241 | '''
242 | def batch_gen(x, y, batch_size):
243 | # infinite while
244 | while True:
245 | for i in range(0, len(x), batch_size):
246 | if (i+1)*batch_size < len(x):
247 | yield x[i : (i+1)*batch_size ].T, y[i : (i+1)*batch_size ].T
248 |
249 | '''
250 | generate batches, by random sampling a bunch of items
251 | yield (x_gen, y_gen)
252 |
253 | '''
254 | def rand_batch_gen(x, y, batch_size):
255 | while True:
256 | sample_idx = sample(list(np.arange(len(x))), batch_size)
257 | yield x[sample_idx].T, y[sample_idx].T
258 |
259 | #'''
260 | # convert indices of alphabets into a string (word)
261 | # return str(word)
262 | #
263 | #'''
264 | #def decode_word(alpha_seq, idx2alpha):
265 | # return ''.join([ idx2alpha[alpha] for alpha in alpha_seq if alpha ])
266 | #
267 | #
268 | #'''
269 | # convert indices of phonemes into list of phonemes (as string)
270 | # return str(phoneme_list)
271 | #
272 | #'''
273 | #def decode_phonemes(pho_seq, idx2pho):
274 | # return ' '.join( [ idx2pho[pho] for pho in pho_seq if pho ])
275 |
276 |
277 | '''
278 | a generic decode function
279 | inputs : sequence, lookup
280 |
281 | '''
282 | def decode(sequence, lookup, separator=''): # 0 used for padding, is ignored
283 | return separator.join([ lookup[element] for element in sequence if element ])
284 |
285 |
286 |
287 | if __name__ == '__main__':
288 | process_data()
289 |
--------------------------------------------------------------------------------
/data/twitter/idx_a.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorlayer/seq2seq-chatbot/3757307595b15e45a8870ffbe7728d72ddca1f96/data/twitter/idx_a.npy
--------------------------------------------------------------------------------
/data/twitter/idx_q.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorlayer/seq2seq-chatbot/3757307595b15e45a8870ffbe7728d72ddca1f96/data/twitter/idx_q.npy
--------------------------------------------------------------------------------
/data/twitter/metadata.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorlayer/seq2seq-chatbot/3757307595b15e45a8870ffbe7728d72ddca1f96/data/twitter/metadata.pkl
--------------------------------------------------------------------------------
/data/twitter/pull:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget -c 'https://www.dropbox.com/s/tmfwptbs3q180p0/seq2seq.twitter.tar.gz?dl=0' -O seq2seq.twitter.tar.gz
4 |
--------------------------------------------------------------------------------
/data/twitter/pull_raw_data:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -c https://raw.githubusercontent.com/Marsan-Ma/chat_corpus/master/twitter_en.txt.gz
3 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | import tensorflow as tf
5 | import tensorlayer as tl
6 | import numpy as np
7 | from tensorlayer.cost import cross_entropy_seq, cross_entropy_seq_with_mask
8 | from tqdm import tqdm
9 | from sklearn.utils import shuffle
10 | from data.twitter import data
11 | from tensorlayer.models.seq2seq import Seq2seq
12 | from tensorlayer.models.seq2seq_with_attention import Seq2seqLuongAttention
13 | import os
14 |
15 |
16 | def initial_setup(data_corpus):
17 | metadata, idx_q, idx_a = data.load_data(PATH='data/{}/'.format(data_corpus))
18 | (trainX, trainY), (testX, testY), (validX, validY) = data.split_dataset(idx_q, idx_a)
19 | trainX = tl.prepro.remove_pad_sequences(trainX.tolist())
20 | trainY = tl.prepro.remove_pad_sequences(trainY.tolist())
21 | testX = tl.prepro.remove_pad_sequences(testX.tolist())
22 | testY = tl.prepro.remove_pad_sequences(testY.tolist())
23 | validX = tl.prepro.remove_pad_sequences(validX.tolist())
24 | validY = tl.prepro.remove_pad_sequences(validY.tolist())
25 | return metadata, trainX, trainY, testX, testY, validX, validY
26 |
27 |
28 |
29 | if __name__ == "__main__":
30 | data_corpus = "twitter"
31 |
32 | #data preprocessing
33 | metadata, trainX, trainY, testX, testY, validX, validY = initial_setup(data_corpus)
34 |
35 | # Parameters
36 | src_len = len(trainX)
37 | tgt_len = len(trainY)
38 |
39 | assert src_len == tgt_len
40 |
41 | batch_size = 32
42 | n_step = src_len // batch_size
43 | src_vocab_size = len(metadata['idx2w']) # 8002 (0~8001)
44 | emb_dim = 1024
45 |
46 | word2idx = metadata['w2idx'] # dict word 2 index
47 | idx2word = metadata['idx2w'] # list index 2 word
48 |
49 | unk_id = word2idx['unk'] # 1
50 | pad_id = word2idx['_'] # 0
51 |
52 | start_id = src_vocab_size # 8002
53 | end_id = src_vocab_size + 1 # 8003
54 |
55 | word2idx.update({'start_id': start_id})
56 | word2idx.update({'end_id': end_id})
57 | idx2word = idx2word + ['start_id', 'end_id']
58 |
59 | src_vocab_size = tgt_vocab_size = src_vocab_size + 2
60 |
61 | num_epochs = 50
62 | vocabulary_size = src_vocab_size
63 |
64 |
65 |
66 | def inference(seed, top_n):
67 | model_.eval()
68 | seed_id = [word2idx.get(w, unk_id) for w in seed.split(" ")]
69 | sentence_id = model_(inputs=[[seed_id]], seq_length=20, start_token=start_id, top_n = top_n)
70 | sentence = []
71 | for w_id in sentence_id[0]:
72 | w = idx2word[w_id]
73 | if w == 'end_id':
74 | break
75 | sentence = sentence + [w]
76 | return sentence
77 |
78 | decoder_seq_length = 20
79 | model_ = Seq2seq(
80 | decoder_seq_length = decoder_seq_length,
81 | cell_enc=tf.keras.layers.GRUCell,
82 | cell_dec=tf.keras.layers.GRUCell,
83 | n_layer=3,
84 | n_units=256,
85 | embedding_layer=tl.layers.Embedding(vocabulary_size=vocabulary_size, embedding_size=emb_dim),
86 | )
87 |
88 |
89 | # Uncomment below statements if you have already saved the model
90 |
91 | # load_weights = tl.files.load_npz(name='model.npz')
92 | # tl.files.assign_weights(load_weights, model_)
93 |
94 | optimizer = tf.optimizers.Adam(learning_rate=0.001)
95 | model_.train()
96 |
97 | seeds = ["happy birthday have a nice day",
98 | "donald trump won last nights presidential debate according to snap online polls"]
99 | for epoch in range(num_epochs):
100 | model_.train()
101 | trainX, trainY = shuffle(trainX, trainY, random_state=0)
102 | total_loss, n_iter = 0, 0
103 | for X, Y in tqdm(tl.iterate.minibatches(inputs=trainX, targets=trainY, batch_size=batch_size, shuffle=False),
104 | total=n_step, desc='Epoch[{}/{}]'.format(epoch + 1, num_epochs), leave=False):
105 |
106 | X = tl.prepro.pad_sequences(X)
107 | _target_seqs = tl.prepro.sequences_add_end_id(Y, end_id=end_id)
108 | _target_seqs = tl.prepro.pad_sequences(_target_seqs, maxlen=decoder_seq_length)
109 | _decode_seqs = tl.prepro.sequences_add_start_id(Y, start_id=start_id, remove_last=False)
110 | _decode_seqs = tl.prepro.pad_sequences(_decode_seqs, maxlen=decoder_seq_length)
111 | _target_mask = tl.prepro.sequences_get_mask(_target_seqs)
112 |
113 | with tf.GradientTape() as tape:
114 | ## compute outputs
115 | output = model_(inputs = [X, _decode_seqs])
116 |
117 | output = tf.reshape(output, [-1, vocabulary_size])
118 | ## compute loss and update model
119 | loss = cross_entropy_seq_with_mask(logits=output, target_seqs=_target_seqs, input_mask=_target_mask)
120 |
121 | grad = tape.gradient(loss, model_.all_weights)
122 | optimizer.apply_gradients(zip(grad, model_.all_weights))
123 |
124 | total_loss += loss
125 | n_iter += 1
126 |
127 | # printing average loss after every epoch
128 | print('Epoch [{}/{}]: loss {:.4f}'.format(epoch + 1, num_epochs, total_loss / n_iter))
129 |
130 | for seed in seeds:
131 | print("Query >", seed)
132 | top_n = 3
133 | for i in range(top_n):
134 | sentence = inference(seed, top_n)
135 | print(" >", ' '.join(sentence))
136 |
137 | tl.files.save_npz(model_.all_weights, name='model.npz')
138 |
139 |
140 |
141 |
142 |
143 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit-learn
2 | tensorflow
3 | tensorlayer
4 | numpy
5 | click
6 | tqdm
7 | nltk
8 |
--------------------------------------------------------------------------------