├── .gitignore ├── README.md ├── data ├── __init__.py ├── cornell_corpus │ └── data.py └── twitter │ ├── data.py │ ├── idx_a.npy │ ├── idx_q.npy │ ├── metadata.pkl │ ├── pull │ └── pull_raw_data ├── main.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | dea 2 | docs/_build 3 | tensorlayer 4 | tensorlayer/__pacache__ 5 | tensorlayer/.DS_Store 6 | .DS_Store 7 | dist 8 | build/ 9 | tensorlayer.egg-info 10 | data/.DS_Store 11 | *.pyc 12 | *.gz 13 | .spyproject/ 14 | .vscode/* 15 | model.npz 16 | env/ 17 | venv/ 18 | .idea/ 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Seq2Seq Chatbot 2 | 3 | This is a 200 lines implementation of Twitter/Cornell-Movie Chatbot, please read the following references before you read the code: 4 | 5 | - [Practical-Seq2Seq](http://suriyadeepan.github.io/2016-12-31-practical-seq2seq/) 6 | - [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) 7 | - [Understanding LSTM Networks](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) (optional) 8 | 9 | ### Prerequisites 10 | 11 | - Python 3.6 12 | - [TensorFlow](https://github.com/tensorflow/tensorflow) >= 2.0 13 | - [TensorLayer](https://github.com/zsdonghao/tensorlayer) >= 2.0 14 | 15 | ### Model 16 | 17 | 18 |
19 | 20 |
21 | 22 |
23 |
24 | 25 | ### Training 26 | 27 | ``` 28 | python3 main.py 29 | ``` 30 | 31 | 32 | ### Results 33 | 34 | ``` 35 | Query > happy birthday have a nice day 36 | > thank you so much 37 | > thank babe 38 | > thank bro 39 | > thanks so much 40 | > thank babe i appreciate it 41 | Query > donald trump won last nights presidential debate according to snap online polls 42 | > i dont know what the fuck is that 43 | > i think he was a racist 44 | > he is not a racist 45 | > he is a liar 46 | > trump needs to be president 47 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | # from . import twitter 5 | # from . import imagenet_classes 6 | # from . import 7 | -------------------------------------------------------------------------------- /data/cornell_corpus/data.py: -------------------------------------------------------------------------------- 1 | EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz ' # space is included in whitelist 2 | EN_BLACKLIST = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\'' 3 | 4 | limit = { 5 | 'maxq' : 25, 6 | 'minq' : 2, 7 | 'maxa' : 25, 8 | 'mina' : 2 9 | } 10 | 11 | UNK = 'unk' 12 | VOCAB_SIZE = 8000 13 | 14 | 15 | import random 16 | 17 | import nltk 18 | import itertools 19 | from collections import defaultdict 20 | 21 | import numpy as np 22 | 23 | import pickle 24 | 25 | 26 | 27 | ''' 28 | 1. Read from 'movie-lines.txt' 29 | 2. Create a dictionary with ( key = line_id, value = text ) 30 | ''' 31 | def get_id2line(): 32 | lines=open('raw_data/movie_lines.txt', encoding='utf-8', errors='ignore').read().split('\n') 33 | id2line = {} 34 | for line in lines: 35 | _line = line.split(' +++$+++ ') 36 | if len(_line) == 5: 37 | id2line[_line[0]] = _line[4] 38 | return id2line 39 | 40 | ''' 41 | 1. Read from 'movie_conversations.txt' 42 | 2. Create a list of [list of line_id's] 43 | ''' 44 | def get_conversations(): 45 | conv_lines = open('raw_data/movie_conversations.txt', encoding='utf-8', errors='ignore').read().split('\n') 46 | convs = [ ] 47 | for line in conv_lines[:-1]: 48 | _line = line.split(' +++$+++ ')[-1][1:-1].replace("'","").replace(" ","") 49 | convs.append(_line.split(',')) 50 | return convs 51 | 52 | ''' 53 | 1. Get each conversation 54 | 2. Get each line from conversation 55 | 3. Save each conversation to file 56 | ''' 57 | def extract_conversations(convs,id2line,path=''): 58 | idx = 0 59 | for conv in convs: 60 | f_conv = open(path + str(idx)+'.txt', 'w') 61 | for line_id in conv: 62 | f_conv.write(id2line[line_id]) 63 | f_conv.write('\n') 64 | f_conv.close() 65 | idx += 1 66 | 67 | ''' 68 | Get lists of all conversations as Questions and Answers 69 | 1. [questions] 70 | 2. [answers] 71 | ''' 72 | def gather_dataset(convs, id2line): 73 | questions = []; answers = [] 74 | 75 | for conv in convs: 76 | if len(conv) %2 != 0: 77 | conv = conv[:-1] 78 | for i in range(len(conv)): 79 | if i%2 == 0: 80 | questions.append(id2line[conv[i]]) 81 | else: 82 | answers.append(id2line[conv[i]]) 83 | 84 | return questions, answers 85 | 86 | 87 | ''' 88 | We need 4 files 89 | 1. train.enc : Encoder input for training 90 | 2. train.dec : Decoder input for training 91 | 3. test.enc : Encoder input for testing 92 | 4. test.dec : Decoder input for testing 93 | ''' 94 | def prepare_seq2seq_files(questions, answers, path='',TESTSET_SIZE = 30000): 95 | 96 | # open files 97 | train_enc = open(path + 'train.enc','w') 98 | train_dec = open(path + 'train.dec','w') 99 | test_enc = open(path + 'test.enc', 'w') 100 | test_dec = open(path + 'test.dec', 'w') 101 | 102 | # choose 30,000 (TESTSET_SIZE) items to put into testset 103 | test_ids = random.sample([i for i in range(len(questions))],TESTSET_SIZE) 104 | 105 | for i in range(len(questions)): 106 | if i in test_ids: 107 | test_enc.write(questions[i]+'\n') 108 | test_dec.write(answers[i]+ '\n' ) 109 | else: 110 | train_enc.write(questions[i]+'\n') 111 | train_dec.write(answers[i]+ '\n' ) 112 | if i%10000 == 0: 113 | print('\n>> written {} lines'.format(i)) 114 | 115 | # close files 116 | train_enc.close() 117 | train_dec.close() 118 | test_enc.close() 119 | test_dec.close() 120 | 121 | 122 | 123 | ''' 124 | remove anything that isn't in the vocabulary 125 | return str(pure en) 126 | 127 | ''' 128 | def filter_line(line, whitelist): 129 | return ''.join([ ch for ch in line if ch in whitelist ]) 130 | 131 | 132 | 133 | ''' 134 | filter too long and too short sequences 135 | return tuple( filtered_ta, filtered_en ) 136 | 137 | ''' 138 | def filter_data(qseq, aseq): 139 | filtered_q, filtered_a = [], [] 140 | raw_data_len = len(qseq) 141 | 142 | assert len(qseq) == len(aseq) 143 | 144 | for i in range(raw_data_len): 145 | qlen, alen = len(qseq[i].split(' ')), len(aseq[i].split(' ')) 146 | if qlen >= limit['minq'] and qlen <= limit['maxq']: 147 | if alen >= limit['mina'] and alen <= limit['maxa']: 148 | filtered_q.append(qseq[i]) 149 | filtered_a.append(aseq[i]) 150 | 151 | # print the fraction of the original data, filtered 152 | filt_data_len = len(filtered_q) 153 | filtered = int((raw_data_len - filt_data_len)*100/raw_data_len) 154 | print(str(filtered) + '% filtered from original data') 155 | 156 | return filtered_q, filtered_a 157 | 158 | 159 | ''' 160 | read list of words, create index to word, 161 | word to index dictionaries 162 | return tuple( vocab->(word, count), idx2w, w2idx ) 163 | 164 | ''' 165 | def index_(tokenized_sentences, vocab_size): 166 | # get frequency distribution 167 | freq_dist = nltk.FreqDist(itertools.chain(*tokenized_sentences)) 168 | # get vocabulary of 'vocab_size' most used words 169 | vocab = freq_dist.most_common(vocab_size) 170 | # index2word 171 | index2word = ['_'] + [UNK] + [ x[0] for x in vocab ] 172 | # word2index 173 | word2index = dict([(w,i) for i,w in enumerate(index2word)] ) 174 | return index2word, word2index, freq_dist 175 | 176 | ''' 177 | filter based on number of unknowns (words not in vocabulary) 178 | filter out the worst sentences 179 | 180 | ''' 181 | def filter_unk(qtokenized, atokenized, w2idx): 182 | data_len = len(qtokenized) 183 | 184 | filtered_q, filtered_a = [], [] 185 | 186 | for qline, aline in zip(qtokenized, atokenized): 187 | unk_count_q = len([ w for w in qline if w not in w2idx ]) 188 | unk_count_a = len([ w for w in aline if w not in w2idx ]) 189 | if unk_count_a <= 2: 190 | if unk_count_q > 0: 191 | if unk_count_q/len(qline) > 0.2: 192 | pass 193 | filtered_q.append(qline) 194 | filtered_a.append(aline) 195 | 196 | # print the fraction of the original data, filtered 197 | filt_data_len = len(filtered_q) 198 | filtered = int((data_len - filt_data_len)*100/data_len) 199 | print(str(filtered) + '% filtered from original data') 200 | 201 | return filtered_q, filtered_a 202 | 203 | 204 | 205 | 206 | ''' 207 | create the final dataset : 208 | - convert list of items to arrays of indices 209 | - add zero padding 210 | return ( [array_en([indices]), array_ta([indices]) ) 211 | 212 | ''' 213 | def zero_pad(qtokenized, atokenized, w2idx): 214 | # num of rows 215 | data_len = len(qtokenized) 216 | 217 | # numpy arrays to store indices 218 | idx_q = np.zeros([data_len, limit['maxq']], dtype=np.int32) 219 | idx_a = np.zeros([data_len, limit['maxa']], dtype=np.int32) 220 | 221 | for i in range(data_len): 222 | q_indices = pad_seq(qtokenized[i], w2idx, limit['maxq']) 223 | a_indices = pad_seq(atokenized[i], w2idx, limit['maxa']) 224 | 225 | #print(len(idx_q[i]), len(q_indices)) 226 | #print(len(idx_a[i]), len(a_indices)) 227 | idx_q[i] = np.array(q_indices) 228 | idx_a[i] = np.array(a_indices) 229 | 230 | return idx_q, idx_a 231 | 232 | 233 | ''' 234 | replace words with indices in a sequence 235 | replace with unknown if word not in lookup 236 | return [list of indices] 237 | 238 | ''' 239 | def pad_seq(seq, lookup, maxlen): 240 | indices = [] 241 | for word in seq: 242 | if word in lookup: 243 | indices.append(lookup[word]) 244 | else: 245 | indices.append(lookup[UNK]) 246 | return indices + [0]*(maxlen - len(seq)) 247 | 248 | 249 | 250 | 251 | 252 | def process_data(): 253 | 254 | id2line = get_id2line() 255 | print('>> gathered id2line dictionary.\n') 256 | convs = get_conversations() 257 | print(convs[121:125]) 258 | print('>> gathered conversations.\n') 259 | questions, answers = gather_dataset(convs,id2line) 260 | 261 | # change to lower case (just for en) 262 | questions = [ line.lower() for line in questions ] 263 | answers = [ line.lower() for line in answers ] 264 | 265 | # filter out unnecessary characters 266 | print('\n>> Filter lines') 267 | questions = [ filter_line(line, EN_WHITELIST) for line in questions ] 268 | answers = [ filter_line(line, EN_WHITELIST) for line in answers ] 269 | 270 | # filter out too long or too short sequences 271 | print('\n>> 2nd layer of filtering') 272 | qlines, alines = filter_data(questions, answers) 273 | 274 | for q,a in zip(qlines[141:145], alines[141:145]): 275 | print('q : [{0}]; a : [{1}]'.format(q,a)) 276 | 277 | # convert list of [lines of text] into list of [list of words ] 278 | print('\n>> Segment lines into words') 279 | qtokenized = [ [w.strip() for w in wordlist.split(' ') if w] for wordlist in qlines ] 280 | atokenized = [ [w.strip() for w in wordlist.split(' ') if w] for wordlist in alines ] 281 | print('\n:: Sample from segmented list of words') 282 | 283 | for q,a in zip(qtokenized[141:145], atokenized[141:145]): 284 | print('q : [{0}]; a : [{1}]'.format(q,a)) 285 | 286 | # indexing -> idx2w, w2idx 287 | print('\n >> Index words') 288 | idx2w, w2idx, freq_dist = index_( qtokenized + atokenized, vocab_size=VOCAB_SIZE) 289 | 290 | # filter out sentences with too many unknowns 291 | print('\n >> Filter Unknowns') 292 | qtokenized, atokenized = filter_unk(qtokenized, atokenized, w2idx) 293 | print('\n Final dataset len : ' + str(len(qtokenized))) 294 | 295 | 296 | print('\n >> Zero Padding') 297 | idx_q, idx_a = zero_pad(qtokenized, atokenized, w2idx) 298 | 299 | print('\n >> Save numpy arrays to disk') 300 | # save them 301 | np.save('idx_q.npy', idx_q) 302 | np.save('idx_a.npy', idx_a) 303 | 304 | # let us now save the necessary dictionaries 305 | metadata = { 306 | 'w2idx' : w2idx, 307 | 'idx2w' : idx2w, 308 | 'limit' : limit, 309 | 'freq_dist' : freq_dist 310 | } 311 | 312 | # write to disk : data control dictionaries 313 | with open('metadata.pkl', 'wb') as f: 314 | pickle.dump(metadata, f) 315 | 316 | # count of unknowns 317 | unk_count = (idx_q == 1).sum() + (idx_a == 1).sum() 318 | # count of words 319 | word_count = (idx_q > 1).sum() + (idx_a > 1).sum() 320 | 321 | print('% unknown : {0}'.format(100 * (unk_count/word_count))) 322 | print('Dataset count : ' + str(idx_q.shape[0])) 323 | 324 | 325 | #print '>> gathered questions and answers.\n' 326 | #prepare_seq2seq_files(questions,answers) 327 | 328 | 329 | import numpy as np 330 | from random import sample 331 | 332 | ''' 333 | split data into train (70%), test (15%) and valid(15%) 334 | return tuple( (trainX, trainY), (testX,testY), (validX,validY) ) 335 | 336 | ''' 337 | def split_dataset(x, y, ratio = [0.7, 0.15, 0.15] ): 338 | # number of examples 339 | data_len = len(x) 340 | lens = [ int(data_len*item) for item in ratio ] 341 | 342 | trainX, trainY = x[:lens[0]], y[:lens[0]] 343 | testX, testY = x[lens[0]:lens[0]+lens[1]], y[lens[0]:lens[0]+lens[1]] 344 | validX, validY = x[-lens[-1]:], y[-lens[-1]:] 345 | 346 | return (trainX,trainY), (testX,testY), (validX,validY) 347 | 348 | 349 | ''' 350 | generate batches from dataset 351 | yield (x_gen, y_gen) 352 | 353 | TODO : fix needed 354 | 355 | ''' 356 | def batch_gen(x, y, batch_size): 357 | # infinite while 358 | while True: 359 | for i in range(0, len(x), batch_size): 360 | if (i+1)*batch_size < len(x): 361 | yield x[i : (i+1)*batch_size ].T, y[i : (i+1)*batch_size ].T 362 | 363 | ''' 364 | generate batches, by random sampling a bunch of items 365 | yield (x_gen, y_gen) 366 | 367 | ''' 368 | def rand_batch_gen(x, y, batch_size): 369 | while True: 370 | sample_idx = sample(list(np.arange(len(x))), batch_size) 371 | yield x[sample_idx].T, y[sample_idx].T 372 | 373 | #''' 374 | # convert indices of alphabets into a string (word) 375 | # return str(word) 376 | # 377 | #''' 378 | #def decode_word(alpha_seq, idx2alpha): 379 | # return ''.join([ idx2alpha[alpha] for alpha in alpha_seq if alpha ]) 380 | # 381 | # 382 | #''' 383 | # convert indices of phonemes into list of phonemes (as string) 384 | # return str(phoneme_list) 385 | # 386 | #''' 387 | #def decode_phonemes(pho_seq, idx2pho): 388 | # return ' '.join( [ idx2pho[pho] for pho in pho_seq if pho ]) 389 | 390 | 391 | ''' 392 | a generic decode function 393 | inputs : sequence, lookup 394 | 395 | ''' 396 | def decode(sequence, lookup, separator=''): # 0 used for padding, is ignored 397 | return separator.join([ lookup[element] for element in sequence if element ]) 398 | 399 | 400 | 401 | if __name__ == '__main__': 402 | process_data() 403 | 404 | 405 | def load_data(PATH=''): 406 | # read data control dictionaries 407 | with open(PATH + 'metadata.pkl', 'rb') as f: 408 | metadata = pickle.load(f) 409 | # read numpy arrays 410 | idx_q = np.load(PATH + 'idx_q.npy') 411 | idx_a = np.load(PATH + 'idx_a.npy') 412 | return metadata, idx_q, idx_a 413 | -------------------------------------------------------------------------------- /data/twitter/data.py: -------------------------------------------------------------------------------- 1 | EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz ' # space is included in whitelist 2 | EN_BLACKLIST = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\'' 3 | 4 | FILENAME = 'data/chat.txt' 5 | 6 | limit = { 7 | 'maxq' : 20, 8 | 'minq' : 0, 9 | 'maxa' : 20, 10 | 'mina' : 3 11 | } 12 | 13 | UNK = 'unk' 14 | VOCAB_SIZE = 6000 15 | 16 | import random 17 | import sys 18 | 19 | import nltk 20 | import itertools 21 | from collections import defaultdict 22 | 23 | import numpy as np 24 | 25 | import pickle 26 | 27 | 28 | def ddefault(): 29 | return 1 30 | 31 | ''' 32 | read lines from file 33 | return [list of lines] 34 | 35 | ''' 36 | def read_lines(filename): 37 | return open(filename).read().split('\n')[:-1] 38 | 39 | 40 | ''' 41 | split sentences in one line 42 | into multiple lines 43 | return [list of lines] 44 | 45 | ''' 46 | def split_line(line): 47 | return line.split('.') 48 | 49 | 50 | ''' 51 | remove anything that isn't in the vocabulary 52 | return str(pure ta/en) 53 | 54 | ''' 55 | def filter_line(line, whitelist): 56 | return ''.join([ ch for ch in line if ch in whitelist ]) 57 | 58 | 59 | ''' 60 | read list of words, create index to word, 61 | word to index dictionaries 62 | return tuple( vocab->(word, count), idx2w, w2idx ) 63 | 64 | ''' 65 | def index_(tokenized_sentences, vocab_size): 66 | # get frequency distribution 67 | freq_dist = nltk.FreqDist(itertools.chain(*tokenized_sentences)) 68 | # get vocabulary of 'vocab_size' most used words 69 | vocab = freq_dist.most_common(vocab_size) 70 | # index2word 71 | index2word = ['_'] + [UNK] + [ x[0] for x in vocab ] 72 | # word2index 73 | word2index = dict([(w,i) for i,w in enumerate(index2word)] ) 74 | return index2word, word2index, freq_dist 75 | 76 | 77 | ''' 78 | filter too long and too short sequences 79 | return tuple( filtered_ta, filtered_en ) 80 | 81 | ''' 82 | def filter_data(sequences): 83 | filtered_q, filtered_a = [], [] 84 | raw_data_len = len(sequences)//2 85 | 86 | for i in range(0, len(sequences), 2): 87 | qlen, alen = len(sequences[i].split(' ')), len(sequences[i+1].split(' ')) 88 | if qlen >= limit['minq'] and qlen <= limit['maxq']: 89 | if alen >= limit['mina'] and alen <= limit['maxa']: 90 | filtered_q.append(sequences[i]) 91 | filtered_a.append(sequences[i+1]) 92 | 93 | # print the fraction of the original data, filtered 94 | filt_data_len = len(filtered_q) 95 | filtered = int((raw_data_len - filt_data_len)*100/raw_data_len) 96 | print(str(filtered) + '% filtered from original data') 97 | 98 | return filtered_q, filtered_a 99 | 100 | 101 | 102 | 103 | 104 | ''' 105 | create the final dataset : 106 | - convert list of items to arrays of indices 107 | - add zero padding 108 | return ( [array_en([indices]), array_ta([indices]) ) 109 | 110 | ''' 111 | def zero_pad(qtokenized, atokenized, w2idx): 112 | # num of rows 113 | data_len = len(qtokenized) 114 | 115 | # numpy arrays to store indices 116 | idx_q = np.zeros([data_len, limit['maxq']], dtype=np.int32) 117 | idx_a = np.zeros([data_len, limit['maxa']], dtype=np.int32) 118 | 119 | for i in range(data_len): 120 | q_indices = pad_seq(qtokenized[i], w2idx, limit['maxq']) 121 | a_indices = pad_seq(atokenized[i], w2idx, limit['maxa']) 122 | 123 | #print(len(idx_q[i]), len(q_indices)) 124 | #print(len(idx_a[i]), len(a_indices)) 125 | idx_q[i] = np.array(q_indices) 126 | idx_a[i] = np.array(a_indices) 127 | 128 | return idx_q, idx_a 129 | 130 | 131 | ''' 132 | replace words with indices in a sequence 133 | replace with unknown if word not in lookup 134 | return [list of indices] 135 | 136 | ''' 137 | def pad_seq(seq, lookup, maxlen): 138 | indices = [] 139 | for word in seq: 140 | if word in lookup: 141 | indices.append(lookup[word]) 142 | else: 143 | indices.append(lookup[UNK]) 144 | return indices + [0]*(maxlen - len(seq)) 145 | 146 | 147 | def process_data(): 148 | 149 | print('\n>> Read lines from file') 150 | lines = read_lines(filename=FILENAME) 151 | 152 | # change to lower case (just for en) 153 | lines = [ line.lower() for line in lines ] 154 | 155 | print('\n:: Sample from read(p) lines') 156 | print(lines[121:125]) 157 | 158 | # filter out unnecessary characters 159 | print('\n>> Filter lines') 160 | lines = [ filter_line(line, EN_WHITELIST) for line in lines ] 161 | print(lines[121:125]) 162 | 163 | # filter out too long or too short sequences 164 | print('\n>> 2nd layer of filtering') 165 | qlines, alines = filter_data(lines) 166 | print('\nq : {0} ; a : {1}'.format(qlines[60], alines[60])) 167 | print('\nq : {0} ; a : {1}'.format(qlines[61], alines[61])) 168 | 169 | 170 | # convert list of [lines of text] into list of [list of words ] 171 | print('\n>> Segment lines into words') 172 | qtokenized = [ wordlist.split(' ') for wordlist in qlines ] 173 | atokenized = [ wordlist.split(' ') for wordlist in alines ] 174 | print('\n:: Sample from segmented list of words') 175 | print('\nq : {0} ; a : {1}'.format(qtokenized[60], atokenized[60])) 176 | print('\nq : {0} ; a : {1}'.format(qtokenized[61], atokenized[61])) 177 | 178 | 179 | # indexing -> idx2w, w2idx : en/ta 180 | print('\n >> Index words') 181 | idx2w, w2idx, freq_dist = index_( qtokenized + atokenized, vocab_size=VOCAB_SIZE) 182 | 183 | print('\n >> Zero Padding') 184 | idx_q, idx_a = zero_pad(qtokenized, atokenized, w2idx) 185 | 186 | print('\n >> Save numpy arrays to disk') 187 | # save them 188 | np.save('idx_q.npy', idx_q) 189 | np.save('idx_a.npy', idx_a) 190 | 191 | # let us now save the necessary dictionaries 192 | metadata = { 193 | 'w2idx' : w2idx, 194 | 'idx2w' : idx2w, 195 | 'limit' : limit, 196 | 'freq_dist' : freq_dist 197 | } 198 | 199 | # write to disk : data control dictionaries 200 | with open('metadata.pkl', 'wb') as f: 201 | pickle.dump(metadata, f) 202 | 203 | def load_data(PATH=''): 204 | # read data control dictionaries 205 | try: 206 | with open(PATH + 'metadata.pkl', 'rb') as f: 207 | metadata = pickle.load(f) 208 | except: 209 | metadata = None 210 | # read numpy arrays 211 | idx_q = np.load(PATH + 'idx_q.npy') 212 | idx_a = np.load(PATH + 'idx_a.npy') 213 | return metadata, idx_q, idx_a 214 | 215 | import numpy as np 216 | from random import sample 217 | 218 | ''' 219 | split data into train (70%), test (15%) and valid(15%) 220 | return tuple( (trainX, trainY), (testX,testY), (validX,validY) ) 221 | 222 | ''' 223 | def split_dataset(x, y, ratio = [0.7, 0.15, 0.15] ): 224 | # number of examples 225 | data_len = len(x) 226 | lens = [ int(data_len*item) for item in ratio ] 227 | 228 | trainX, trainY = x[:lens[0]], y[:lens[0]] 229 | testX, testY = x[lens[0]:lens[0]+lens[1]], y[lens[0]:lens[0]+lens[1]] 230 | validX, validY = x[-lens[-1]:], y[-lens[-1]:] 231 | 232 | return (trainX,trainY), (testX,testY), (validX,validY) 233 | 234 | 235 | ''' 236 | generate batches from dataset 237 | yield (x_gen, y_gen) 238 | 239 | TODO : fix needed 240 | 241 | ''' 242 | def batch_gen(x, y, batch_size): 243 | # infinite while 244 | while True: 245 | for i in range(0, len(x), batch_size): 246 | if (i+1)*batch_size < len(x): 247 | yield x[i : (i+1)*batch_size ].T, y[i : (i+1)*batch_size ].T 248 | 249 | ''' 250 | generate batches, by random sampling a bunch of items 251 | yield (x_gen, y_gen) 252 | 253 | ''' 254 | def rand_batch_gen(x, y, batch_size): 255 | while True: 256 | sample_idx = sample(list(np.arange(len(x))), batch_size) 257 | yield x[sample_idx].T, y[sample_idx].T 258 | 259 | #''' 260 | # convert indices of alphabets into a string (word) 261 | # return str(word) 262 | # 263 | #''' 264 | #def decode_word(alpha_seq, idx2alpha): 265 | # return ''.join([ idx2alpha[alpha] for alpha in alpha_seq if alpha ]) 266 | # 267 | # 268 | #''' 269 | # convert indices of phonemes into list of phonemes (as string) 270 | # return str(phoneme_list) 271 | # 272 | #''' 273 | #def decode_phonemes(pho_seq, idx2pho): 274 | # return ' '.join( [ idx2pho[pho] for pho in pho_seq if pho ]) 275 | 276 | 277 | ''' 278 | a generic decode function 279 | inputs : sequence, lookup 280 | 281 | ''' 282 | def decode(sequence, lookup, separator=''): # 0 used for padding, is ignored 283 | return separator.join([ lookup[element] for element in sequence if element ]) 284 | 285 | 286 | 287 | if __name__ == '__main__': 288 | process_data() 289 | -------------------------------------------------------------------------------- /data/twitter/idx_a.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/seq2seq-chatbot/3757307595b15e45a8870ffbe7728d72ddca1f96/data/twitter/idx_a.npy -------------------------------------------------------------------------------- /data/twitter/idx_q.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/seq2seq-chatbot/3757307595b15e45a8870ffbe7728d72ddca1f96/data/twitter/idx_q.npy -------------------------------------------------------------------------------- /data/twitter/metadata.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/seq2seq-chatbot/3757307595b15e45a8870ffbe7728d72ddca1f96/data/twitter/metadata.pkl -------------------------------------------------------------------------------- /data/twitter/pull: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget -c 'https://www.dropbox.com/s/tmfwptbs3q180p0/seq2seq.twitter.tar.gz?dl=0' -O seq2seq.twitter.tar.gz 4 | -------------------------------------------------------------------------------- /data/twitter/pull_raw_data: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -c https://raw.githubusercontent.com/Marsan-Ma/chat_corpus/master/twitter_en.txt.gz 3 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | import tensorlayer as tl 6 | import numpy as np 7 | from tensorlayer.cost import cross_entropy_seq, cross_entropy_seq_with_mask 8 | from tqdm import tqdm 9 | from sklearn.utils import shuffle 10 | from data.twitter import data 11 | from tensorlayer.models.seq2seq import Seq2seq 12 | from tensorlayer.models.seq2seq_with_attention import Seq2seqLuongAttention 13 | import os 14 | 15 | 16 | def initial_setup(data_corpus): 17 | metadata, idx_q, idx_a = data.load_data(PATH='data/{}/'.format(data_corpus)) 18 | (trainX, trainY), (testX, testY), (validX, validY) = data.split_dataset(idx_q, idx_a) 19 | trainX = tl.prepro.remove_pad_sequences(trainX.tolist()) 20 | trainY = tl.prepro.remove_pad_sequences(trainY.tolist()) 21 | testX = tl.prepro.remove_pad_sequences(testX.tolist()) 22 | testY = tl.prepro.remove_pad_sequences(testY.tolist()) 23 | validX = tl.prepro.remove_pad_sequences(validX.tolist()) 24 | validY = tl.prepro.remove_pad_sequences(validY.tolist()) 25 | return metadata, trainX, trainY, testX, testY, validX, validY 26 | 27 | 28 | 29 | if __name__ == "__main__": 30 | data_corpus = "twitter" 31 | 32 | #data preprocessing 33 | metadata, trainX, trainY, testX, testY, validX, validY = initial_setup(data_corpus) 34 | 35 | # Parameters 36 | src_len = len(trainX) 37 | tgt_len = len(trainY) 38 | 39 | assert src_len == tgt_len 40 | 41 | batch_size = 32 42 | n_step = src_len // batch_size 43 | src_vocab_size = len(metadata['idx2w']) # 8002 (0~8001) 44 | emb_dim = 1024 45 | 46 | word2idx = metadata['w2idx'] # dict word 2 index 47 | idx2word = metadata['idx2w'] # list index 2 word 48 | 49 | unk_id = word2idx['unk'] # 1 50 | pad_id = word2idx['_'] # 0 51 | 52 | start_id = src_vocab_size # 8002 53 | end_id = src_vocab_size + 1 # 8003 54 | 55 | word2idx.update({'start_id': start_id}) 56 | word2idx.update({'end_id': end_id}) 57 | idx2word = idx2word + ['start_id', 'end_id'] 58 | 59 | src_vocab_size = tgt_vocab_size = src_vocab_size + 2 60 | 61 | num_epochs = 50 62 | vocabulary_size = src_vocab_size 63 | 64 | 65 | 66 | def inference(seed, top_n): 67 | model_.eval() 68 | seed_id = [word2idx.get(w, unk_id) for w in seed.split(" ")] 69 | sentence_id = model_(inputs=[[seed_id]], seq_length=20, start_token=start_id, top_n = top_n) 70 | sentence = [] 71 | for w_id in sentence_id[0]: 72 | w = idx2word[w_id] 73 | if w == 'end_id': 74 | break 75 | sentence = sentence + [w] 76 | return sentence 77 | 78 | decoder_seq_length = 20 79 | model_ = Seq2seq( 80 | decoder_seq_length = decoder_seq_length, 81 | cell_enc=tf.keras.layers.GRUCell, 82 | cell_dec=tf.keras.layers.GRUCell, 83 | n_layer=3, 84 | n_units=256, 85 | embedding_layer=tl.layers.Embedding(vocabulary_size=vocabulary_size, embedding_size=emb_dim), 86 | ) 87 | 88 | 89 | # Uncomment below statements if you have already saved the model 90 | 91 | # load_weights = tl.files.load_npz(name='model.npz') 92 | # tl.files.assign_weights(load_weights, model_) 93 | 94 | optimizer = tf.optimizers.Adam(learning_rate=0.001) 95 | model_.train() 96 | 97 | seeds = ["happy birthday have a nice day", 98 | "donald trump won last nights presidential debate according to snap online polls"] 99 | for epoch in range(num_epochs): 100 | model_.train() 101 | trainX, trainY = shuffle(trainX, trainY, random_state=0) 102 | total_loss, n_iter = 0, 0 103 | for X, Y in tqdm(tl.iterate.minibatches(inputs=trainX, targets=trainY, batch_size=batch_size, shuffle=False), 104 | total=n_step, desc='Epoch[{}/{}]'.format(epoch + 1, num_epochs), leave=False): 105 | 106 | X = tl.prepro.pad_sequences(X) 107 | _target_seqs = tl.prepro.sequences_add_end_id(Y, end_id=end_id) 108 | _target_seqs = tl.prepro.pad_sequences(_target_seqs, maxlen=decoder_seq_length) 109 | _decode_seqs = tl.prepro.sequences_add_start_id(Y, start_id=start_id, remove_last=False) 110 | _decode_seqs = tl.prepro.pad_sequences(_decode_seqs, maxlen=decoder_seq_length) 111 | _target_mask = tl.prepro.sequences_get_mask(_target_seqs) 112 | 113 | with tf.GradientTape() as tape: 114 | ## compute outputs 115 | output = model_(inputs = [X, _decode_seqs]) 116 | 117 | output = tf.reshape(output, [-1, vocabulary_size]) 118 | ## compute loss and update model 119 | loss = cross_entropy_seq_with_mask(logits=output, target_seqs=_target_seqs, input_mask=_target_mask) 120 | 121 | grad = tape.gradient(loss, model_.all_weights) 122 | optimizer.apply_gradients(zip(grad, model_.all_weights)) 123 | 124 | total_loss += loss 125 | n_iter += 1 126 | 127 | # printing average loss after every epoch 128 | print('Epoch [{}/{}]: loss {:.4f}'.format(epoch + 1, num_epochs, total_loss / n_iter)) 129 | 130 | for seed in seeds: 131 | print("Query >", seed) 132 | top_n = 3 133 | for i in range(top_n): 134 | sentence = inference(seed, top_n) 135 | print(" >", ' '.join(sentence)) 136 | 137 | tl.files.save_npz(model_.all_weights, name='model.npz') 138 | 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | tensorflow 3 | tensorlayer 4 | numpy 5 | click 6 | tqdm 7 | nltk 8 | --------------------------------------------------------------------------------