├── requirements.txt ├── .gitignore ├── scripts ├── prepare_Twitter-Dialogs_Corpus.sh └── prepare_Cornell_Movie-Dialogs_Corpus.sh ├── data └── tiny_processed_data │ ├── test_ids.enc │ ├── test_ids.dec │ ├── train_ids.dec │ └── train_ids.enc ├── utils.py ├── config ├── all-dialogs.yml ├── check_tiny.yml ├── twitter-dialogs.yml └── cornell-movie-dialogs.yml ├── hook.py ├── main.py ├── chat.py ├── model.py ├── seq2seq_attention ├── encoder.py ├── __init__.py └── decoder.py ├── README.md └── data_loader.py /requirements.txt: -------------------------------------------------------------------------------- 1 | hb-config 2 | nltk 3 | tqdm 4 | requests -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | logs/ 3 | processed_cornell_movie_dialogs_data/ 4 | -------------------------------------------------------------------------------- /scripts/prepare_Twitter-Dialogs_Corpus.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | cd data 3 | 4 | wget https://github.com/Marsan-Ma/chat_corpus/raw/master/twitter_en.txt.gz 5 | gunzip twitter_en.txt.gz 6 | -------------------------------------------------------------------------------- /scripts/prepare_Cornell_Movie-Dialogs_Corpus.sh: -------------------------------------------------------------------------------- 1 | 2 | mkdir data 3 | cd data 4 | 5 | wget http://www.mpi-sws.org/~cristian/data/cornell_movie_dialogs_corpus.zip 6 | unzip cornell_movie_dialogs_corpus.zip 7 | mv cornell\ movie-dialogs\ corpus cornell_movie_dialogs_corpus 8 | 9 | cd .. 10 | python3 data_loader.py --config cornell-movie-dialogs 11 | -------------------------------------------------------------------------------- /data/tiny_processed_data/test_ids.enc: -------------------------------------------------------------------------------- 1 | 9151 2454 4 58 291 11 40 118 2428 86 58 543 505 333 6 101 15 39 47 49 58 62 512 19 15 93 114 4 2 | 71 4 80 9 7 5 33 450 29 10 13166 4 3 | 17 5 14 144 15 5 14 347 13 250 66 4 4 | 7 62 114 41 35 242 9 5 | 58 5 14 38 13 4 4 4 6 | 23 4 28 7105 4 4 4 2206 18755 4 8 5 30 49 6 1792 309 4 7 | 37 8 | 973 9 | 482 8 743 56 10 285 25 5 33 373 11 63 11 72 55 44 46 10 763 25 411 63 11 72 55 44 6 7 32 9 10 | 196 6 756 7143 4 -------------------------------------------------------------------------------- /data/tiny_processed_data/test_ids.dec: -------------------------------------------------------------------------------- 1 | 2 17 5 14 13 1799 4 3 2 | 2 8 27 5 16 63 11 32 61 11 100 17 498 4 8 63 11 32 3090 167 4 49 91 10 87 4768 46 4 61 139 181 2352 982 9 320 49 2834 4 8 34 111 22 36 155 110 11 352 55 36 305 11 240 4 3 3 | 2 360 1007 4 3 4 | 2 8 6773 41 7 11 178 36 345 4 7 18 17 16500 46 1318 9328 4 369 5 16 25 154 77 41 135 715 9 3 5 | 2 6162 9 37 4 8 307 13 650 19 26558 17579 22 66 19 74 8914 6 54 8 5 30 293 120 58 5 14 38 17580 271 12 624 12702 4 3 6 | 2 231 186 21 56 8 110 11 238 66 119 395 51 35 44748 4 4 4 3 7 | 2 122 12 12 7 5 33 126 134 11 657 61 11 629 4 3 8 | 2 107 5 14 72 4 3 9 | 2 48 8 32 24 12 12 8 5 79 145 60 36 800 545 11 72 55 44 13 171 49 1616 4 3 10 | 2 432 6 1616 4 3 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path 3 | 4 | from hbconfig import Config 5 | import requests 6 | 7 | 8 | 9 | def get_rev_vocab(vocab): 10 | if vocab is None: 11 | return None 12 | return {idx: key for key, idx in vocab.items()} 13 | 14 | 15 | def send_message_to_slack(config_name): 16 | project_name = os.path.basename(os.path.abspath(".")) 17 | 18 | data = { 19 | "text": f"The learning is finished with *{project_name}* Project using `{config_name}` config." 20 | } 21 | 22 | webhook_url = Config.slack.webhook_url 23 | if webhook_url == "": 24 | print(data["text"]) 25 | else: 26 | requests.post(Config.slack.webhook_url, data=json.dumps(data)) 27 | -------------------------------------------------------------------------------- /config/all-dialogs.yml: -------------------------------------------------------------------------------- 1 | data: 2 | type: all 3 | base_path: 'data/' 4 | processed_path: 'processed_all_dialogs_data' 5 | word_threshold: 6 6 | max_seq_length: 200 7 | sentence_diff: 0.33 8 | testset_size: 50000 9 | 10 | PAD_ID: 0 11 | UNK_ID: 1 12 | START_ID: 2 13 | EOS_ID: 3 14 | 15 | model: 16 | batch_size: 32 17 | num_layers: 4 18 | num_units: 512 19 | embed_dim: 300 20 | embed_share: true 21 | cell_type: gru 22 | dropout: 0.2 23 | encoder_type: bi 24 | attention_mechanism: bahdanau 25 | 26 | train: 27 | learning_rate: 0.001 28 | sampling_probability: 0.4 29 | 30 | train_steps: 100000 31 | model_dir: 'logs/all_dialogs' 32 | 33 | save_checkpoints_steps: 1000 34 | loss_hook_n_iter: 1000 35 | check_hook_n_iter: 1000 36 | min_eval_frequency: 1000 37 | 38 | print_verbose: True 39 | debug: False 40 | 41 | predict: 42 | beam_width: 0 43 | length_penalty_weight: 1.0 44 | 45 | slack: 46 | webhook_url: "" 47 | -------------------------------------------------------------------------------- /config/check_tiny.yml: -------------------------------------------------------------------------------- 1 | data: 2 | base_path: 'data/' 3 | conversation_fname: 'movie_conversations.txt' 4 | line_fname: 'movie_lines.txt' 5 | processed_path: 'tiny_processed_data' 6 | max_seq_length: 60 7 | word_threshold: 2 8 | sentence_diff: 0.33 9 | 10 | PAD_ID: 0 11 | UNK_ID: 1 12 | START_ID: 2 13 | EOS_ID: 3 14 | 15 | model: 16 | batch_size: 2 17 | num_layers: 1 18 | num_units: 16 19 | embed_dim: 16 20 | embed_share: false 21 | cell_type: lstm 22 | dropout: 0.2 23 | encoder_type: bi 24 | attention_mechanism: bahdanau 25 | 26 | train: 27 | learning_rate: 0.001 28 | sampling_probability: 0.25 29 | 30 | train_steps: 20000 31 | model_dir: 'logs/check_tiny' 32 | 33 | save_checkpoints_steps: 1000 34 | loss_hook_n_iter: 1 35 | check_hook_n_iter: 10 36 | min_eval_frequency: 10 37 | 38 | print_verbose: True 39 | debug: False 40 | 41 | predict: 42 | beam_width: 0 43 | length_penalty_weight: 1.0 44 | 45 | slack: 46 | webhook_url: "" 47 | -------------------------------------------------------------------------------- /config/twitter-dialogs.yml: -------------------------------------------------------------------------------- 1 | data: 2 | type: twitter 3 | base_path: 'data/' 4 | line_fname: 'twitter_en.txt' 5 | processed_path: 'processed_twitter_dialogs_data' 6 | 7 | word_threshold: 2 8 | max_seq_length: 200 9 | sentence_diff: 0.33 10 | testset_size: 25000 11 | 12 | PAD_ID: 0 13 | UNK_ID: 1 14 | START_ID: 2 15 | EOS_ID: 3 16 | 17 | model: 18 | batch_size: 32 19 | num_layers: 3 20 | num_units: 512 21 | embed_dim: 256 22 | embed_share: true 23 | cell_type: gru 24 | dropout: 0.2 25 | encoder_type: bi 26 | attention_mechanism: normed_bahdanau 27 | 28 | train: 29 | learning_rate: 0.001 30 | sampling_probability: 0.4 31 | 32 | train_steps: 100000 33 | model_dir: 'logs/twitter_dialogs' 34 | 35 | save_checkpoints_steps: 1000 36 | loss_hook_n_iter: 1000 37 | check_hook_n_iter: 1000 38 | min_eval_frequency: 1000 39 | 40 | print_verbose: True 41 | debug: False 42 | 43 | predict: 44 | beam_width: 0 45 | length_penalty_weight: 1.0 46 | 47 | slack: 48 | webhook_url: "" 49 | -------------------------------------------------------------------------------- /config/cornell-movie-dialogs.yml: -------------------------------------------------------------------------------- 1 | data: 2 | base_path: 'data/cornell_movie_dialogs_corpus/' 3 | conversation_fname: 'movie_conversations.txt' 4 | line_fname: 'movie_lines.txt' 5 | processed_path: 'processed_cornell_movie_dialogs_data' 6 | 7 | word_threshold: 2 8 | max_seq_length: 200 9 | testset_size: 25000 10 | sentence_diff: 0.33 11 | 12 | PAD_ID: 0 13 | UNK_ID: 1 14 | START_ID: 2 15 | EOS_ID: 3 16 | 17 | model: 18 | batch_size: 32 19 | num_layers: 3 20 | num_units: 512 21 | embed_dim: 256 22 | embed_share: true 23 | cell_type: gru 24 | dropout: 0.2 25 | encoder_type: bi 26 | attention_mechanism: normed_bahdanau 27 | 28 | train: 29 | learning_rate: 0.001 30 | sampling_probability: 0.25 31 | 32 | train_steps: 100000 33 | model_dir: 'logs/cornell_movie_dialogs' 34 | 35 | save_checkpoints_steps: 1000 36 | loss_hook_n_iter: 1000 37 | check_hook_n_iter: 1000 38 | min_eval_frequency: 1 39 | 40 | print_verbose: True 41 | debug: False 42 | 43 | predict: 44 | beam_width: 5 45 | length_penalty_weight: 1.0 46 | 47 | slack: 48 | webhook_url: "" 49 | -------------------------------------------------------------------------------- /hook.py: -------------------------------------------------------------------------------- 1 | 2 | from hbconfig import Config 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | 8 | def print_variables(variables, rev_vocab=None, every_n_iter=100): 9 | 10 | return tf.train.LoggingTensorHook( 11 | variables, 12 | every_n_iter=every_n_iter, 13 | formatter=format_variable(variables, rev_vocab=rev_vocab)) 14 | 15 | 16 | def format_variable(keys, rev_vocab=None): 17 | 18 | def to_str(sequence): 19 | if type(sequence) == np.ndarray: 20 | tokens = [ 21 | rev_vocab.get(x, '') for x in sequence if x != Config.data.PAD_ID] 22 | return ' '.join(tokens) 23 | else: 24 | x = int(sequence) 25 | return rev_vocab[x] 26 | 27 | def format(values): 28 | result = [] 29 | for key in keys: 30 | if rev_vocab is None: 31 | result.append(f"{key} = {values[key]}") 32 | else: 33 | result.append(f"{key} = {to_str(values[key])}") 34 | 35 | try: 36 | return '\n - '.join(result) 37 | except: 38 | pass 39 | 40 | return format 41 | -------------------------------------------------------------------------------- /data/tiny_processed_data/train_ids.dec: -------------------------------------------------------------------------------- 1 | 2 69 6 8 151 25 5 79 324 44 26557 6 56 17 5 14 122 44 7 4 3 2 | 2 38 10 8911 18 14613 18 9307 363 4 187 4 3 3 | 2 122 4 4 4 101 61 5 883 25 270 55 96 1007 12734 4 1464 9 163 9 3 4 | 2 360 15 4 3 5 | 2 5299 4 3 6 | 2 10 128 24 6 5299 12 12 8 5 30 67 10 3023 19 13 2460 8913 5014 19 2191 4 36 615 4 8 43 5 16 715 361 58 181 4 3 7 | 2 564 49 58 105 53 13 715 447 237 4 4 4 3 8 | 2 9151 2454 4 58 291 11 40 118 2428 86 58 543 505 333 6 101 15 39 47 49 58 62 512 19 15 93 114 4 3 9 | 2 107 23 80 20 8 43 31 4 3 10 | 2 71 4 80 9 7 5 33 450 29 10 13166 4 3 11 | 2 17 5 14 144 15 5 14 347 13 250 66 4 3 12 | 2 69 6 50 5 14 240 8 64 220 40 12 12 3 13 | 2 91 9 3 14 | 2 7 5 33 756 4 3 15 | 2 41641 5 14 924 12981 226 137 334 4 18 8 111 6 154 331 13 41642 308 10 41643 15495 4 3 16 | 2 8 118 6 118 6 118 375 72 6 42 8 43 5 16 4 38 702 36 615 488 4 3 17 | 2 8 5 30 2447 5 41 15 4 42 58 195 5 16 583 11 40 641 5 29 59 4 3 18 | 2 54 17 5 14 10 208 19 171 58 893 9 293 763 9 3 19 | 2 78 365 9 48 8 5 65 154 273 74 100 24 17 58 5 79 6378 158 2767 13 171 17 4895 4 3 20 | 2 387 49 167 673 55 287 6 216 9 3 21 | 2 8 177 25 1100 88 992 13824 3 22 | 2 4641 3 23 | 2 8 39 9 3 24 | 2 7 111 234 11 72 55 44 5 23 6 76 7 9 3 -------------------------------------------------------------------------------- /data/tiny_processed_data/train_ids.enc: -------------------------------------------------------------------------------- 1 | 43 25 117 28 964 9 35788 35789 18 4080 8226 46 426 88 3857 22621 965 501 12 60 41 10 26556 4 194 4 2 | 69 6 8 151 25 5 79 324 44 26557 6 56 17 5 14 122 44 7 4 3 | 38 10 8911 18 14613 18 9307 363 4 187 4 4 | 7 5 33 601 23 55 4 17 5 14 54 1160 4 20 5 14 35 191 194 9 5 | 37 6 37 6 15 5 14 36 742 12 12 25 102 5 16 34 13 2118 8912 12 12 12 6 | 5299 4 7 | 10 128 24 6 5299 12 12 8 5 30 67 10 3023 19 13 2460 8913 5014 19 2191 4 36 615 4 8 43 5 16 715 361 58 181 4 8 | 75 9 9 | 2351 6 56 131 25 105 162 5015 13 1124 4 4 4 10 | 481 5 35790 603 35791 4 28 24 36 305 11 | 8 27 5 16 63 11 32 61 11 100 17 498 4 8 63 11 32 3090 167 4 49 91 10 87 4768 46 4 61 139 181 2352 982 9 320 49 2834 4 8 34 111 22 36 155 110 11 352 55 36 305 11 240 4 12 | 61 24 135 121 162 10 11217 13 715 623 22622 9 13 | 50 4 14 | 7 34 36 385 4 83 13 1586 15 | 61 31 7 53 35 656 11 108 49 17 9 16 | 120 34 4 17 | 8 118 6 118 6 118 375 72 6 42 8 43 5 16 4 38 702 36 615 488 4 18 | 6162 9 37 4 8 307 13 650 19 26558 17579 22 66 19 74 8914 6 54 8 5 30 293 120 58 5 14 38 17580 271 12 624 12702 4 19 | 54 17 5 14 10 208 19 171 58 893 9 293 763 9 20 | 432 4 21 | 7 32 11218 9 22 | 34 595 287 9 23 | 8 636 29 7 104 67 10 516 6 42 7 179 1192 11 40 70 6314 70 4 24 | 8 39 9 25 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #-- coding: utf-8 -*- 2 | 3 | import argparse 4 | import atexit 5 | import logging 6 | 7 | from hbconfig import Config 8 | import tensorflow as tf 9 | from tensorflow.python import debug as tf_debug 10 | 11 | import data_loader 12 | from model import Conversation 13 | import hook 14 | import utils 15 | 16 | 17 | 18 | def experiment_fn(run_config, params): 19 | 20 | conversation = Conversation() 21 | estimator = tf.estimator.Estimator( 22 | model_fn=conversation.model_fn, 23 | model_dir=Config.train.model_dir, 24 | params=params, 25 | config=run_config) 26 | 27 | vocab = data_loader.load_vocab("vocab") 28 | Config.data.vocab_size = len(vocab) 29 | 30 | train_X, test_X, train_y, test_y = data_loader.make_train_and_test_set() 31 | 32 | train_input_fn, train_input_hook = data_loader.make_batch((train_X, train_y), batch_size=Config.model.batch_size) 33 | test_input_fn, test_input_hook = data_loader.make_batch((test_X, test_y), batch_size=Config.model.batch_size, scope="test") 34 | 35 | train_hooks = [train_input_hook] 36 | if Config.train.print_verbose: 37 | train_hooks.append(hook.print_variables( 38 | variables=['train/enc_0', 'train/dec_0', 'train/pred_0'], 39 | rev_vocab=utils.get_rev_vocab(vocab), 40 | every_n_iter=Config.train.check_hook_n_iter)) 41 | if Config.train.debug: 42 | train_hooks.append(tf_debug.LocalCLIDebugHook()) 43 | 44 | eval_hooks = [test_input_hook] 45 | if Config.train.debug: 46 | eval_hooks.append(tf_debug.LocalCLIDebugHook()) 47 | 48 | experiment = tf.contrib.learn.Experiment( 49 | estimator=estimator, 50 | train_input_fn=train_input_fn, 51 | eval_input_fn=test_input_fn, 52 | train_steps=Config.train.train_steps, 53 | min_eval_frequency=Config.train.min_eval_frequency, 54 | train_monitors=train_hooks, 55 | eval_hooks=eval_hooks, 56 | eval_delay_secs=0 57 | ) 58 | return experiment 59 | 60 | 61 | def main(mode): 62 | params = tf.contrib.training.HParams(**Config.model.to_dict()) 63 | 64 | run_config = tf.contrib.learn.RunConfig( 65 | model_dir=Config.train.model_dir, 66 | save_checkpoints_steps=Config.train.save_checkpoints_steps) 67 | 68 | tf.contrib.learn.learn_runner.run( 69 | experiment_fn=experiment_fn, 70 | run_config=run_config, 71 | schedule=mode, 72 | hparams=params 73 | ) 74 | 75 | 76 | if __name__ == '__main__': 77 | 78 | parser = argparse.ArgumentParser( 79 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 80 | parser.add_argument('--config', type=str, default='config', 81 | help='config file name') 82 | parser.add_argument('--mode', type=str, default='train', 83 | help='Mode (train/test/train_and_evaluate)') 84 | args = parser.parse_args() 85 | 86 | tf.logging.set_verbosity(logging.INFO) 87 | 88 | # Print Config setting 89 | Config(args.config) 90 | print("Config: ", Config) 91 | if Config.get("description", None): 92 | print("Config Description") 93 | for key, value in Config.description.items(): 94 | print(f" - {key}: {value}") 95 | 96 | # After terminated Notification to Slack 97 | atexit.register(utils.send_message_to_slack, config_name=args.config) 98 | 99 | main(args.mode) 100 | -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import os 5 | import sys 6 | 7 | from hbconfig import Config 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | import data_loader 12 | from model import Conversation 13 | import utils 14 | 15 | 16 | def chat(ids, vocab): 17 | 18 | X = np.array(data_loader._pad_input(ids, Config.data.max_seq_length), dtype=np.int32) 19 | X = np.reshape(X, (1, Config.data.max_seq_length)) 20 | 21 | predict_input_fn = tf.estimator.inputs.numpy_input_fn( 22 | x={"input_data": X}, 23 | num_epochs=1, 24 | shuffle=False) 25 | 26 | estimator = _make_estimator() 27 | result = estimator.predict(input_fn=predict_input_fn) 28 | 29 | prediction = next(result)["prediction"] 30 | 31 | beam_width = Config.predict.get('beam_width', 0) 32 | if beam_width > 0: 33 | 34 | def select_by_score(predictions): 35 | p_list = list(predictions) 36 | 37 | scores = [] 38 | for p in p_list: 39 | score = 0 40 | 41 | unknown_count = len(list(filter(lambda x: x == -1, p))) 42 | score -= 2 * unknown_count 43 | 44 | eos_except_last_count = len(list(filter(lambda x: x == Config.data.EOS_ID, p[:-1]))) 45 | score -= 2 * eos_except_last_count 46 | 47 | distinct_id_count = len(list(set(p))) 48 | score += 1 * distinct_id_count 49 | 50 | if eos_except_last_count == 0 and p[-1] == Config.data.EOS_ID: 51 | score += 5 52 | 53 | scores.append(score) 54 | 55 | max_score_index = scores.index(max(scores)) 56 | return predictions[max_score_index] 57 | 58 | prediction = select_by_score(prediction) 59 | 60 | rev_vocab = utils.get_rev_vocab(vocab) 61 | def to_str(sequence): 62 | tokens = [ 63 | rev_vocab.get(x, '') for x in sequence if x != Config.data.PAD_ID] 64 | return ' '.join(tokens) 65 | 66 | return to_str(prediction) 67 | 68 | 69 | def _make_estimator(): 70 | params = tf.contrib.training.HParams(**Config.model.to_dict()) 71 | # Using CPU 72 | run_config = tf.contrib.learn.RunConfig( 73 | model_dir=Config.train.model_dir, 74 | session_config=tf.ConfigProto( 75 | device_count={'GPU': 0} 76 | )) 77 | 78 | conversation = Conversation() 79 | return tf.estimator.Estimator( 80 | model_fn=conversation.model_fn, 81 | model_dir=Config.train.model_dir, 82 | params=params, 83 | config=run_config) 84 | 85 | 86 | def _get_user_input(): 87 | """ Get user's input, which will be transformed into encoder input later """ 88 | print("> ", end="") 89 | sys.stdout.flush() 90 | return sys.stdin.readline() 91 | 92 | 93 | def main(): 94 | vocab = data_loader.load_vocab("vocab") 95 | Config.data.vocab_size = len(vocab) 96 | 97 | while True: 98 | sentence = _get_user_input().lower() 99 | ids = data_loader.sentence2id(vocab, sentence) 100 | ids += [Config.data.START_ID] 101 | 102 | if len(ids) > Config.data.max_seq_length: 103 | print(f"Max length I can handle is: {Config.data.max_seq_length}") 104 | continue 105 | 106 | answer = chat(ids, vocab) 107 | print(answer) 108 | 109 | 110 | if __name__ == '__main__': 111 | 112 | parser = argparse.ArgumentParser( 113 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 114 | parser.add_argument('--config', type=str, default='config', 115 | help='config file name') 116 | args = parser.parse_args() 117 | 118 | Config(args.config) 119 | Config.train.batch_size = 1 120 | 121 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 122 | tf.logging.set_verbosity(tf.logging.ERROR) 123 | 124 | main() 125 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | 4 | from hbconfig import Config 5 | import nltk 6 | import tensorflow as tf 7 | 8 | import seq2seq_attention 9 | 10 | 11 | 12 | class Conversation: 13 | 14 | def __init__(self): 15 | pass 16 | 17 | def model_fn(self, mode, features, labels, params): 18 | self.dtype = tf.float32 19 | self.mode = mode 20 | 21 | self.loss, self.train_op, self.metrics, self.predictions = None, None, None, None 22 | self._init_placeholder(features, labels) 23 | self.build_graph() 24 | 25 | # train mode: required loss and train_op 26 | # eval mode: required loss 27 | # predict mode: required predictions 28 | 29 | return tf.estimator.EstimatorSpec( 30 | mode=mode, 31 | loss=self.loss, 32 | train_op=self.train_op, 33 | eval_metric_ops=self.metrics, 34 | predictions={"prediction": self.predictions}) 35 | 36 | def _init_placeholder(self, features, labels): 37 | self.encoder_inputs = features 38 | if type(features) == dict: 39 | self.encoder_inputs = features["input_data"] 40 | 41 | batch_size = tf.shape(self.encoder_inputs)[0] 42 | 43 | if self.mode == tf.estimator.ModeKeys.TRAIN or self.mode == tf.estimator.ModeKeys.EVAL: 44 | self.decoder_inputs = labels 45 | decoder_input_shift_1 = tf.slice(self.decoder_inputs, [0, 1], 46 | [batch_size, Config.data.max_seq_length-1]) 47 | pad_tokens = tf.zeros([batch_size, 1], dtype=tf.int32) 48 | 49 | # make target (right shift 1 from decoder_inputs) 50 | self.targets = tf.concat([decoder_input_shift_1, pad_tokens], axis=1) 51 | else: 52 | self.decoder_inputs = None 53 | 54 | def build_graph(self): 55 | graph = seq2seq_attention.Graph(mode=self.mode) 56 | graph.build(encoder_inputs=self.encoder_inputs, 57 | decoder_inputs=self.decoder_inputs) 58 | 59 | self.predictions = graph.predictions 60 | if self.mode != tf.estimator.ModeKeys.PREDICT: 61 | self._build_loss(graph.logits, graph.weight_masks) 62 | self._build_optimizer() 63 | self._build_metric() 64 | 65 | def _build_loss(self, logits, weight_masks): 66 | self.loss = tf.contrib.seq2seq.sequence_loss( 67 | logits=logits, 68 | targets=self.targets, 69 | weights=weight_masks, 70 | name="loss") 71 | 72 | def _build_optimizer(self): 73 | self.train_op = tf.contrib.layers.optimize_loss( 74 | self.loss, tf.train.get_global_step(), 75 | optimizer='Adam', 76 | learning_rate=Config.train.learning_rate, 77 | summaries=['loss', 'learning_rate'], 78 | name="train_op") 79 | 80 | def _build_metric(self): 81 | 82 | def blue_score(labels, predictions, 83 | weights=None, metrics_collections=None, 84 | updates_collections=None, name=None): 85 | 86 | def _nltk_blue_score(labels, predictions): 87 | 88 | # slice after 89 | predictions = predictions.tolist() 90 | for i in range(len(predictions)): 91 | prediction = predictions[i] 92 | if Config.data.EOS_ID in prediction: 93 | predictions[i] = prediction[:prediction.index(Config.data.EOS_ID)+1] 94 | 95 | labels = [ 96 | [[w_id for w_id in label if w_id != Config.data.PAD_ID]] 97 | for label in labels.tolist()] 98 | predictions = [ 99 | [w_id for w_id in prediction] 100 | for prediction in predictions] 101 | 102 | return float(nltk.translate.bleu_score.corpus_bleu(labels, predictions)) 103 | 104 | score = tf.py_func(_nltk_blue_score, (labels, predictions), tf.float64) 105 | return tf.metrics.mean(score * 100) 106 | 107 | self.metrics = { 108 | "accuracy": tf.metrics.accuracy(self.targets, self.predictions), 109 | "bleu": blue_score(self.targets, self.predictions) 110 | } 111 | -------------------------------------------------------------------------------- /seq2seq_attention/encoder.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | 5 | __all__ = [ 6 | "Encoder" 7 | ] 8 | 9 | 10 | 11 | class Encoder: 12 | """Encoder class is Mutil-layer Recurrent Neural Networks 13 | The 'Encoder' usually encode the sequential input vector. 14 | """ 15 | 16 | UNI_ENCODER_TYPE = "uni" 17 | BI_ENCODER_TYPE = "bi" 18 | 19 | RNN_GRU_CELL = "gru" 20 | RNN_LSTM_CELL = "lstm" 21 | RNN_LAYER_NORM_LSTM_CELL = "layer_norm_lstm" 22 | RNN_NAS_CELL = "nas" 23 | 24 | def __init__(self, encoder_type="uni", num_layers=4, 25 | cell_type="GRU", num_units=512, dropout=0.8, 26 | dtype=tf.float32): 27 | """Contructs an 'Encoder' instance. 28 | * Args: 29 | encoder_type: rnn encoder_type (uni, bi) 30 | num_layers: number of RNN cell composed sequentially of multiple simple cells. 31 | input_vector: RNN Input vectors. 32 | sequence_length: batch element's sequence length 33 | cell_type: RNN cell types (lstm, gru, layer_norm_lstm, nas) 34 | num_units: the number of units in cell 35 | dropout: set prob operator adding dropout to inputs of the given cell. 36 | dtype: the dtype of the input 37 | * Returns: 38 | Encoder instance 39 | """ 40 | 41 | self.encoder_type = encoder_type 42 | self.num_layers = num_layers 43 | self.cell_type = cell_type 44 | self.num_units = num_units 45 | self.dropout = dropout 46 | self.dtype = dtype 47 | 48 | def build(self, input_vector, sequence_length, scope=None): 49 | if self.encoder_type == self.UNI_ENCODER_TYPE: 50 | self.cells = self._create_rnn_cells() 51 | 52 | return self.unidirectional_rnn(input_vector, sequence_length, scope=scope) 53 | elif self.encoder_type == self.BI_ENCODER_TYPE: 54 | 55 | self.num_layers = int(self.num_layers / 2) 56 | if self.num_layers == 0: 57 | self.num_layers = 1 58 | 59 | self.cells_fw = self._create_rnn_cells(is_list=True) 60 | self.cells_bw = self._create_rnn_cells(is_list=True) 61 | 62 | return self.bidirectional_rnn(input_vector, sequence_length, scope=scope) 63 | else: 64 | raise ValueError(f"Unknown encoder_type {self.encoder_type}") 65 | 66 | def unidirectional_rnn(self, input_vector, sequence_length, scope=None): 67 | return tf.nn.dynamic_rnn( 68 | self.cells, 69 | input_vector, 70 | sequence_length=sequence_length, 71 | dtype=self.dtype, 72 | time_major=False, 73 | swap_memory=True, 74 | scope=scope) 75 | 76 | def bidirectional_rnn(self, input_vector, sequence_length, scope=None): 77 | outputs, output_state_fw, output_state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( 78 | self.cells_fw, 79 | self.cells_bw, 80 | input_vector, 81 | sequence_length=sequence_length, 82 | dtype=self.dtype, 83 | scope=scope) 84 | 85 | if self.cell_type == self.RNN_LSTM_CELL: 86 | encoder_final_state_c = tf.concat((output_state_fw[-1].c, output_state_bw[-1].c), axis=1) 87 | encoder_final_state_h = tf.concat((output_state_fw[-1].h, output_state_bw[-1].h), axis=1) 88 | encoder_final_state = tf.contrib.rnn.LSTMStateTuple(c=encoder_final_state_c, h=encoder_final_state_h) 89 | else: 90 | encoder_final_state = tf.concat((output_state_fw[-1], output_state_bw[-1]), axis=1) 91 | 92 | return outputs, encoder_final_state 93 | 94 | def _create_rnn_cells(self, is_list=False): 95 | """Contructs stacked_rnn with num_layers 96 | * Args: 97 | is_list: flags for stack bidirectional. True=stack bidirectional, False=unidirectional 98 | * Returns: 99 | stacked_rnn 100 | """ 101 | 102 | stacked_rnn = [] 103 | for _ in range(self.num_layers): 104 | single_cell = self._rnn_single_cell() 105 | stacked_rnn.append(single_cell) 106 | 107 | if is_list: 108 | return stacked_rnn 109 | else: 110 | return tf.nn.rnn_cell.MultiRNNCell( 111 | cells=stacked_rnn, 112 | state_is_tuple=True) 113 | 114 | def _rnn_single_cell(self): 115 | """Contructs rnn single_cell""" 116 | 117 | if self.cell_type == self.RNN_GRU_CELL: 118 | single_cell = tf.contrib.rnn.GRUCell( 119 | self.num_units, 120 | reuse=tf.get_variable_scope().reuse) 121 | elif self.cell_type == self.RNN_LSTM_CELL: 122 | single_cell = tf.contrib.rnn.BasicLSTMCell( 123 | self.num_units, 124 | forget_bias=1.0, 125 | reuse=tf.get_variable_scope().reuse) 126 | elif self.cell_type == self.RNN_LAYER_NORM_LSTM_CELL: 127 | single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( 128 | self.num_units, 129 | forget_bias=1.0, 130 | layer_norm=True, 131 | reuse=tf.get_variable_scope().reuse) 132 | elif self.cell_type == self.RNN_NAS_CELL: 133 | single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( 134 | self.num_units) 135 | else: 136 | raise ValueError(f"Unknown rnn cell type. {self.cell_type}") 137 | 138 | if self.dropout > 0.0: 139 | single_cell = tf.contrib.rnn.DropoutWrapper( 140 | cell=single_cell, input_keep_prob=(1.0 - self.dropout)) 141 | 142 | return single_cell 143 | -------------------------------------------------------------------------------- /seq2seq_attention/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from hbconfig import Config 3 | import tensorflow as tf 4 | 5 | from .encoder import Encoder 6 | from .decoder import Decoder 7 | 8 | 9 | class Graph: 10 | 11 | def __init__(self, mode=None, dtype=tf.float32): 12 | self.mode = mode 13 | self.beam_width = Config.predict.get('beam_width', 0) 14 | self.dtype = dtype 15 | 16 | def build(self, 17 | encoder_inputs=None, 18 | decoder_inputs=None): 19 | 20 | # set inputs variable 21 | self.encoder_inputs = encoder_inputs 22 | self.encoder_input_lengths = tf.reduce_sum( 23 | tf.to_int32(tf.not_equal(self.encoder_inputs, Config.data.PAD_ID)), 1, 24 | name="encoder_input_lengths") 25 | 26 | if self.mode == tf.estimator.ModeKeys.TRAIN or self.mode == tf.estimator.ModeKeys.EVAL: 27 | self.decoder_inputs = decoder_inputs 28 | self.decoder_input_lengths = tf.reduce_sum( 29 | tf.to_int32(tf.not_equal(self.decoder_inputs, Config.data.PAD_ID)), 1, 30 | name="decoder_input_lengths") 31 | else: 32 | self.decoder_inputs = None 33 | self.decoder_input_lengths = None 34 | 35 | self._build_embed() 36 | self._build_encoder() 37 | self._build_decoder() 38 | 39 | def _build_embed(self): 40 | with tf.variable_scope ("embeddings", dtype=self.dtype) as scope: 41 | 42 | if Config.model.embed_share: 43 | embedding = tf.get_variable( 44 | "embedding_share", [Config.data.vocab_size, Config.model.embed_dim], self.dtype) 45 | 46 | self.embedding_encoder = embedding 47 | self.embedding_decoder = embedding 48 | else: 49 | self.embedding_encoder = tf.get_variable( 50 | "embedding_encoder", [Config.data.vocab_size, Config.model.embed_dim], self.dtype) 51 | self.embedding_decoder = tf.get_variable( 52 | "embedding_decoder", [Config.data.vocab_size, Config.model.embed_dim], self.dtype) 53 | 54 | self.encoder_emb_inp = tf.nn.embedding_lookup( 55 | self.embedding_encoder, self.encoder_inputs) 56 | 57 | if self.mode == tf.estimator.ModeKeys.TRAIN: 58 | self.decoder_emb_inp = tf.nn.embedding_lookup( 59 | self.embedding_decoder, self.decoder_inputs) 60 | else: 61 | self.decoder_emb_inp=None 62 | 63 | def _build_encoder(self): 64 | with tf.variable_scope('encoder'): 65 | encoder = Encoder( 66 | encoder_type=Config.model.encoder_type, 67 | num_layers=Config.model.num_layers, 68 | cell_type=Config.model.cell_type, 69 | num_units=Config.model.num_units, 70 | dropout=Config.model.dropout) 71 | 72 | self.encoder_outputs, self.encoder_final_state = encoder.build( 73 | input_vector=self.encoder_emb_inp, 74 | sequence_length=self.encoder_input_lengths) 75 | 76 | if self.mode == tf.estimator.ModeKeys.PREDICT and self.beam_width > 0: 77 | self.encoder_outputs = tf.contrib.seq2seq.tile_batch( 78 | self.encoder_outputs, self.beam_width) 79 | self.encoder_input_lengths = tf.contrib.seq2seq.tile_batch( 80 | self.encoder_input_lengths, self.beam_width) 81 | 82 | def _build_decoder(self): 83 | 84 | batch_size = tf.shape(self.encoder_inputs)[0] 85 | 86 | with tf.variable_scope('decoder'): 87 | 88 | decoder = Decoder( 89 | cell_type=Config.model.cell_type, 90 | dropout=Config.model.dropout, 91 | encoder_type=Config.model.encoder_type, 92 | num_layers=Config.model.num_layers, 93 | num_units=Config.model.num_units, 94 | sampling_probability=Config.train.sampling_probability, 95 | mode=self.mode, 96 | dtype=self.dtype) 97 | 98 | decoder.set_attention_then_project( 99 | attention_mechanism=Config.model.attention_mechanism, 100 | beam_width=self.beam_width, 101 | memory=self.encoder_outputs, 102 | memory_sequence_length=self.encoder_input_lengths, 103 | vocab_size=Config.data.vocab_size) 104 | decoder.set_initial_state(batch_size, self.encoder_final_state) 105 | 106 | decoder_outputs = decoder.build( 107 | inputs=self.decoder_emb_inp, 108 | sequence_length=self.decoder_input_lengths, 109 | embedding=self.embedding_decoder, 110 | start_tokens=tf.fill([batch_size], Config.data.START_ID), 111 | end_token=Config.data.EOS_ID, 112 | length_penalty_weight=Config.predict.length_penalty_weight) 113 | 114 | if self.mode == tf.estimator.ModeKeys.TRAIN: 115 | self.decoder_logits = decoder_outputs.rnn_output 116 | else: 117 | if self.mode == tf.estimator.ModeKeys.PREDICT and self.beam_width > 0: 118 | self.decoder_logits = tf.no_op() 119 | self.predictions = decoder_outputs.predicted_ids 120 | else: 121 | self.decoder_logits = decoder_outputs.rnn_output 122 | self.predictions = decoder_outputs.sample_id 123 | 124 | if self.mode == tf.estimator.ModeKeys.PREDICT: 125 | # PREDICT mode do not need loss 126 | return 127 | 128 | decoder_output_length = tf.shape(self.decoder_logits)[1] 129 | 130 | def concat_zero_padding(): 131 | pad_num = Config.data.max_seq_length - decoder_output_length 132 | zero_padding = tf.zeros( 133 | [batch_size, pad_num, Config.data.vocab_size], 134 | dtype=self.dtype) 135 | 136 | return tf.concat([self.decoder_logits, zero_padding], axis=1) 137 | 138 | def slice_to_max_len(): 139 | return tf.slice(self.decoder_logits, 140 | [0, 0, 0], 141 | [batch_size, Config.data.max_seq_length, Config.data.vocab_size]) 142 | 143 | # decoder output sometimes exceed max_seq_length 144 | self.logits = tf.cond(decoder_output_length < Config.data.max_seq_length, 145 | concat_zero_padding, 146 | slice_to_max_len) 147 | self.predictions = tf.argmax(self.logits, axis=2) 148 | 149 | self.weight_masks = tf.sequence_mask( 150 | lengths=self.decoder_input_lengths, 151 | maxlen=Config.data.max_seq_length, 152 | dtype=self.dtype, name='masks') 153 | 154 | if self.mode == tf.estimator.ModeKeys.TRAIN: 155 | self.train_predictions = tf.argmax(self.logits, axis=2) 156 | # for print trainig data 157 | tf.identity(tf.argmax(self.decoder_logits[0], axis=1), name='train/pred_0') 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Neural Conversational Model [![hb-research](https://img.shields.io/badge/hb--research-experiment-green.svg?style=flat&colorA=448C57&colorB=555555)](https://github.com/hb-research) 2 | 3 | TensorFlow implementation of Conversation Models. 4 | 5 | 1. **Model** 6 | 7 | - `seq2seq_attention` : Seq2Seq model with attentional decoder 8 | - Encoder 9 | - Unidirectional RNN 10 | - Stack Bidirectional RNN 11 | - Attention 12 | - [Bahdanau Attention](https://arxiv.org/abs/1409.0473) (option Norm) 13 | - [Luong Attention](https://arxiv.org/abs/1508.04025) (option Scale) 14 | - Decoder 15 | - Greedy (beam_width = 0) 16 | - Beam Search (beam_width > 0) 17 | - [Scheduled Sampling](https://arxiv.org/abs/1506.03099) 18 | 19 | 2. **Dataset** 20 | 21 | - [Cornell_Movie-Dialogs_Corpus](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html) 22 | - [Twitter chat_corpus](https://github.com/Marsan-Ma/chat_corpus) 23 | 24 | ## Requirements 25 | 26 | - Python 3.6 27 | - TensorFlow 1.4 28 | - nltk 29 | - [hb-config](https://github.com/hb-research/hb-config) (Singleton Config) 30 | - tqdm 31 | - [Slack Incoming Webhook URL](https://my.slack.com/services/new/incoming-webhook/) 32 | 33 | ## Project Structure 34 | 35 | initiate Project by [hb-base](https://github.com/hb-research/hb-base) 36 | 37 | . 38 | ├── config # Config files (.yml, .json) using with hb-config 39 | ├── data/ # dataset path 40 | ├── scripts # download dataset using shell scripts 41 | ├── seq2seq_attention # seq2seq_attention architecture graphs (from input to logits) 42 | ├── __init__.py # Graph 43 | ├── encoder.py # Encoder 44 | ├── decoder.py # Decoder 45 | ├── data_loader.py # raw_date -> precossed_data -> generate_batch (using Dataset) 46 | ├── hook.py # training or test hook feature (eg. print_variables) 47 | ├── main.py # define experiment_fn 48 | └── model.py # define EstimatorSpec 49 | 50 | Reference : [hb-config](https://github.com/hb-research/hb-config), [Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator), [experiments_fn](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Experiment), [EstimatorSpec](https://www.tensorflow.org/api_docs/python/tf/estimator/EstimatorSpec) 51 | 52 | 53 | ## Todo 54 | 55 | - make dataset Korean dialog corpus like [Cornell_Movie-Dialogs_Corpus](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html) 56 | - Implements CopyNet 57 | - [Incorporating Copying Mechanism in Sequence-to-Sequence Learning](https://arxiv.org/abs/1603.06393) by J, Gu 2016. 58 | - Apply [hb-research/notes - Neural Text Generation: A Practical Guide](https://github.com/hb-research/notes/blob/master/notes/neural_text_generation.md) 59 | 60 | ## Config 61 | 62 | Can control all **Experimental environment**. 63 | 64 | example: cornell-movie-dialogs.yml 65 | 66 | ```yml 67 | data: 68 | base_path: 'data/cornell_movie_dialogs_corpus/' 69 | conversation_fname: 'movie_conversations.txt' 70 | line_fname: 'movie_lines.txt' 71 | processed_path: 'processed_cornell_movie_dialogs_data' 72 | word_threshold: 2 73 | max_seq_length: 200 74 | sentence_diff: 0.33 # (Filtering with input and output sentence diff) 75 | testset_size: 25000 76 | 77 | PAD_ID: 0 78 | UNK_ID: 1 79 | START_ID: 2 80 | EOS_ID: 3 81 | 82 | model: 83 | batch_size: 32 84 | num_layers: 4 85 | num_units: 512 86 | embed_dim: 256 87 | embed_share: true # (true or false) 88 | cell_type: gru # (lstm, gru, layer_norm_lstm, nas) 89 | dropout: 0.2 90 | encoder_type: bi # (uni / bi) 91 | attention_mechanism: normed_bahdanau # (bahdanau, normed_bahdanau, luong, scaled_luong) 92 | 93 | train: 94 | learning_rate: 0.001 95 | sampling_probability: 0.25 # (Scheduled Sampling) 96 | 97 | train_steps: 100000 98 | model_dir: 'logs/cornell_movie_dialogs' 99 | 100 | save_checkpoints_steps: 1000 101 | loss_hook_n_iter: 1000 102 | check_hook_n_iter: 1000 103 | min_eval_frequency: 1000 104 | 105 | print_verbose: True 106 | debug: False 107 | 108 | predict: 109 | beam_width: 5 # (0: GreedyEmbeddingHelper, 1>=: BeamSearchDecoder) 110 | length_penalty_weight: 1.0 111 | 112 | slack: 113 | webhook_url: "" # after training notify you using slack-webhook 114 | ``` 115 | 116 | 117 | ## Usage 118 | 119 | Install requirements. 120 | 121 | ```pip install -r requirements.txt``` 122 | 123 | First, check if the model is valid. 124 | 125 | ```python main.py --config check_tiny --mode train``` 126 | 127 | Then, download [Cornell_Movie-Dialogs_Corpus](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html) and train it. 128 | 129 | ``` 130 | sh scripts/prepare_Cornell_Movie-Dialogs_Corpus.sh 131 | python data_loader --config cornell-movie-dialogs # pre-processing data 132 | python main.py --config cornell-movie-dialogs --mode train_and_evaluate 133 | ``` 134 | 135 | ### Experiments modes 136 | 137 | :white_check_mark: : Working 138 | :white_medium_small_square: : Not tested yet. 139 | 140 | - :white_check_mark: `evaluate` : Evaluate on the evaluation data. 141 | - :white_medium_small_square: `extend_train_hooks` : Extends the hooks for training. 142 | - :white_medium_small_square: `reset_export_strategies` : Resets the export strategies with the new_export_strategies. 143 | - :white_medium_small_square: `run_std_server` : Starts a TensorFlow server and joins the serving thread. 144 | - :white_medium_small_square: `test` : Tests training, evaluating and exporting the estimator for a single step. 145 | - :white_check_mark: `train` : Fit the estimator using the training data. 146 | - :white_check_mark: `train_and_evaluate` : Interleaves training and evaluation. 147 | 148 | --- 149 | 150 | After training, start chatting. 151 | 152 | ```python chat.py --config cornell-movie-dialogs``` 153 | 154 | 155 | ### Tensorboard 156 | 157 | ```tensorboard --logdir logs``` 158 | 159 | 160 | ## Conversation Example 161 | 162 | \> (Human) 163 | 164 | ``` 165 | > hi. 166 | i ' m sorry . <\s> 167 | > why? 168 | i ' m not interested . 169 | > hmm.. are you human? 170 | nnnnn . <\s> 171 | > you are you? 172 | i ' m not going to marry you . <\s> 173 | > what? me also! 174 | i ' m sorry . <\s> 175 | > anyway, what is your name? 176 | i ' m not going to kill you . <\s> 177 | > i'm asking your name. 178 | i ' m sorry . <\s> 179 | > yeah, so.. where are you from? 180 | i ' m not sure . <\s> 181 | > Have you lost your memory? 182 | i ' m not sure i ' m not going to you . <\s> 183 | > Have you ever been to Florida? 184 | i ' m not sure i ' m not a little . <\s> 185 | > there there. 186 | no , no , no , no , 187 | > Are you alright? 188 | i ' m not going to die . <\s> 189 | > alsoulutely, you are very healthy. 190 | i don ' t know . <\s> 191 | > you know that? you are robot. 192 | i ' m not sure . <\s> 193 | > i don ' t know what to say. 194 | i ' m not sure i ' m not a little stupid . <\s> 195 | > what do you like to do in your free time? 196 | i ' m not sure i ' m not gonna take you a little more . <\s> 197 | > are you in love? 198 | i ' m not sure i ' m not stupid . <\s> 199 | > i know. 200 | i ' m not going to let you 201 | > what's your name? 202 | i ' m not sure i ' m not a good guy . <\s> 203 | > oh god.. 204 | i ' m not going to do that . <\s> 205 | > let's talk about something else. 206 | i ' m not sure . <\s> 207 | > i have to figure out what is the problem. 208 | i ' m not sure . <\s> 209 | > because of you are stupid. 210 | i ' m not going to get you . <\s> 211 | > florida? 212 | he is a homosexual - - 213 | > really? 214 | i ' m not sure i 215 | > can you tell me about more detail about him? 216 | i ' m not a cop . <\s> 217 | > okay. 218 | i ' m not going to 219 | > you always say 'i....' 220 | i ' m not sure i ' m not a murderer . <\s> 221 | > hmm..... 222 | i ' m not sure i ' m not a little popular . <\s> 223 | > you are not popular. 224 | i ' m not sure i ' m not a girl . 225 | > are you boy? 226 | i ' m not sure . <\s> 227 | > are you girl? 228 | i ' m not sure . <\s> 229 | > who are you? 230 | i ' m sorry . <\s> 231 | > i'm sorry. 232 | i ' m not crowding . <\s> 233 | > just kidding. 234 | i ' m going to be a writer 235 | > great! 236 | i don ' t know . 237 | > what do you mean? 238 | i ' m sorry , i ' m sorry . <\s> 239 | > omg 240 | no . <\s> 241 | > florida. 242 | he hasn ' t got any 243 | > music is my life. 244 | you ' re a fool , you ' re a man ! 245 | > wow! 246 | i ' m not kidding . 247 | > sorry. 248 | i ' m not kidding . 249 | > yeah, i know it. 250 | i ' m not sure . <\s> 251 | > california? 252 | what ? <\s> 253 | > haha. 254 | it ' s a real competition 255 | ``` 256 | 257 | 258 | ## Reference 259 | 260 | - [stanford-tensorflow-tutorials](https://github.com/chiphuyen/stanford-tensorflow-tutorials/tree/master/assignments/chatbot) by Chip Huyen 261 | - [TensorFlow Neural Machine Translation Tutorial](https://github.com/tensorflow/nmt) - Tensorflow Official 262 | - [Deep Learning for Chatbots, Part 1 – Introduction](http://www.wildml.com/2016/04/deep-learning-for-chatbots-part-1-introduction/) - WildML 263 | - [hb-research/notes - Neural Text Generation: A Practical Guide](https://github.com/hb-research/notes/blob/master/notes/neural_text_generation.md) 264 | 265 | ## Author 266 | 267 | [Dongjun Lee](https://github.com/DongjunLee) (humanbrain.djlee@gmail.com) 268 | 269 | ### Contributors 270 | 271 | - [junbeomlee](https://github.com/junbeomlee) 272 | -------------------------------------------------------------------------------- /seq2seq_attention/decoder.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | 5 | __all__ = [ 6 | "Attention", "Decoder" 7 | ] 8 | 9 | 10 | class Attention: 11 | """Attention class""" 12 | 13 | BAHDANAU_MECHANISM = "bahdanau" 14 | NORMED_BAHDANAU_MECHANISM = "normed_bahdanau" 15 | 16 | LUONG_MECHANISM = "luong" 17 | SCALED_LUONG_MECHANISM = "scaled_luong" 18 | 19 | def __init__(self, 20 | attention_mechanism="bahdanau", 21 | encoder_type="bi", 22 | num_units=512, 23 | memory=None, 24 | memory_sequence_length=None): 25 | 26 | assert memory is not None 27 | assert memory_sequence_length is not None 28 | 29 | self.attention_mechanism = attention_mechanism 30 | self.encoder_type = encoder_type 31 | self.num_units = num_units 32 | self.memory = memory 33 | self.memory_sequence_length = memory_sequence_length 34 | 35 | def wrap(self, decoder_cell, alignment_history=True): 36 | with tf.variable_scope("attention") as scope: 37 | attention_layer_size = self.num_units 38 | return tf.contrib.seq2seq.AttentionWrapper( 39 | decoder_cell, 40 | self._create_mechanism(), 41 | attention_layer_size=attention_layer_size, 42 | alignment_history=alignment_history, 43 | name=f"{self.attention_mechanism}-mechanism") 44 | 45 | def _create_mechanism(self): 46 | 47 | if self.attention_mechanism == "bahdanau": 48 | return tf.contrib.seq2seq.BahdanauAttention( 49 | self.num_units, 50 | self.memory, 51 | memory_sequence_length=self.memory_sequence_length) 52 | 53 | elif self.attention_mechanism == "normed_bahdanau": 54 | return tf.contrib.seq2seq.BahdanauAttention( 55 | self.num_units, 56 | self.memory, 57 | memory_sequence_length=self.memory_sequence_length, 58 | normalize=True) 59 | 60 | elif self.attention_mechanism == "luong": 61 | return tf.contrib.seq2seq.LuongAttention( 62 | self.num_units, 63 | self.memory, 64 | memory_sequence_length=self.memory_sequence_length) 65 | 66 | elif self.attention_mechanism == "scaled_luong": 67 | return tf.contrib.seq2seq.LuongAttention( 68 | self.num_units, 69 | self.memory, 70 | memory_sequence_length=self.memory_sequence_length, 71 | scale=True) 72 | 73 | else: 74 | raise ValueError(f"Unknown attention mechanism {self.attention_mechanism}") 75 | 76 | 77 | class Decoder: 78 | """Decoder class""" 79 | 80 | UNI_ENCODER_TYPE = "uni" 81 | BI_ENCODER_TYPE = "bi" 82 | 83 | RNN_GRU_CELL = "gru" 84 | RNN_LSTM_CELL = "lstm" 85 | RNN_LAYER_NORM_LSTM_CELL = "layer_norm_lstm" 86 | RNN_NAS_CELL = "nas" 87 | 88 | def __init__(self, 89 | cell_type="lstm", 90 | dropout=0.8, 91 | encoder_type="uni", 92 | num_layers=None, 93 | num_units=None, 94 | sampling_probability=0.4, 95 | mode=tf.estimator.ModeKeys.TRAIN, 96 | dtype=tf.float32): 97 | 98 | self.cell_type = cell_type 99 | self.dropout = dropout 100 | self.encoder_type = encoder_type 101 | self.num_layers = num_layers 102 | self.num_units = num_units 103 | self.sampling_probability = sampling_probability 104 | 105 | if encoder_type == self.BI_ENCODER_TYPE: 106 | self.num_units *= 2 107 | self.num_layers = int(self.num_layers / 2) 108 | if self.num_layers == 0: 109 | self.num_layers = 1 110 | self.mode = mode 111 | self.dtype = dtype 112 | 113 | def set_attention_then_project(self, 114 | attention_mechanism="bahdanau", 115 | beam_width=0, 116 | memory=None, 117 | memory_sequence_length=None, 118 | vocab_size=None): 119 | 120 | self.beam_width = beam_width 121 | 122 | cells = self._create_rnn_cells() 123 | 124 | attention = Attention( 125 | attention_mechanism=attention_mechanism, 126 | encoder_type=self.encoder_type, 127 | num_units=self.num_units, 128 | memory=memory, 129 | memory_sequence_length=memory_sequence_length) 130 | alignment_history = (self.mode == tf.estimator.ModeKeys.PREDICT 131 | and self.beam_width == 0) 132 | 133 | attn_cell = attention.wrap(cells, alignment_history=alignment_history) 134 | self.out_cell = tf.contrib.rnn.OutputProjectionWrapper( 135 | attn_cell, vocab_size) 136 | 137 | self.maximum_iterations = tf.round(tf.reduce_max(memory_sequence_length) * 2) 138 | 139 | def set_initial_state(self, batch_size, encoder_final_state): 140 | if self.mode == tf.estimator.ModeKeys.PREDICT and self.beam_width > 0: 141 | decoder_start_state = tf.contrib.seq2seq.tile_batch(encoder_final_state, self.beam_width) 142 | self.decoder_initial_state = self.out_cell.zero_state(batch_size * self.beam_width, self.dtype) 143 | self.decoder_initial_state = self.decoder_initial_state.clone(cell_state=decoder_start_state) 144 | else: 145 | self.decoder_initial_state = self.out_cell.zero_state(batch_size, self.dtype) 146 | self.decoder_initial_state = self.decoder_initial_state.clone(cell_state=encoder_final_state) 147 | 148 | def build(self, 149 | inputs=None, 150 | sequence_length=None, 151 | embedding=None, 152 | start_tokens=None, 153 | end_token=None, 154 | length_penalty_weight=1.0): 155 | 156 | if self.mode == tf.estimator.ModeKeys.TRAIN: 157 | assert inputs is not None 158 | assert sequence_length is not None 159 | 160 | helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( 161 | inputs=inputs, 162 | sequence_length=sequence_length, 163 | embedding=embedding, 164 | sampling_probability=self.sampling_probability) 165 | 166 | return self._basic_decoder(helper) 167 | 168 | else: 169 | assert embedding is not None 170 | assert start_tokens is not None 171 | assert end_token is not None 172 | 173 | if self.mode == tf.estimator.ModeKeys.PREDICT and self.beam_width > 0: 174 | return self._beam_search_decoder( 175 | embedding, start_tokens, end_token, length_penalty_weight) 176 | else: 177 | helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( 178 | embedding=embedding, 179 | start_tokens=start_tokens, 180 | end_token=end_token) 181 | return self._basic_decoder(helper) 182 | 183 | def _basic_decoder(self, helper): 184 | decoder = tf.contrib.seq2seq.BasicDecoder( 185 | cell=self.out_cell, 186 | helper=helper, 187 | initial_state=self.decoder_initial_state) 188 | 189 | if self.mode == tf.estimator.ModeKeys.TRAIN: 190 | outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode( 191 | decoder=decoder, 192 | output_time_major=False, 193 | impute_finished=True, 194 | swap_memory=True) 195 | else: 196 | outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode( 197 | decoder=decoder, 198 | output_time_major=False, 199 | impute_finished=True, 200 | maximum_iterations=self.maximum_iterations) 201 | 202 | return outputs 203 | 204 | def _beam_search_decoder(self, embedding, start_tokens, end_token, length_penalty_weight): 205 | decoder = tf.contrib.seq2seq.BeamSearchDecoder( 206 | cell=self.out_cell, 207 | embedding=embedding, 208 | start_tokens=start_tokens, 209 | end_token=end_token, 210 | initial_state=self.decoder_initial_state, 211 | beam_width=self.beam_width, 212 | length_penalty_weight=length_penalty_weight) 213 | 214 | outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode( 215 | decoder=decoder, 216 | output_time_major=False, 217 | impute_finished=False, 218 | maximum_iterations=self.maximum_iterations) 219 | return outputs 220 | 221 | def _create_rnn_cells(self): 222 | """Contructs stacked_rnn with num_layers 223 | * Args: 224 | is_list: flags for stack bidirectional. True=stack bidirectional, False=unidirectional 225 | * Returns: 226 | stacked_rnn 227 | """ 228 | 229 | stacked_rnn = [] 230 | for _ in range(self.num_layers): 231 | single_cell = self._rnn_single_cell() 232 | stacked_rnn.append(single_cell) 233 | 234 | if self.num_layers == 1: 235 | return stacked_rnn[0] 236 | else: 237 | return tf.nn.rnn_cell.MultiRNNCell( 238 | cells=stacked_rnn, 239 | state_is_tuple=True) 240 | 241 | def _rnn_single_cell(self): 242 | """Contructs rnn single_cell""" 243 | 244 | if self.cell_type == self.RNN_GRU_CELL: 245 | single_cell = tf.contrib.rnn.GRUCell( 246 | self.num_units, 247 | reuse=tf.get_variable_scope().reuse) 248 | elif self.cell_type == self.RNN_LSTM_CELL: 249 | single_cell = tf.contrib.rnn.BasicLSTMCell( 250 | self.num_units, 251 | forget_bias=1.0, 252 | reuse=tf.get_variable_scope().reuse) 253 | elif self.cell_type == self.RNN_LAYER_NORM_LSTM_CELL: 254 | single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( 255 | self.num_units, 256 | forget_bias=1.0, 257 | layer_norm=True, 258 | reuse=tf.get_variable_scope().reuse) 259 | elif self.cell_type == self.RNN_NAS_CELL: 260 | single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( 261 | self.num_units) 262 | else: 263 | raise ValueError(f"Unknown rnn cell type. {self.cell_type}") 264 | 265 | if self.dropout > 0.0: 266 | single_cell = tf.contrib.rnn.DropoutWrapper( 267 | cell=single_cell, input_keep_prob=(1.0 - self.dropout)) 268 | 269 | return single_cell 270 | 271 | 272 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import os 6 | import random 7 | import re 8 | 9 | from nltk.tokenize import TweetTokenizer 10 | from hbconfig import Config 11 | import numpy as np 12 | import tensorflow as tf 13 | from tqdm import tqdm 14 | 15 | 16 | 17 | tokenizer = TweetTokenizer() 18 | 19 | def get_lines(): 20 | id2line = {} 21 | file_path = os.path.join(Config.data.base_path, Config.data.line_fname) 22 | with open(file_path, 'rb') as f: 23 | lines = f.readlines() 24 | for line in lines: 25 | parts = line.decode('iso-8859-1').split(' +++$+++ ') 26 | if len(parts) == 5: 27 | if parts[4][-1] == '\n': 28 | parts[4] = parts[4][:-1] 29 | id2line[parts[0]] = parts[4] 30 | return id2line 31 | 32 | 33 | def get_convos(): 34 | """ Get conversations from the raw data """ 35 | file_path = os.path.join(Config.data.base_path, Config.data.conversation_fname) 36 | convos = [] 37 | with open(file_path, 'rb') as f: 38 | for line in f.readlines(): 39 | parts = line.decode('iso-8859-1').split(' +++$+++ ') 40 | if len(parts) == 4: 41 | convo = [] 42 | for line in parts[3][1:-2].split(', '): 43 | convo.append(line[1:-1]) 44 | convos.append(convo) 45 | 46 | return convos 47 | 48 | 49 | def cornell_question_answers(id2line, convos): 50 | """ Divide the dataset into two sets: questions and answers. """ 51 | questions, answers = [], [] 52 | for convo in convos: 53 | for index, line in enumerate(convo[:-1]): 54 | questions.append(id2line[convo[index]]) 55 | answers.append(id2line[convo[index + 1]]) 56 | assert len(questions) == len(answers) 57 | return questions, answers 58 | 59 | 60 | def twitter_question_answers(): 61 | """ Divide the dataset into two sets: questions and answers. """ 62 | file_path = os.path.join(Config.data.base_path, Config.data.line_fname) 63 | 64 | twitter_corpus = [] 65 | with open(file_path, 'rb') as f: 66 | for line in f.readlines(): 67 | line = line.decode('utf-8') 68 | 69 | if line[-1] == '\n': 70 | twitter_corpus.append(line[:-1].lower()) 71 | else: 72 | twitter_corpus.append(line.lower()) 73 | 74 | questions = twitter_corpus[::2] # even is question 75 | answers = twitter_corpus[1::2] # odd is answer 76 | 77 | assert len(questions) == len(answers) 78 | return questions, answers 79 | 80 | 81 | def prepare_dataset(questions, answers): 82 | # create path to store all the train & test encoder & decoder 83 | make_dir(Config.data.base_path + Config.data.processed_path) 84 | 85 | # random convos to create the test set 86 | test_ids = random.sample([i for i in range(len(questions))], Config.data.testset_size) 87 | 88 | filenames = ['train.enc', 'train.dec', 'test.enc', 'test.dec'] 89 | files = [] 90 | for filename in filenames: 91 | files.append(open(os.path.join(Config.data.base_path, Config.data.processed_path, filename), 'wb')) 92 | 93 | for i in tqdm(range(len(questions))): 94 | 95 | question = questions[i] 96 | answer = answers[i] 97 | 98 | if i in test_ids: 99 | files[2].write((question + "\n").encode('utf-8').lower()) 100 | files[3].write((answer + '\n').encode('utf-8').lower()) 101 | else: 102 | files[0].write((question + '\n').encode('utf-8').lower()) 103 | files[1].write((answer + '\n').encode('utf-8').lower()) 104 | 105 | for file in files: 106 | file.close() 107 | 108 | 109 | def make_dir(path): 110 | """ Create a directory if there isn't one already. """ 111 | try: 112 | os.mkdir(path) 113 | except OSError: 114 | pass 115 | 116 | 117 | def basic_tokenizer(line, normalize_digits=True): 118 | """ A basic tokenizer to tokenize text into tokens. 119 | Feel free to change this to suit your need. """ 120 | line = re.sub('', '', line) 121 | line = re.sub('', '', line) 122 | line = re.sub('\[', '', line) 123 | line = re.sub('\]', '', line) 124 | words = [] 125 | _WORD_SPLIT = re.compile("([.,!?\"'-<>:;)(])") 126 | _DIGIT_RE = re.compile(r"\d") 127 | for fragment in line.strip().lower().split(): 128 | for token in re.split(_WORD_SPLIT, fragment): 129 | if not token: 130 | continue 131 | if normalize_digits: 132 | token = re.sub(_DIGIT_RE, '#', token) 133 | words.append(token) 134 | return words 135 | 136 | 137 | def build_vocab(in_fname, out_fname, normalize_digits=True): 138 | print("Count each vocab frequency ...") 139 | 140 | vocab = {} 141 | def count_vocab(fname): 142 | with open(fname, 'rb') as f: 143 | for line in tqdm(f.readlines()): 144 | line = line.decode('utf-8') 145 | for token in tokenizer.tokenize(line): 146 | if not token in vocab: 147 | vocab[token] = 0 148 | vocab[token] += 1 149 | 150 | in_path = os.path.join(Config.data.base_path, Config.data.processed_path, in_fname) 151 | out_path = os.path.join(Config.data.base_path, Config.data.processed_path, out_fname) 152 | 153 | count_vocab(in_path) 154 | count_vocab(out_path) 155 | 156 | print("total vocab size:", len(vocab)) 157 | sorted_vocab = sorted(vocab, key=vocab.get, reverse=True) 158 | 159 | dest_path = os.path.join(Config.data.base_path, Config.data.processed_path, 'vocab') 160 | with open(dest_path, 'wb') as f: 161 | f.write(('' + '\n').encode('utf-8')) 162 | f.write(('' + '\n').encode('utf-8')) 163 | f.write(('' + '\n').encode('utf-8')) 164 | f.write(('<\s>' + '\n').encode('utf-8')) 165 | index = 4 166 | for word in tqdm(sorted_vocab): 167 | if vocab[word] < Config.data.word_threshold: 168 | break 169 | 170 | f.write((word + '\n').encode('utf-8')) 171 | index += 1 172 | 173 | 174 | def load_vocab(vocab_fname): 175 | print("load vocab ...") 176 | with open(os.path.join(Config.data.base_path, Config.data.processed_path, vocab_fname), 'rb') as f: 177 | words = f.read().decode('utf-8').splitlines() 178 | print("vocab size:", len(words)) 179 | return {words[i]: i for i in range(len(words))} 180 | 181 | 182 | def sentence2id(vocab, line): 183 | return [vocab.get(token, vocab['']) for token in tokenizer.tokenize(line)] 184 | 185 | 186 | def token2id(data, mode): 187 | """ Convert all the tokens in the data into their corresponding 188 | index in the vocabulary. """ 189 | vocab_path = 'vocab' 190 | in_path = data + '.' + mode 191 | out_path = data + '_ids.' + mode 192 | 193 | vocab = load_vocab(vocab_path) 194 | in_file = open(os.path.join(Config.data.base_path, Config.data.processed_path, in_path), 'rb') 195 | out_file = open(os.path.join(Config.data.base_path, Config.data.processed_path, out_path), 'wb') 196 | 197 | lines = in_file.read().decode('utf-8').splitlines() 198 | for line in tqdm(lines): 199 | if mode == 'dec': # we only care about '' and in decoder 200 | ids = [vocab['']] 201 | else: 202 | ids = [] 203 | 204 | sentence_ids = sentence2id(vocab, line) 205 | ids.extend(sentence_ids) 206 | if mode == 'dec': 207 | ids.append(vocab['<\s>']) 208 | 209 | out_file.write(b' '.join(str(id_).encode('cp1252') for id_ in ids) + b'\n') 210 | 211 | 212 | def prepare_raw_data(): 213 | print('Preparing raw data into train set and test set ...') 214 | 215 | data_type = Config.data.get('type', 'cornell-movie') 216 | if data_type == "cornell-movie": 217 | id2line = get_lines() 218 | convos = get_convos() 219 | questions, answers = cornell_question_answers(id2line, convos) 220 | elif data_type == "twitter": 221 | questions, answers = twitter_question_answers() 222 | elif data_type == "all": 223 | # cornell-movie 224 | Config.data.base_path = "data/cornell_movie_dialogs_corpus/" 225 | Config.data.line_fname = "movie_lines.txt" 226 | Config.data.conversation_fname = "movie_conversations.txt" 227 | 228 | id2line = get_lines() 229 | convos = get_convos() 230 | co_questions, co_answers = cornell_question_answers(id2line, convos) 231 | 232 | #twitter 233 | Config.data.base_path = "data/" 234 | Config.data.line_fname = "twitter_en.txt" 235 | 236 | tw_questions, tw_answers = twitter_question_answers() 237 | 238 | questions = co_questions + tw_questions 239 | answers = co_answers + tw_answers 240 | else: 241 | raise ValueError(f"Unknown data_type, {data_type}") 242 | 243 | prepare_dataset(questions, answers) 244 | 245 | def process_data(): 246 | print('Preparing data to be model-ready ...') 247 | 248 | build_vocab('train.enc', 'train.dec') 249 | 250 | token2id('train', 'enc') 251 | token2id('train', 'dec') 252 | token2id('test', 'enc') 253 | token2id('test', 'dec') 254 | 255 | 256 | def make_train_and_test_set(shuffle=True, bucket=True): 257 | print("make Training data and Test data Start....") 258 | 259 | train_X, train_y = load_data('train_ids.enc', 'train_ids.dec') 260 | test_X, test_y = load_data('test_ids.enc', 'test_ids.dec') 261 | 262 | assert len(train_X) == len(train_y) 263 | assert len(test_X) == len(test_y) 264 | 265 | print(f"train data count : {len(train_X)}") 266 | print(f"test data count : {len(test_X)}") 267 | 268 | if shuffle: 269 | print("shuffle dataset ...") 270 | train_p = np.random.permutation(len(train_y)) 271 | test_p = np.random.permutation(len(test_y)) 272 | 273 | train_X, train_y = train_X[train_p], train_y[train_p] 274 | test_X, test_y = test_X[test_p], test_y[test_p] 275 | 276 | if bucket: 277 | print("sorted by inputs length and outputs length ...") 278 | train_X, train_y = zip(*sorted(zip(train_X, train_y), key=lambda x: len(x[0]) + len([x[1]]))) 279 | test_X, test_y = zip(*sorted(zip(test_X, test_y), key=lambda x: len(x[0]) + len([x[1]]))) 280 | 281 | return train_X, test_X, train_y, test_y 282 | 283 | def load_data(enc_fname, dec_fname): 284 | enc_input_data = open(os.path.join(Config.data.base_path, Config.data.processed_path, enc_fname), 'r') 285 | dec_input_data = open(os.path.join(Config.data.base_path, Config.data.processed_path, dec_fname), 'r') 286 | 287 | enc_data, dec_data = [], [] 288 | for e_line, d_line in tqdm(zip(enc_input_data.readlines(), dec_input_data.readlines())): 289 | e_ids = [int(id_) for id_ in e_line.split()] 290 | d_ids = [int(id_) for id_ in d_line.split()] 291 | 292 | if len(e_ids) == 0 or len(d_ids) == 0: 293 | continue 294 | 295 | if len(e_ids) <= Config.data.max_seq_length and len(d_ids) < Config.data.max_seq_length: 296 | 297 | if abs(len(d_ids) - len(e_ids)) / (len(e_ids) + len(d_ids)) < Config.data.sentence_diff: 298 | enc_data.append(_pad_input(e_ids, Config.data.max_seq_length)) 299 | dec_data.append(_pad_input(d_ids, Config.data.max_seq_length)) 300 | 301 | print(f"load data from {enc_fname}, {dec_fname}...") 302 | return np.array(enc_data, dtype=np.int32), np.array(dec_data, dtype=np.int32) 303 | 304 | 305 | def _pad_input(input_, size): 306 | return input_ + [Config.data.PAD_ID] * (size - len(input_)) 307 | 308 | 309 | def set_max_seq_length(dataset_fnames): 310 | 311 | max_seq_length = Config.data.get('max_seq_length', 10) 312 | 313 | for fname in dataset_fnames: 314 | input_data = open(os.path.join(Config.data.base_path, Config.data.processed_path, fname), 'r') 315 | 316 | for line in input_data.readlines(): 317 | ids = [int(id_) for id_ in line.split()] 318 | seq_length = len(ids) 319 | 320 | if seq_length > max_seq_length: 321 | max_seq_length = seq_length 322 | 323 | Config.data.max_seq_length = max_seq_length 324 | print(f"Setting max_seq_length to Config : {max_seq_length}") 325 | 326 | 327 | def make_batch(data, buffer_size=10000, batch_size=64, scope="train"): 328 | 329 | class IteratorInitializerHook(tf.train.SessionRunHook): 330 | """Hook to initialise data iterator after Session is created.""" 331 | 332 | def __init__(self): 333 | super(IteratorInitializerHook, self).__init__() 334 | self.iterator_initializer_func = None 335 | 336 | def after_create_session(self, session, coord): 337 | """Initialise the iterator after the session has been created.""" 338 | self.iterator_initializer_func(session) 339 | 340 | 341 | def get_inputs(): 342 | 343 | iterator_initializer_hook = IteratorInitializerHook() 344 | 345 | def train_inputs(): 346 | with tf.name_scope(scope): 347 | 348 | X, y = data 349 | 350 | # Define placeholders 351 | input_placeholder = tf.placeholder( 352 | tf.int32, [None, Config.data.max_seq_length]) 353 | output_placeholder = tf.placeholder( 354 | tf.int32, [None, Config.data.max_seq_length]) 355 | 356 | # Build dataset iterator 357 | dataset = tf.data.Dataset.from_tensor_slices( 358 | (input_placeholder, output_placeholder)) 359 | 360 | if scope == "train": 361 | dataset = dataset.repeat(None) # Infinite iterations 362 | else: 363 | dataset = dataset.repeat(1) # 1 Epoch 364 | # dataset = dataset.shuffle(buffer_size=buffer_size) 365 | dataset = dataset.batch(batch_size) 366 | 367 | iterator = dataset.make_initializable_iterator() 368 | next_X, next_y = iterator.get_next() 369 | 370 | tf.identity(next_X[0], 'enc_0') 371 | tf.identity(next_y[0], 'dec_0') 372 | 373 | # Set runhook to initialize iterator 374 | iterator_initializer_hook.iterator_initializer_func = \ 375 | lambda sess: sess.run( 376 | iterator.initializer, 377 | feed_dict={input_placeholder: X, 378 | output_placeholder: y}) 379 | 380 | # Return batched (features, labels) 381 | return next_X, next_y 382 | 383 | # Return function and hook 384 | return train_inputs, iterator_initializer_hook 385 | 386 | return get_inputs() 387 | 388 | 389 | if __name__ == '__main__': 390 | 391 | parser = argparse.ArgumentParser( 392 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 393 | parser.add_argument('--config', type=str, default='config', 394 | help='config file name') 395 | args = parser.parse_args() 396 | 397 | Config(args.config) 398 | 399 | prepare_raw_data() 400 | process_data() 401 | --------------------------------------------------------------------------------