├── README.md ├── __pycache__ ├── chat_command_handler.cpython-36.pyc ├── chat_settings.cpython-36.pyc ├── chatbot_model.cpython-36.pyc ├── general_utils.cpython-36.pyc ├── hparams.cpython-36.pyc └── vocabulary.cpython-36.pyc ├── chat.py ├── chat_command_handler.py ├── chat_settings.py ├── chat_ui.html ├── chat_web.py ├── chatbot_model.py ├── chatbot_screenshots ├── seq1.png ├── seq10.png ├── seq11.png ├── seq2.png ├── seq3.png ├── seq4.png ├── seq5.png ├── seq6.png ├── seq7.png ├── seq8.png └── seq9.png ├── dataset.py ├── dataset_readers ├── cornell_dataset_reader.py ├── csv_dataset_reader.py ├── dataset_reader.py └── dataset_reader_factory.py ├── datasets ├── cornell_movie_dialog │ ├── README.md │ ├── movie_conversations.txt │ ├── movie_lines.txt │ ├── train_with_dependency_based_embeddings.bat │ ├── train_with_dependency_based_embeddings_decoder_only.bat │ ├── train_with_dependency_based_embeddings_encoder_only.bat │ ├── train_with_nnlm_en_embeddings.bat │ ├── train_with_nnlm_en_embeddings_decoder_only.bat │ ├── train_with_nnlm_en_embeddings_encoder_only.bat │ ├── train_with_random_embeddings.bat │ ├── train_with_word2vec_wikipedia_embeddings.bat │ ├── train_with_word2vec_wikipedia_embeddings_decoder_only.bat │ └── train_with_word2vec_wikipedia_embeddings_encoder_only.bat └── csv │ ├── README.md │ ├── csv_data.csv │ ├── train_with_dependency_based_embeddings.bat │ ├── train_with_dependency_based_embeddings_decoder_only.bat │ ├── train_with_dependency_based_embeddings_encoder_only.bat │ ├── train_with_nnlm_en_embeddings.bat │ ├── train_with_nnlm_en_embeddings_decoder_only.bat │ ├── train_with_nnlm_en_embeddings_encoder_only.bat │ ├── train_with_random_embeddings.bat │ ├── train_with_word2vec_wikipedia_embeddings.bat │ ├── train_with_word2vec_wikipedia_embeddings_decoder_only.bat │ └── train_with_word2vec_wikipedia_embeddings_encoder_only.bat ├── general_utils.py ├── hparams.json ├── hparams.py ├── models └── cornell_movie_dialog │ └── README.md ├── roadmap.md ├── train.py ├── train_console_helper.py ├── training_stats.py ├── vocabulary.py └── vocabulary_importers ├── __pycache__ └── vocabulary_importer.cpython-36.pyc ├── checkpoint_vocabulary_importer.py ├── dependency_based_vocabulary_importer.py ├── flatfile_vocabulary_importer.py ├── nnlm_en_vocabulary_importer.py ├── vocabulary_importer.py ├── vocabulary_importer_factory.py └── word2vec_wikipedia_vocabulary_importer.py /README.md: -------------------------------------------------------------------------------- 1 | # seq2seq-chatbot 2 | A sequence2sequence chatbot implementation with TensorFlow. 3 | 4 | ## Chatting with a trained model 5 | 6 | For console chat: 7 | 1. Run `chat_console_best_weights_training.bat` or `chat_console_best_weights_validation.bat` 8 | 9 | For web chat: 10 | 1. Run `chat_web_best_weights_training.bat` or `chat_web_best_weights_validation.bat` 11 | 12 | 2. Open a browser to the URL indicated by the server console, followed by `/chat_ui.html`. This is typically: [http://localhost:8080/chat_ui.html](http://localhost:8080/chat_ui.html) 13 | 14 | ### To chat with a trained model from a python console: 15 | 16 | 1. Set console working directory to the **seq2seq-chatbot** directory. This directory should have the **models** and **datasets** directories directly within it. 17 | 18 | 2. Run chat.py with the model checkpoint path: 19 | ```shell 20 | run chat.py models\cornell_movie_dialog\trained_model\best_weights_training.ckpt 21 | ``` 22 | 23 | ## Training a model 24 | To train a model from a python console: 25 | 26 | 1. To train a new model, run train.py with the dataset path: 27 | ```shell 28 | run train.py --datasetdir=datasets\dataset_name 29 | ``` 30 | 31 | Or to resume training an existing model, run train.py with the model checkpoint path: 32 | ```shell 33 | run train.py --checkpointfile=models\dataset_name\model_name\checkpoint.ckpt 34 | ``` 35 | ## Visualizing a model in TensorBoard 36 | 37 | To start TensorBoard from a terminal: 38 | ```shell 39 | tensorboard --logdir=model_dir 40 | ``` 41 | ## Dependencies 42 | The following python packages are used in seq2seq-chatbot: 43 | (excluding packages that come with Anaconda) 44 | 45 | - [TensorFlow](https://www.tensorflow.org/) 46 | ```shell 47 | pip install --upgrade tensorflow 48 | ``` 49 | For GPU support: [(See here for full GPU install instructions including CUDA and cuDNN)](https://www.tensorflow.org/install/) 50 | ```shell 51 | pip install --upgrade tensorflow-gpu 52 | ``` 53 | 54 | - [jsonpickle](https://jsonpickle.github.io/) 55 | ```shell 56 | pip install --upgrade jsonpickle 57 | ``` 58 | 59 | - [flask 0.12.4](http://flask.pocoo.org/) and [flask-restful](https://flask-restful.readthedocs.io/en/latest/) (required to run the web interface) 60 | ```shell 61 | pip install flask==0.12.4 62 | pip install --upgrade flask-restful 63 | ``` 64 | -------------------------------------------------------------------------------- /__pycache__/chat_command_handler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/chat_command_handler.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/chat_settings.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/chat_settings.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/chatbot_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/chatbot_model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/general_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/general_utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/hparams.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/hparams.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/vocabulary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/vocabulary.cpython-36.pyc -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for chatting with a trained chatbot model 3 | """ 4 | import datetime 5 | from os import path 6 | 7 | import general_utils 8 | import chat_command_handler 9 | from chat_settings import ChatSettings 10 | from chatbot_model import ChatbotModel 11 | from vocabulary import Vocabulary 12 | 13 | #Read the hyperparameters and configure paths 14 | _, model_dir, hparams, checkpoint, _, _ = general_utils.initialize_session("chat") 15 | 16 | #Load the vocabulary 17 | print() 18 | print("Loading vocabulary...") 19 | if hparams.model_hparams.share_embedding: 20 | shared_vocab_filepath = path.join(model_dir, Vocabulary.SHARED_VOCAB_FILENAME) 21 | input_vocabulary = Vocabulary.load(shared_vocab_filepath) 22 | output_vocabulary = input_vocabulary 23 | else: 24 | input_vocab_filepath = path.join(model_dir, Vocabulary.INPUT_VOCAB_FILENAME) 25 | input_vocabulary = Vocabulary.load(input_vocab_filepath) 26 | output_vocab_filepath = path.join(model_dir, Vocabulary.OUTPUT_VOCAB_FILENAME) 27 | output_vocabulary = Vocabulary.load(output_vocab_filepath) 28 | 29 | # Setting up the chat 30 | chatlog_filepath = path.join(model_dir, "chat_logs", "chatlog_{0}.txt".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))) 31 | chat_settings = ChatSettings(hparams.model_hparams, hparams.inference_hparams) 32 | terminate_chat = False 33 | reload_model = False 34 | while not terminate_chat: 35 | #Create the model 36 | print() 37 | print("Initializing model..." if not reload_model else "Re-initializing model...") 38 | print() 39 | with ChatbotModel(mode = "infer", 40 | model_hparams = chat_settings.model_hparams, 41 | input_vocabulary = input_vocabulary, 42 | output_vocabulary = output_vocabulary, 43 | model_dir = model_dir) as model: 44 | 45 | #Load the weights 46 | print() 47 | print("Loading model weights...") 48 | print() 49 | model.load(checkpoint) 50 | 51 | #Show the commands 52 | if not reload_model: 53 | chat_command_handler.print_commands() 54 | 55 | while True: 56 | #Get the input and check if it is a question or a command, and execute if it is a command 57 | question = input("You: ") 58 | is_command, terminate_chat, reload_model = chat_command_handler.handle_command(question, model, chat_settings) 59 | if terminate_chat or reload_model: 60 | break 61 | elif is_command: 62 | continue 63 | else: 64 | #If it is not a command (it is a question), pass it on to the chatbot model to get the answer 65 | question_with_history, answer = model.chat(question, chat_settings) 66 | 67 | #Print the answer or answer beams and log to chat log 68 | if chat_settings.show_question_context: 69 | print("Question with history (context): {0}".format(question_with_history)) 70 | 71 | if chat_settings.show_all_beams: 72 | for i in range(len(answer)): 73 | print("ChatBot (Beam {0}): {1}".format(i, answer[i])) 74 | else: 75 | print("ChatBot: {0}".format(answer)) 76 | 77 | print() 78 | 79 | if chat_settings.inference_hparams.log_chat: 80 | chat_command_handler.append_to_chatlog(chatlog_filepath, question, answer) -------------------------------------------------------------------------------- /chat_command_handler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command handler for chat session 3 | """ 4 | import os 5 | 6 | def append_to_chatlog(chatlog_filepath, question, answer): 7 | """Append a question and answer to the chat log. 8 | 9 | Args: 10 | chatlog_filepath: Path to the chat log file 11 | 12 | question: the question string entered by the user 13 | 14 | answer: the answer string returned by the chatbot 15 | If chat_settings.show_all_beams = True, answer is the array of all answer beams with one string per beam. 16 | """ 17 | chatlog_dir = os.path.dirname(chatlog_filepath) 18 | if not os.path.isdir(chatlog_dir): 19 | os.makedirs(chatlog_dir) 20 | with open(chatlog_filepath, "a", encoding="utf-8") as file: 21 | file.write("You: {0}".format(question)) 22 | file.write('\n') 23 | file.write("ChatBot: {0}".format(answer)) 24 | file.write('\n\n') 25 | 26 | def print_commands(): 27 | """Print the list of available commands and their descriptions. 28 | """ 29 | print() 30 | print() 31 | print("Commands:") 32 | print("-----------General-----------------") 33 | print("--help (Show this list of commands) --reset (Reset to default settings from hparams.json [*]);") 34 | print("--exit (Quit);") 35 | print() 36 | print("-----------Chat Options:-----------") 37 | print("--enableautopunct (Auto add punctuation to questions); --disableautopunct (Enter punctuation exactly as typed);") 38 | print("--enablenormwords (Auto replace 'don't' with 'do not', etc.); --disablenormwords (Enter words exactly as typed);") 39 | print("--showquestioncontext (Show conversation history as context); --hidequestioncontext (Show questions only);") 40 | print("--showbeams (Output all predicted beams); --hidebeams (Output only the highest ranked beam);") 41 | print("--convhistlength=N (Set conversation history length to N); --clearconvhist (Clear history and start a new conversation);") 42 | print() 43 | print("-----------Model Options:----------") 44 | print("--beamwidth=N (Set beam width to N. 0 disables beamsearch [*]); --beamlenpenalty=N (Set beam length penalty to N);") 45 | print("--enablesampling (Use sampling decoder if beamwidth=0 [*]); --disableasampling (Use greedy decoder if beamwidth=0 [*]);") 46 | print("--samplingtemp=N (Set sampling temperature to N); --maxanswerlen=N (Set max words in answer to N);") 47 | print() 48 | print() 49 | print("[*] Causes model to reload") 50 | print() 51 | print() 52 | 53 | def handle_command(input_str, model, chat_settings): 54 | """Given a user input string, determine if it is a command or a question and process if it is a command. 55 | 56 | Args: 57 | input_str: the user input string 58 | 59 | model: the ChatbotModel instance 60 | 61 | chat_settings: the ChatSettings instance 62 | """ 63 | reload_model = False 64 | terminate_chat = False 65 | is_command = True 66 | cmd_value = _get_command_value(input_str) 67 | #General Commands 68 | if input_str == '--help': 69 | print_commands() 70 | elif input_str == '--reset': 71 | chat_settings.reset_to_defaults() 72 | reload_model = True 73 | print ("[Reset to default settings.]") 74 | elif input_str == '--exit': 75 | terminate_chat = True 76 | #Chat Options 77 | elif input_str == '--enableautopunct': 78 | chat_settings.enable_auto_punctuation = True 79 | print ("[Auto-punctuation enabled.]") 80 | elif input_str == '--disableautopunct': 81 | chat_settings.enable_auto_punctuation = False 82 | print ("[Auto-punctuation disabled.]") 83 | elif input_str == '--enablenormwords': 84 | chat_settings.inference_hparams.normalize_words = True 85 | print ("[Word normalization enabled.]") 86 | elif input_str == '--disablenormwords': 87 | chat_settings.inference_hparams.normalize_words = False 88 | print ("[Word normalization disabled.]") 89 | elif input_str == '--showquestioncontext': 90 | chat_settings.show_question_context = True 91 | print ("[Show question context enabled.]") 92 | elif input_str == "--hidequestioncontext": 93 | chat_settings.show_question_context = False 94 | print ("[Show question context disabled.]") 95 | elif input_str == '--showbeams': 96 | chat_settings.show_all_beams = True 97 | print ("[Show all beams enabled.]") 98 | elif input_str == "--hidebeams": 99 | chat_settings.show_all_beams = False 100 | print ("[Show all beams disabled.]") 101 | elif input_str.startswith("--convhistlength"): 102 | if cmd_value is not None: 103 | chat_settings.inference_hparams.conv_history_length = int(cmd_value) 104 | model.trim_conversation_history(chat_settings.inference_hparams.conv_history_length) 105 | print ("[Conversation history length set to {0}.]".format(chat_settings.inference_hparams.conv_history_length)) 106 | elif input_str == '--clearconvhist': 107 | model.trim_conversation_history(0) 108 | print ("[Conversation history cleared.]") 109 | #Model Options 110 | elif input_str.startswith("--beamwidth"): 111 | if cmd_value is not None: 112 | chat_settings.model_hparams.beam_width = int(cmd_value) 113 | reload_model = True 114 | print ("[Beam width set to {0}.]".format(chat_settings.model_hparams.beam_width)) 115 | elif input_str.startswith("--beamlenpenalty"): 116 | if cmd_value is not None: 117 | chat_settings.inference_hparams.beam_length_penalty_weight = float(cmd_value) 118 | print ("[Beam length penalty weight set to {0}.]".format(chat_settings.inference_hparams.beam_length_penalty_weight)) 119 | elif input_str == '--enablesampling': 120 | chat_settings.model_hparams.enable_sampling = True 121 | if chat_settings.model_hparams.beam_width == 0: 122 | reload_model = True 123 | print ("[Sampling decoder enabled (if beamwidth=0).]") 124 | elif input_str == '--disablesampling': 125 | chat_settings.model_hparams.enable_sampling = False 126 | if chat_settings.model_hparams.beam_width == 0: 127 | reload_model = True 128 | print ("[Sampling decoder disabled. Using greedy decoder (if beamwidth=0).]") 129 | elif input_str.startswith("--samplingtemp"): 130 | if cmd_value is not None: 131 | chat_settings.inference_hparams.sampling_temperature = float(cmd_value) 132 | print ("[Sampling temperature set to {0}.]".format(chat_settings.inference_hparams.sampling_temperature)) 133 | elif input_str.startswith("--maxanswerlen"): 134 | if cmd_value is not None: 135 | chat_settings.inference_hparams.max_answer_words = int(cmd_value) 136 | print ("[Max words in answer set to {0}.]".format(chat_settings.inference_hparams.max_answer_words)) 137 | #Not a command 138 | else: 139 | is_command = False 140 | 141 | return is_command, terminate_chat, reload_model 142 | 143 | def _get_command_value(input_str): 144 | """Parses a command string and returns the value to the right of the '=' sign 145 | 146 | Args: 147 | input_str: the command string 148 | """ 149 | idx = input_str.find("=") 150 | if idx > -1: 151 | return input_str[idx+1:].strip() 152 | else: 153 | return None -------------------------------------------------------------------------------- /chat_settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | ChatSettings class 3 | """ 4 | import copy 5 | 6 | class ChatSettings(object): 7 | """Contains settings for a chat session. 8 | """ 9 | def __init__(self, model_hparams, inference_hparams): 10 | """ 11 | Args: 12 | inference_hparams: the loaded InferenceHparams instance to use as default for this chat session 13 | """ 14 | self.show_question_context = False 15 | self.show_all_beams = False 16 | self.enable_auto_punctuation = True 17 | self.model_hparams = None 18 | self.inference_hparams = None 19 | 20 | self._default_model_hparams = model_hparams 21 | self._default_inference_hparams = inference_hparams 22 | self.reset_to_defaults() 23 | 24 | def reset_to_defaults(self): 25 | """Reset all settings to defaults 26 | """ 27 | self.show_question_context = False 28 | self.show_all_beams = False 29 | self.enable_auto_punctuation = True 30 | self.model_hparams = copy.copy(self._default_model_hparams) 31 | self.inference_hparams = copy.copy(self._default_inference_hparams) -------------------------------------------------------------------------------- /chat_ui.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ChatBot 5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 |
13 | 14 | 15 |
16 |
17 | 18 |
19 | 20 | 21 |
22 |
23 | 24 | 25 |
26 | 27 | 28 | 29 |
30 |
31 | 32 | 82 | 83 | 160 | 161 | 188 | 189 | -------------------------------------------------------------------------------- /chat_web.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for serving a trained chatbot model over http 3 | """ 4 | import datetime 5 | import click 6 | from os import path 7 | from flask import Flask, request, send_from_directory 8 | from flask_cors import CORS 9 | from flask_restful import Resource, Api 10 | 11 | import general_utils 12 | import chat_command_handler 13 | from chat_settings import ChatSettings 14 | from chatbot_model import ChatbotModel 15 | from vocabulary import Vocabulary 16 | 17 | app = Flask(__name__) 18 | CORS(app) 19 | 20 | @app.cli.command() 21 | @click.argument("checkpointfile") 22 | @click.option("-p", "--port", type=int) 23 | def serve_chat(checkpointfile, port): 24 | 25 | api = Api(app) 26 | 27 | #Read the hyperparameters and configure paths 28 | model_dir, hparams, checkpoint = general_utils.initialize_session_server(checkpointfile) 29 | 30 | #Load the vocabulary 31 | print() 32 | print ("Loading vocabulary...") 33 | if hparams.model_hparams.share_embedding: 34 | shared_vocab_filepath = path.join(model_dir, Vocabulary.SHARED_VOCAB_FILENAME) 35 | input_vocabulary = Vocabulary.load(shared_vocab_filepath) 36 | output_vocabulary = input_vocabulary 37 | else: 38 | input_vocab_filepath = path.join(model_dir, Vocabulary.INPUT_VOCAB_FILENAME) 39 | input_vocabulary = Vocabulary.load(input_vocab_filepath) 40 | output_vocab_filepath = path.join(model_dir, Vocabulary.OUTPUT_VOCAB_FILENAME) 41 | output_vocabulary = Vocabulary.load(output_vocab_filepath) 42 | 43 | #Create the model 44 | print ("Initializing model...") 45 | print() 46 | with ChatbotModel(mode = "infer", 47 | model_hparams = hparams.model_hparams, 48 | input_vocabulary = input_vocabulary, 49 | output_vocabulary = output_vocabulary, 50 | model_dir = model_dir) as model: 51 | 52 | #Load the weights 53 | print() 54 | print ("Loading model weights...") 55 | model.load(checkpoint) 56 | 57 | # Setting up the chat 58 | chatlog_filepath = path.join(model_dir, "chat_logs", "web_chatlog_{0}.txt".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))) 59 | chat_settings = ChatSettings(hparams.model_hparams, hparams.inference_hparams) 60 | chat_command_handler.print_commands() 61 | 62 | class Answer(Resource): 63 | def get(self, question): 64 | is_command, terminate_chat, _ = chat_command_handler.handle_command(question, model, chat_settings) 65 | if terminate_chat: 66 | answer = "[Can't terminate from http request]" 67 | elif is_command: 68 | answer = "[Command processed]" 69 | else: 70 | #If it is not a command (it is a question), pass it on to the chatbot model to get the answer 71 | _, answer = model.chat(question, chat_settings) 72 | 73 | if chat_settings.inference_hparams.log_chat: 74 | chat_command_handler.append_to_chatlog(chatlog_filepath, question, answer) 75 | 76 | return answer 77 | 78 | class UI(Resource): 79 | def get(self): 80 | return send_from_directory(".", "chat_ui.html") 81 | 82 | api.add_resource(Answer, "/chat/") 83 | api.add_resource(UI, "/chat_ui/") 84 | app.run(debug=False, port=port) -------------------------------------------------------------------------------- /chatbot_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | ChatbotModel class 3 | """ 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.python.layers import core as layers_core 7 | from tensorflow.contrib.tensorboard.plugins import projector 8 | from os import path 9 | 10 | from vocabulary import Vocabulary 11 | 12 | class ChatbotModel(object): 13 | """Seq2Seq chatbot model class. 14 | 15 | This class encapsulates all interaction with TensorFlow. Users of this class need not be aware of the dependency on TF. 16 | This abstraction allows for future use of alternative DL libraries* by writing ChatbotModel implementations for them 17 | without the calling code needing to change. 18 | 19 | This class implementation was influenced by studying the TensorFlow Neural Machine Translation (NMT) tutorial: 20 | - https://www.tensorflow.org/tutorials/seq2seq 21 | These files were helpful in understanding the seq2seq implementation: 22 | - https://github.com/tensorflow/nmt/blob/master/nmt/model.py 23 | - https://github.com/tensorflow/nmt/blob/master/nmt/model_helper.py 24 | - https://github.com/tensorflow/nmt/blob/master/nmt/attention_model.py 25 | 26 | Although the code in this implementation is original, there are some similarities with the NMT tutorial in certain places. 27 | """ 28 | 29 | def __init__(self, 30 | mode, 31 | model_hparams, 32 | input_vocabulary, 33 | output_vocabulary, 34 | model_dir): 35 | """Create the Seq2Seq chatbot model. 36 | 37 | Args: 38 | mode: "train" or "infer" 39 | 40 | model_hparams: parameters which determine the architecture and complexity of the model. 41 | See hparams.py for in-depth comments. 42 | 43 | input_vocabulary: the input vocabulary. Each word in this vocabulary gets its own vector 44 | in the encoder embedding matrix. 45 | 46 | output_vocabulary: the output vocabulary. Each word in this vocabulary gets its own vector 47 | in the decoder embedding matrix. 48 | 49 | model_dir: the directory to output model summaries and load / save checkpoints. 50 | """ 51 | self.mode = mode 52 | self.model_hparams = model_hparams 53 | self.input_vocabulary = input_vocabulary 54 | self.output_vocabulary = output_vocabulary 55 | self.model_dir = model_dir 56 | 57 | #Validate arguments 58 | if self.model_hparams.share_embedding and self.input_vocabulary.size() != self.output_vocabulary.size(): 59 | raise ValueError("Cannot share embedding matrices when the input and output vocabulary sizes are different.") 60 | 61 | if self.model_hparams.share_embedding and self.model_hparams.encoder_embedding_size != self.model_hparams.decoder_embedding_size: 62 | raise ValueError("Cannot share embedding matrices when the encoder and decoder embedding sizes are different.") 63 | 64 | if self.input_vocabulary.external_embeddings is not None and self.input_vocabulary.external_embeddings.shape[1] != self.model_hparams.encoder_embedding_size: 65 | raise ValueError("Cannot use external embeddings with vector size {0} for the encoder when the hparams encoder embedding size is {1}".format(self.input_vocabulary.external_embeddings.shape[1], 66 | self.model_hparams.encoder_embedding_size)) 67 | if self.output_vocabulary.external_embeddings is not None and self.output_vocabulary.external_embeddings.shape[1] != self.model_hparams.decoder_embedding_size: 68 | raise ValueError("Cannot use external embeddings with vector size {0} for the decoder when the hparams decoder embedding size is {1}".format(self.output_vocabulary.external_embeddings.shape[1], 69 | self.model_hparams.decoder_embedding_size)) 70 | 71 | tf.contrib.learn.ModeKeys.validate(self.mode) 72 | 73 | if self.mode == tf.contrib.learn.ModeKeys.TRAIN or self.model_hparams.beam_width is None: 74 | self.beam_width = 0 75 | else: 76 | self.beam_width = self.model_hparams.beam_width 77 | 78 | #Reset the default TF graph 79 | tf.reset_default_graph() 80 | 81 | #Define general model inputs 82 | self.inputs = tf.placeholder(tf.int32, [None, None], name = "inputs") 83 | self.input_sequence_length = tf.placeholder(tf.int32, [None], name = "input_sequence_length") 84 | 85 | #Build model 86 | initializer_feed_dict = {} 87 | if self.mode == tf.contrib.learn.ModeKeys.TRAIN: 88 | #Define training model inputs 89 | self.targets = tf.placeholder(tf.int32, [None, None], name = "targets") 90 | self.target_sequence_length = tf.placeholder(tf.int32, [None], name = "target_sequence_length") 91 | self.learning_rate = tf.placeholder(tf.float32, name= "learning_rate") 92 | self.keep_prob = tf.placeholder(tf.float32, name = "keep_prob") 93 | if self.input_vocabulary.external_embeddings is not None: 94 | self.input_external_embeddings = tf.placeholder(tf.float32, 95 | shape = self.input_vocabulary.external_embeddings.shape, 96 | name = "input_external_embeddings") 97 | initializer_feed_dict[self.input_external_embeddings] = self.input_vocabulary.external_embeddings 98 | if self.output_vocabulary.external_embeddings is not None and not self.model_hparams.share_embedding: 99 | self.output_external_embeddings = tf.placeholder(tf.float32, 100 | shape = self.output_vocabulary.external_embeddings.shape, 101 | name = "output_external_embeddings") 102 | initializer_feed_dict[self.output_external_embeddings] = self.output_vocabulary.external_embeddings 103 | 104 | self.loss, self.training_step = self._build_model() 105 | 106 | elif self.mode == tf.contrib.learn.ModeKeys.INFER: 107 | #Define inference model inputs 108 | self.max_output_sequence_length = tf.placeholder(tf.int32, [], name="max_output_sequence_length") 109 | self.beam_length_penalty_weight = tf.placeholder(tf.float32, name = "beam_length_penalty_weight") 110 | self.sampling_temperature = tf.placeholder(tf.float32, name = "sampling_temperature") 111 | 112 | self.predictions, self.predictions_seq_lengths = self._build_model() 113 | self.conversation_history = [] 114 | else: 115 | raise ValueError("Unsupported model mode. Choose 'train' or 'infer'.") 116 | 117 | # Get the final merged summary for writing to TensorBoard 118 | self.merged_summary = tf.summary.merge_all() 119 | 120 | # Defining the session, summary writer, and checkpoint saver 121 | self.session = self._create_session() 122 | 123 | self.summary_writer = tf.summary.FileWriter(self.model_dir, self.session.graph) 124 | 125 | self.session.run(tf.global_variables_initializer(), initializer_feed_dict) 126 | 127 | self.saver = tf.train.Saver() 128 | 129 | def load(self, filename): 130 | """Loads a trained model from a checkpoint 131 | 132 | Args: 133 | filename: Checkpoint filename, such as best_model_checkpoint.ckpt 134 | This file must exist within model_dir. 135 | """ 136 | filepath = path.join(self.model_dir, filename) 137 | self.saver.restore(self.session, filepath) 138 | 139 | def save(self, filename): 140 | """Saves a checkpoint of the current model weights 141 | 142 | Args: 143 | filename: Checkpoint filename, such as best_model_checkpoint.ckpt. 144 | This file must exist within model_dir. 145 | """ 146 | filepath = path.join(self.model_dir, filename) 147 | self.saver.save(self.session, filepath) 148 | 149 | config = projector.ProjectorConfig() 150 | if self.model_hparams.share_embedding: 151 | shared_embedding = config.embeddings.add() 152 | shared_embedding.tensor_name = "model/encoder/shared_embeddings_matrix" 153 | shared_embedding.metadata_path = Vocabulary.SHARED_VOCAB_FILENAME 154 | else: 155 | encoder_embedding = config.embeddings.add() 156 | encoder_embedding.tensor_name = "model/encoder/encoder_embeddings_matrix" 157 | encoder_embedding.metadata_path = Vocabulary.INPUT_VOCAB_FILENAME 158 | decoder_embedding = config.embeddings.add() 159 | decoder_embedding.tensor_name = "model/decoder/decoder_embeddings_matrix" 160 | decoder_embedding.metadata_path = Vocabulary.OUTPUT_VOCAB_FILENAME 161 | 162 | projector.visualize_embeddings(self.summary_writer, config) 163 | 164 | def train_batch(self, inputs, targets, input_sequence_length, target_sequence_length, learning_rate, dropout, global_step, log_summary=True): 165 | """Train the model on one batch, and return the training loss. 166 | 167 | Args: 168 | inputs: The input matrix of shape (batch_size, sequence_length) 169 | where each value in the sequences are words encoded as integer indexes of the input vocabulary. 170 | 171 | targets: The target matrix of shape (batch_size, sequence_length) 172 | where each value in the sequences are words encoded as integer indexes of the output vocabulary. 173 | 174 | input_sequence_length: A vector of sequence lengths of shape (batch_size) 175 | containing the lengths of every input sequence in the batch. This allows for dynamic sequence lengths. 176 | 177 | target_sequence_length: A vector of sequence lengths of shape (batch_size) 178 | containing the lengths of every target sequence in the batch. This allows for dynamic sequence lengths. 179 | 180 | learning_rate: The learning rate to use for the weight updates. 181 | 182 | dropout: The probability (0 <= p <= 1) that any neuron will be randomly disabled for this training step. 183 | This regularization technique can allow the model to learn more independent relationships in the input data 184 | and reduce overfitting. Too much dropout can make the model underfit. Typical values range between 0.2 - 0.5 185 | 186 | global_step: The index of this training step across all batches and all epochs. This allows TensorBoard to trend 187 | the training visually over time. 188 | 189 | log_summary: Flag indicating if the training summary should be logged (for visualization in TensorBoard). 190 | """ 191 | 192 | if self.mode != tf.contrib.learn.ModeKeys.TRAIN: 193 | raise ValueError("train_batch can only be called when the model is initialized in train mode.") 194 | 195 | #Calculate the keep_probability (prob. a neuron will not be dropped) as 1 - dropout rate 196 | keep_probability = 1.0 - dropout 197 | 198 | #Train on the batch 199 | _, batch_training_loss, merged_summary = self.session.run([self.training_step, self.loss, self.merged_summary], 200 | { self.inputs: inputs, 201 | self.targets: targets, 202 | self.input_sequence_length: input_sequence_length, 203 | self.target_sequence_length: target_sequence_length, 204 | self.learning_rate: learning_rate, 205 | self.keep_prob: keep_probability }) 206 | 207 | #Write the training summary for this step if summary logging is enabled. 208 | if log_summary: 209 | self.summary_writer.add_summary(merged_summary, global_step) 210 | 211 | return batch_training_loss 212 | 213 | def validate_batch(self, inputs, targets, input_sequence_length, target_sequence_length, metric = "loss"): 214 | """Evaluate the metric on one batch and return. 215 | 216 | Args: 217 | inputs: The input matrix of shape (batch_size, sequence_length) 218 | where each value in the sequences are words encoded as integer indexes of the input vocabulary. 219 | 220 | targets: The target matrix of shape (batch_size, sequence_length) 221 | where each value in the sequences are words encoded as integer indexes of the output vocabulary. 222 | 223 | input_sequence_length: A vector of sequence lengths of shape (batch_size) 224 | containing the lengths of every input sequence in the batch. This allows for dynamic sequence lengths. 225 | 226 | target_sequence_length: A vector of sequence lengths of shape (batch_size) 227 | containing the lengths of every target sequence in the batch. This allows for dynamic sequence lengths. 228 | 229 | metric: The desired validation metric. Currently only "loss" is supported. This will eventually support 230 | "accuracy", "bleu", and other common validation metrics. 231 | """ 232 | 233 | if self.mode != tf.contrib.learn.ModeKeys.TRAIN: 234 | raise ValueError("validate_batch can only be called when the model is initialized in train mode.") 235 | 236 | if metric == "loss": 237 | metric_op = self.loss 238 | else: 239 | raise ValueError("Unsupported validation metric: '{0}'".format(metric)) 240 | 241 | metric_value = self.session.run(metric_op, { self.inputs: inputs, 242 | self.targets: targets, 243 | self.input_sequence_length: input_sequence_length, 244 | self.target_sequence_length: target_sequence_length, 245 | self.keep_prob: 1 }) 246 | 247 | return metric_value 248 | 249 | def predict_batch(self, inputs, input_sequence_length, max_output_sequence_length, beam_length_penalty_weight, sampling_temperature, log_summary=True): 250 | """Predict a batch of output sequences given a batch of input sequences. 251 | 252 | Args: 253 | inputs: The input matrix of shape (batch_size, sequence_length) 254 | where each value in the sequences are words encoded as integer indexes of the input vocabulary. 255 | 256 | input_sequence_length: A vector of sequence lengths of shape (batch_size) 257 | containing the lengths of every input sequence in the batch. This allows for dynamic sequence lengths. 258 | 259 | max_output_sequence_length: The maximum number of timesteps the decoder can generate. 260 | If the decoder generates an EOS token sooner, it will end there. This maximum value just makes sure 261 | the decoder doesn't go on forever if no EOS is generated. 262 | 263 | beam_length_penalty_weight: When using beam search decoding, this penalty weight influences how 264 | beams are ranked. Large negative values rank very short beams first while large postive values rank very long beams first. 265 | A value of 0 will not influence the beam ranking. For a chatbot model, positive values between 0 and 2 can be beneficial 266 | to help the bot avoid short generic answers. 267 | 268 | sampling_temperature: When using sampling decoding, higher temperature values result in more random sampling 269 | while lower temperature values behave more like greedy decoding which takes the argmax of the output class distribution 270 | (softmax probability distribution over the output vocabulary). If this value is set to 0, sampling is disabled 271 | and greedy decoding is used. 272 | 273 | log_summary: Flag indicating if the inference summary should be logged (for visualization in TensorBoard). 274 | """ 275 | 276 | if self.mode != tf.contrib.learn.ModeKeys.INFER: 277 | raise ValueError("predict_batch can only be called when the model is initialized in infer mode.") 278 | 279 | fetches = [{ "predictions": self.predictions, "predictions_seq_lengths": self.predictions_seq_lengths }] 280 | if self.merged_summary is not None: 281 | fetches.append(self.merged_summary) 282 | 283 | predicted_output_info = self.session.run(fetches, { self.inputs: inputs, 284 | self.input_sequence_length: input_sequence_length, 285 | self.max_output_sequence_length: max_output_sequence_length, 286 | self.beam_length_penalty_weight: beam_length_penalty_weight, 287 | self.sampling_temperature: sampling_temperature }) 288 | 289 | #Write the training summary for this prediction if summary logging is enabled. 290 | if log_summary and len(predicted_output_info) == 2: 291 | merged_summary = predicted_output_info[1] 292 | self.summary_writer.add_summary(merged_summary) 293 | 294 | return predicted_output_info[0] 295 | 296 | def chat(self, question, chat_settings): 297 | """Chat with the chatbot model by predicting an answer to a question. 298 | 'question' and 'answer' in this context are generic terms for the interactions in a dialog exchange 299 | and can be statements, remarks, queries, requests, or any other type of dialog speech. 300 | For example: 301 | Question: "How are you?" Answer: "Fine." 302 | Question: "That's great." Answer: "Yeah." 303 | 304 | Args: 305 | question: The input question for which the model should predict an answer. 306 | 307 | chat_settings: The ChatSettings instance containing the chat settings and inference hyperparameters 308 | 309 | Returns: 310 | q_with_hist: question with history if chat_settings.show_question_context = True otherwise None. 311 | 312 | answers: array of answer beams if chat_settings.show_all_beams = True otherwise the single selected answer. 313 | 314 | """ 315 | #Process the question by cleaning it and converting it to an integer encoded vector 316 | if chat_settings.enable_auto_punctuation: 317 | question = Vocabulary.auto_punctuate(question) 318 | question = Vocabulary.clean_text(question, normalize_words = chat_settings.inference_hparams.normalize_words) 319 | question = self.input_vocabulary.words2ints(question) 320 | 321 | #Prepend the currently tracked steps of the conversation history separated by EOS tokens. 322 | #This allows for deeper dialog context to influence the answer prediction. 323 | question_with_history = [] 324 | for i in range(len(self.conversation_history)): 325 | question_with_history += self.conversation_history[i] + [self.input_vocabulary.eos_int()] 326 | question_with_history += question 327 | 328 | #Get the answer prediction 329 | batch = np.zeros((1, len(question_with_history))) 330 | batch[0] = question_with_history 331 | max_output_sequence_length = chat_settings.inference_hparams.max_answer_words + 1 # + 1 since the EOS token is counted as a timestep 332 | predicted_answer_info = self.predict_batch(inputs = batch, 333 | input_sequence_length = np.array([len(question_with_history)]), 334 | max_output_sequence_length = max_output_sequence_length, 335 | beam_length_penalty_weight = chat_settings.inference_hparams.beam_length_penalty_weight, 336 | sampling_temperature = chat_settings.inference_hparams.sampling_temperature, 337 | log_summary = chat_settings.inference_hparams.log_summary) 338 | 339 | #Read the answer prediction 340 | answer_beams = [] 341 | if self.beam_width > 0: 342 | #For beam search decoding: if show_all_beams is enabeled then output all beams (sequences), otherwise take the first beam. 343 | # The beams (in the "predictions" matrix) are ordered with the highest ranked beams first. 344 | beam_count = 1 if not chat_settings.show_all_beams else len(predicted_answer_info["predictions_seq_lengths"][0]) 345 | for i in range(beam_count): 346 | predicted_answer_seq_length = predicted_answer_info["predictions_seq_lengths"][0][i] - 1 #-1 to exclude the EOS token 347 | predicted_answer = predicted_answer_info["predictions"][0][:predicted_answer_seq_length, i].tolist() 348 | answer_beams.append(predicted_answer) 349 | else: 350 | #For greedy / sampling decoding: only one beam (sequence) is returned, based on the argmax for greedy decoding 351 | # or the sampling distribution for sampling decoding. Return this beam. 352 | beam_count = 1 353 | predicted_answer_seq_length = predicted_answer_info["predictions_seq_lengths"][0] - 1 #-1 to exclude the EOS token 354 | predicted_answer = predicted_answer_info["predictions"][0][:predicted_answer_seq_length].tolist() 355 | answer_beams.append(predicted_answer) 356 | 357 | #Add new conversation steps to the end of the history and trim from the beginning if it is longer than conv_history_length 358 | #Answers need to be converted from output_vocabulary ints to input_vocabulary ints (since they will be fed back in to the encoder) 359 | self.conversation_history.append(question) 360 | answer_for_history = self.output_vocabulary.ints2words(answer_beams[0], is_punct_discrete_word = True, capitalize_i = False) 361 | answer_for_history = self.input_vocabulary.words2ints(answer_for_history) 362 | self.conversation_history.append(answer_for_history) 363 | self.trim_conversation_history(chat_settings.inference_hparams.conv_history_length) 364 | 365 | #Convert the answer(s) to text and return 366 | answers = [] 367 | for i in range(beam_count): 368 | answer = self.output_vocabulary.ints2words(answer_beams[i]) 369 | answers.append(answer) 370 | 371 | q_with_hist = None if not chat_settings.show_question_context else self.input_vocabulary.ints2words(question_with_history) 372 | if chat_settings.show_all_beams: 373 | return q_with_hist, answers 374 | else: 375 | return q_with_hist, answers[0] 376 | 377 | def trim_conversation_history(self, length): 378 | """Trims the conversation history to the desired length by removing entries from the beginning of the array. 379 | This is the same conversation history prepended to each question to enable deep dialog context, so the shorter 380 | the length the less context the next question will have. 381 | 382 | Args: 383 | length: The desired length to trim the conversation history down to. 384 | """ 385 | while len(self.conversation_history) > length: 386 | self.conversation_history.pop(0) 387 | 388 | def __enter__(self): 389 | return self 390 | 391 | def __exit__(self, exception_type, exception_value, traceback): 392 | try: 393 | self.summary_writer.close() 394 | except: 395 | pass 396 | try: 397 | self.session.close() 398 | except: 399 | pass 400 | 401 | def _build_model(self): 402 | """Create the seq2seq model graph. 403 | 404 | Since TensorFlow's default behavior is deferred execution, none of the tensor objects below actually have values until 405 | session.Run is called to train, validate, or predict a batch of inputs. 406 | 407 | Eager execution was introduced in TF 1.5, but as of now this code does not use it. 408 | """ 409 | with tf.variable_scope("model"): 410 | #Batch size for each batch is infered by looking at the first dimension of the input matrix. 411 | #While batch size is generally defined as a hyperparameter and does not change, in practice it can vary. 412 | #An example of this is if the number of samples in the training set is not evenly divisible by the batch size 413 | #in which case the last batch of each epoch will be smaller than the preset hyperparameter value. 414 | batch_size = tf.shape(self.inputs)[0] 415 | 416 | #encoder 417 | with tf.variable_scope("encoder"): 418 | #encoder_embeddings_matrix is a trainable matrix of values that contain the word embeddings for the input sequences. 419 | # when a word is "embedded", it means that the input to the model is a dense N-dimensional vector that represents the word 420 | # instead of a sparse one-hot encoded vector with the dimension of the word's index in the entire vocabulary set to 1. 421 | # At training time, the dense embedding values that represent each word are updated in the direction of the loss gradient 422 | # just like normal weights. Thus the model learns the contextual relationships between the words (the embedding) along with 423 | # the objective function that depends on the words (the decoding). 424 | encoder_embeddings_matrix_name = "shared_embeddings_matrix" if self.model_hparams.share_embedding else "encoder_embeddings_matrix" 425 | encoder_embeddings_matrix_shape = [self.input_vocabulary.size(), self.model_hparams.encoder_embedding_size] 426 | encoder_embeddings_initial_value = ( 427 | self.input_external_embeddings if self.input_vocabulary.external_embeddings is not None 428 | else tf.random_uniform(encoder_embeddings_matrix_shape, 0, 1)) 429 | encoder_embeddings_matrix = tf.Variable(encoder_embeddings_initial_value, 430 | name = encoder_embeddings_matrix_name, 431 | trainable = self.model_hparams.encoder_embedding_trainable, 432 | expected_shape = encoder_embeddings_matrix_shape) 433 | 434 | #As described above, the sequences of word vocabulary indexes in the inputs matrix are converted to sequences of 435 | #N-dimensional dense vectors, by "looking them up" by index in the encoder_embeddings_matrix. 436 | encoder_embedded_input = tf.nn.embedding_lookup(encoder_embeddings_matrix, self.inputs) 437 | 438 | #Build the encoder RNN 439 | encoder_outputs, encoder_state = self._build_encoder(encoder_embedded_input) 440 | 441 | #Decoder 442 | with tf.variable_scope("decoder") as decoder_scope: 443 | #For description of word embeddings, see comments above on the encoder_embeddings_matrix. 444 | #If the share_embedding flag is set to True, the same matrix is used to embed words in the input and target sequences. 445 | #This is useful to avoid redundency when the same vocabulary is used for the inputs and targets 446 | # (why learn two ways to embed the same words?) 447 | if self.model_hparams.share_embedding: 448 | decoder_embeddings_matrix = encoder_embeddings_matrix 449 | else: 450 | decoder_embeddings_matrix_shape = [self.output_vocabulary.size(), self.model_hparams.decoder_embedding_size] 451 | decoder_embeddings_initial_value = ( 452 | self.output_external_embeddings if self.output_vocabulary.external_embeddings is not None 453 | else tf.random_uniform(decoder_embeddings_matrix_shape, 0, 1)) 454 | decoder_embeddings_matrix = tf.Variable(decoder_embeddings_initial_value, 455 | name = "decoder_embeddings_matrix", 456 | trainable = self.model_hparams.decoder_embedding_trainable, 457 | expected_shape = decoder_embeddings_matrix_shape) 458 | 459 | #Create the attentional decoder cell 460 | decoder_cell, decoder_initial_state = self._build_attention_decoder_cell(encoder_outputs, 461 | encoder_state, 462 | batch_size) 463 | 464 | #Output (projection) layer 465 | weights = tf.truncated_normal_initializer(stddev=0.1) 466 | biases = tf.zeros_initializer() 467 | output_layer = layers_core.Dense(units = self.output_vocabulary.size(), 468 | kernel_initializer = weights, 469 | bias_initializer = biases, 470 | use_bias = True, 471 | name = "output_dense") 472 | 473 | #Build the decoder RNN using the attentional decoder cell and output layer 474 | if self.mode != tf.contrib.learn.ModeKeys.INFER: 475 | #In train / validate mode, the training step and loss are returned. 476 | loss, training_step = self._build_training_decoder(batch_size, 477 | decoder_embeddings_matrix, 478 | decoder_cell, 479 | decoder_initial_state, 480 | decoder_scope, 481 | output_layer) 482 | return loss, training_step 483 | else: 484 | #In inference mode, the predictions and prediction sequence lengths are returned. 485 | #The sequence lengths can differ, but the predictions matrix will be one fixed size. 486 | #The predictions_seq_lengths array can be used to properly read the sequences of variable lengths. 487 | predictions, predictions_seq_lengths = self._build_inference_decoder(batch_size, 488 | decoder_embeddings_matrix, 489 | decoder_cell, 490 | decoder_initial_state, 491 | decoder_scope, 492 | output_layer) 493 | return predictions, predictions_seq_lengths 494 | 495 | def _build_encoder(self, encoder_embedded_input): 496 | """Create the encoder RNN 497 | 498 | Args: 499 | encoder_embedded_input: The embedded input sequences. 500 | """ 501 | keep_prob = self.keep_prob if self.mode == tf.contrib.learn.ModeKeys.TRAIN else None 502 | if self.model_hparams.use_bidirectional_encoder: 503 | #Bi-directional encoding designates one or more RNN cells to read the sequence forward and one or more RNN cells to read 504 | #the sequence backward. The resulting states are concatenated before sending them on to the decoder. 505 | num_bi_layers = int(self.model_hparams.encoder_num_layers / 2) 506 | 507 | encoder_cell_forward = self._create_rnn_cell(self.model_hparams.rnn_size, num_bi_layers, keep_prob) 508 | encoder_cell_backward = self._create_rnn_cell(self.model_hparams.rnn_size, num_bi_layers, keep_prob) 509 | 510 | bi_encoder_outputs, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn( 511 | cell_fw = encoder_cell_forward, 512 | cell_bw = encoder_cell_backward, 513 | sequence_length = self.input_sequence_length, 514 | inputs = encoder_embedded_input, 515 | dtype = tf.float32, 516 | swap_memory=True) 517 | 518 | #Manipulating encoder state to handle multi bidirectional layers 519 | encoder_outputs = tf.concat(bi_encoder_outputs, -1) 520 | 521 | if num_bi_layers == 1: 522 | encoder_state = bi_encoder_state 523 | else: 524 | # alternatively concat forward and backward states 525 | encoder_state = [] 526 | for layer_id in range(num_bi_layers): 527 | encoder_state.append(bi_encoder_state[0][layer_id]) # forward 528 | encoder_state.append(bi_encoder_state[1][layer_id]) # backward 529 | encoder_state = tuple(encoder_state) 530 | 531 | else: 532 | #Uni-directional encoding uses all RNN cells to read the sequence forward. 533 | encoder_cell = self._create_rnn_cell(self.model_hparams.rnn_size, self.model_hparams.encoder_num_layers, keep_prob) 534 | 535 | encoder_outputs, encoder_state = tf.nn.dynamic_rnn( 536 | cell = encoder_cell, 537 | sequence_length = self.input_sequence_length, 538 | inputs = encoder_embedded_input, 539 | dtype = tf.float32, 540 | swap_memory=True) 541 | 542 | return encoder_outputs, encoder_state 543 | 544 | def _build_attention_decoder_cell(self, encoder_outputs, encoder_state, batch_size): 545 | """Create the RNN cell to be used as the decoder and apply an attention mechanism. 546 | 547 | Args: 548 | encoder_outputs: a tensor containing the output of the encoder at each timestep of the input sequence. 549 | this is used as the input to the attention mechanism. 550 | 551 | encoder_state: a tensor containing the final encoder state for each encoder cell after reading the input sequence. 552 | if the encoder and decoder have the same structure, this becomes the decoder initial state. 553 | 554 | batch_size: the batch size tensor 555 | (defined at the beginning of the model graph as the length of the first dimension of the input matrix) 556 | """ 557 | #If beam search decoding - repeat the input sequence length, encoder output, encoder state, and batch size tensors 558 | #once for every beam. 559 | input_sequence_length = self.input_sequence_length 560 | if self.beam_width > 0: 561 | encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier = self.beam_width) 562 | input_sequence_length = tf.contrib.seq2seq.tile_batch(input_sequence_length, multiplier = self.beam_width) 563 | encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier = self.beam_width) 564 | batch_size = batch_size * self.beam_width 565 | 566 | #Construct the attention mechanism 567 | if self.model_hparams.attention_type == "bahdanau" or self.model_hparams.attention_type == "normed_bahdanau": 568 | normalize = self.model_hparams.attention_type == "normed_bahdanau" 569 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units = self.model_hparams.rnn_size, 570 | memory = encoder_outputs, 571 | memory_sequence_length = input_sequence_length, 572 | normalize = normalize) 573 | elif self.model_hparams.attention_type == "luong" or self.model_hparams.attention_type == "scaled_luong": 574 | scale = self.model_hparams.attention_type == "scaled_luong" 575 | attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units = self.model_hparams.rnn_size, 576 | memory = encoder_outputs, 577 | memory_sequence_length = input_sequence_length, 578 | scale = scale) 579 | else: 580 | raise ValueError("Unsupported attention type. Use ('bahdanau' / 'normed_bahdanau') for Bahdanau attention or ('luong' / 'scaled_luong') for Luong attention.") 581 | 582 | #Create the decoder cell and wrap with the attention mechanism 583 | with tf.variable_scope("decoder_cell"): 584 | keep_prob = self.keep_prob if self.mode == tf.contrib.learn.ModeKeys.TRAIN else None 585 | decoder_cell = self._create_rnn_cell(self.model_hparams.rnn_size, self.model_hparams.decoder_num_layers, keep_prob) 586 | 587 | alignment_history = self.mode == tf.contrib.learn.ModeKeys.INFER and self.beam_width == 0 588 | output_attention = self.model_hparams.attention_type == "luong" or self.model_hparams.attention_type == "scaled_luong" 589 | attention_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(cell = decoder_cell, 590 | attention_mechanism = attention_mechanism, 591 | attention_layer_size = self.model_hparams.rnn_size, 592 | alignment_history = alignment_history, 593 | output_attention = output_attention, 594 | name = "attention_decoder_cell") 595 | 596 | #If the encoder and decoder are the same structure, set the decoder initial state to the encoder final state. 597 | decoder_initial_state = attention_decoder_cell.zero_state(batch_size, tf.float32) 598 | if self.model_hparams.encoder_num_layers == self.model_hparams.decoder_num_layers: 599 | decoder_initial_state = decoder_initial_state.clone(cell_state = encoder_state) 600 | 601 | return attention_decoder_cell, decoder_initial_state 602 | 603 | def _build_training_decoder(self, batch_size, decoder_embeddings_matrix, decoder_cell, decoder_initial_state, decoder_scope, output_layer): 604 | """Build the decoder RNN for training mode. 605 | 606 | Currently this is implemented using the TensorFlow TrainingHelper, which uses the Teacher Forcing technique. 607 | "Teacher Forcing" means that at each output timestep, the next word of the target sequence is fed as input 608 | to the decoder without regard to the output prediction at the previous timestep. 609 | 610 | Args: 611 | batch_size: the batch size tensor 612 | (defined at the beginning of the model graph as the length of the first dimension of the input matrix) 613 | 614 | decoder_embeddings_matrix: The matrix containing the decoder embeddings 615 | 616 | decoder_cell: The RNN cell (or cells) used in the decoder. 617 | 618 | decoder_initial_state: The initial cell state of the decoder. This is the final encoder cell state if the encoder 619 | and decoder cells are structured the same. Otherwise it is a memory cell in zero state. 620 | """ 621 | #Prepend each target sequence with the SOS token, as this is always the input of the first decoder timestep. 622 | preprocessed_targets = self._preprocess_targets(batch_size) 623 | #The sequences of word vocabulary indexes in the targets matrix are converted to sequences of 624 | #N-dimensional dense vectors, by "looking them up" by index in the decoder_embeddings_matrix. 625 | decoder_embedded_input = tf.nn.embedding_lookup(decoder_embeddings_matrix, preprocessed_targets) 626 | 627 | #Create the training decoder 628 | helper = tf.contrib.seq2seq.TrainingHelper(inputs = decoder_embedded_input, 629 | sequence_length = self.target_sequence_length) 630 | 631 | decoder = tf.contrib.seq2seq.BasicDecoder(cell = decoder_cell, 632 | helper = helper, 633 | initial_state = decoder_initial_state) 634 | 635 | #Get the decoder output 636 | decoder_output, _final_context_state, _final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder = decoder, 637 | swap_memory = True, 638 | scope = decoder_scope) 639 | #Pass the decoder output through the output dense layer which will become a softmax distribution of classes - 640 | #one class per word in the output vocabulary. 641 | logits = output_layer(decoder_output.rnn_output) 642 | 643 | #Calculate the softmax loss. Since the logits tensor is fixed size and the output sequences are variable length, 644 | #we need to "mask" the timesteps beyond the sequence length for each sequence. This multiplies the loss calculated 645 | #for those extra timesteps by 0, cancelling them out so they do not affect the final sequence loss. 646 | loss_mask = tf.sequence_mask(self.target_sequence_length, tf.shape(self.targets)[1], dtype=tf.float32) 647 | loss = tf.contrib.seq2seq.sequence_loss(logits = logits, 648 | targets = self.targets, 649 | weights = loss_mask) 650 | tf.summary.scalar("sequence_loss", loss) 651 | 652 | #Set up the optimizer 653 | if self.model_hparams.optimizer == "sgd": 654 | optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) 655 | elif self.model_hparams.optimizer == "adam": 656 | optimizer = tf.train.AdamOptimizer(self.learning_rate) 657 | else: 658 | raise ValueError("Unsupported optimizer. Use 'sgd' for GradientDescentOptimizer or 'adam' for AdamOptimizer.") 659 | 660 | tf.summary.scalar("learning_rate", self.learning_rate) 661 | #If max_gradient_norm is provided, enable gradient clipping. 662 | #This prevents an exploding gradient from derailing the training by too much. 663 | if self.model_hparams.max_gradient_norm > 0.0: 664 | params = tf.trainable_variables() 665 | gradients = tf.gradients(loss, params) 666 | clipped_gradients, gradient_norm = tf.clip_by_global_norm(gradients, self.model_hparams.max_gradient_norm) 667 | tf.summary.scalar("gradient_norm", gradient_norm) 668 | tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients)) 669 | training_step = optimizer.apply_gradients(zip(clipped_gradients, params)) 670 | else: 671 | training_step = optimizer.minimize(loss = loss) 672 | 673 | return loss, training_step 674 | 675 | def _build_inference_decoder(self, batch_size, decoder_embeddings_matrix, decoder_cell, decoder_initial_state, decoder_scope, output_layer): 676 | """Build the decoder RNN for inference mode. 677 | 678 | In inference mode, the decoder takes the output of each timestep and feeds it in as the input to the next timestep 679 | and repeats this process until either an EOS token is generated or the maximum sequence length is reached. 680 | 681 | If beam_width > 0, beam search decoding is used. 682 | Beam search will sample from the top most likely words at each timestep and branch off (create a beam) 683 | on one or more of these words to explore a different version of the output sequence. For example: 684 | Question: How are you? 685 | Answer (beam 1): I am fine. 686 | Answer (beam 2): I am doing well. 687 | Answer (beam 3): I am doing alright considering the circumstances. 688 | Answer (beam 4): Good and yourself? 689 | Answer (beam 5): Good good. 690 | Etc... 691 | 692 | If beam_width = 0, greedy or sampling decoding is used. 693 | 694 | Args: 695 | See _build_training_decoder 696 | """ 697 | #Get the SOS and EOS tokens 698 | start_tokens = tf.fill([batch_size], self.output_vocabulary.sos_int()) 699 | end_token = self.output_vocabulary.eos_int() 700 | 701 | #Build the beam search, greedy, or sampling decoder 702 | if self.beam_width > 0: 703 | decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell = decoder_cell, 704 | embedding = decoder_embeddings_matrix, 705 | start_tokens = start_tokens, 706 | end_token = end_token, 707 | initial_state = decoder_initial_state, 708 | beam_width = self.beam_width, 709 | output_layer = output_layer, 710 | length_penalty_weight = self.beam_length_penalty_weight) 711 | else: 712 | if self.model_hparams.enable_sampling: 713 | helper = tf.contrib.seq2seq.SampleEmbeddingHelper(embedding = decoder_embeddings_matrix, 714 | start_tokens = start_tokens, 715 | end_token = end_token, 716 | softmax_temperature = self.sampling_temperature) 717 | else: 718 | helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding = decoder_embeddings_matrix, 719 | start_tokens = start_tokens, 720 | end_token = end_token) 721 | 722 | decoder = tf.contrib.seq2seq.BasicDecoder(cell = decoder_cell, 723 | helper = helper, 724 | initial_state = decoder_initial_state, 725 | output_layer = output_layer) 726 | 727 | #Get the decoder output 728 | decoder_output, final_context_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder = decoder, 729 | maximum_iterations = self.max_output_sequence_length, 730 | swap_memory = True, 731 | scope = decoder_scope) 732 | 733 | #Return the predicted sequences along with an array of the sequence lengths for each predicted sequence in the batch 734 | if self.beam_width > 0: 735 | predictions = decoder_output.predicted_ids 736 | predictions_seq_lengths = final_context_state.lengths 737 | else: 738 | predictions = decoder_output.sample_id 739 | predictions_seq_lengths = final_sequence_lengths 740 | #Create attention alignment summary for visualization in TensorBoard 741 | self._create_attention_images_summary(final_context_state) 742 | 743 | return predictions, predictions_seq_lengths 744 | 745 | def _create_rnn_cell(self, rnn_size, num_layers, keep_prob): 746 | """Create a single RNN cell or stack of RNN cells (depending on num_layers) 747 | 748 | Args: 749 | rnn_size: number of units (neurons) in each RNN cell 750 | 751 | num_layers: number of stacked RNN cells to create 752 | 753 | keep_prob: probability of not being dropped out (1 - dropout), or None for inference mode 754 | """ 755 | 756 | cells = [] 757 | for _ in range(num_layers): 758 | if self.model_hparams.rnn_cell_type == "lstm": 759 | rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=rnn_size) 760 | elif self.model_hparams.rnn_cell_type == "gru": 761 | rnn_cell = tf.nn.rnn_cell.GRUCell(num_units=rnn_size) 762 | else: 763 | raise ValueError("Unsupported RNN cell type. Use 'lstm' for LSTM or 'gru' for GRU.") 764 | 765 | if keep_prob is not None: 766 | rnn_cell = tf.contrib.rnn.DropoutWrapper(cell=rnn_cell, input_keep_prob=keep_prob) 767 | 768 | cells.append(rnn_cell) 769 | 770 | if len(cells) == 1: 771 | return cells[0] 772 | else: 773 | return tf.contrib.rnn.MultiRNNCell(cells = cells) 774 | 775 | def _preprocess_targets(self, batch_size): 776 | """Prepend the SOS token to all target sequences in the batch 777 | 778 | Args: 779 | batch_size: the batch size tensor 780 | (defined at the beginning of the model graph as the length of the first dimension of the input matrix) 781 | """ 782 | left_side = tf.fill([batch_size, 1], self.output_vocabulary.sos_int()) 783 | right_side = tf.strided_slice(self.targets, [0,0], [batch_size, -1], [1,1]) 784 | preprocessed_targets = tf.concat([left_side, right_side], 1) 785 | return preprocessed_targets 786 | 787 | def _create_session(self): 788 | """Initialize the TensorFlow session 789 | """ 790 | if self.model_hparams.gpu_dynamic_memory_growth: 791 | config = tf.ConfigProto() 792 | config.gpu_options.allow_growth = True 793 | session = tf.Session(config=config) 794 | else: 795 | session = tf.Session() 796 | 797 | return session 798 | 799 | def _create_attention_images_summary(self, final_context_state): 800 | """Create attention image and attention summary. 801 | 802 | TODO: this method was taken as is from the NMT tutorial and does not seem to work with the current model. 803 | figure this out and adjust as needed to the chatbot model. 804 | 805 | Args: 806 | final_context_state: final state of the decoder 807 | """ 808 | attention_images = (final_context_state.alignment_history.stack()) 809 | # Reshape to (batch, src_seq_len, tgt_seq_len,1) 810 | attention_images = tf.expand_dims( 811 | tf.transpose(attention_images, [1, 2, 0]), -1) 812 | # Scale to range [0, 255] 813 | attention_images *= 255 814 | tf.summary.image("attention_images", attention_images) -------------------------------------------------------------------------------- /chatbot_screenshots/seq1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq1.png -------------------------------------------------------------------------------- /chatbot_screenshots/seq10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq10.png -------------------------------------------------------------------------------- /chatbot_screenshots/seq11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq11.png -------------------------------------------------------------------------------- /chatbot_screenshots/seq2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq2.png -------------------------------------------------------------------------------- /chatbot_screenshots/seq3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq3.png -------------------------------------------------------------------------------- /chatbot_screenshots/seq4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq4.png -------------------------------------------------------------------------------- /chatbot_screenshots/seq5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq5.png -------------------------------------------------------------------------------- /chatbot_screenshots/seq6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq6.png -------------------------------------------------------------------------------- /chatbot_screenshots/seq7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq7.png -------------------------------------------------------------------------------- /chatbot_screenshots/seq8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq8.png -------------------------------------------------------------------------------- /chatbot_screenshots/seq9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq9.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset class 3 | """ 4 | import math 5 | import random 6 | import numpy as np 7 | from os import path 8 | 9 | class Dataset(object): 10 | """Class representing a chatbot dataset with questions, answers, and vocabulary. 11 | """ 12 | 13 | def __init__(self, questions, answers, input_vocabulary, output_vocabulary): 14 | """Initializes a Dataset instance with a list of questions, answers, and input/output vocabularies. 15 | 16 | Args: 17 | questions: Can be a list of questions as space delimited sentence(s) of words 18 | or a list of lists of integer encoded words 19 | 20 | answers: Can be a list of answers as space delimited sentence(s) of words 21 | or a list of lists of integer encoded words 22 | 23 | input_vocabulary: The Vocabulary instance to use for encoding questions 24 | 25 | output_vocabulary: The Vocabulary instance to use for encoding answers 26 | """ 27 | if len(questions) != len (answers): 28 | raise RuntimeError("questions and answers lists must be the same length, as they are lists of input-output pairs.") 29 | 30 | self.input_vocabulary = input_vocabulary 31 | self.output_vocabulary = output_vocabulary 32 | #If the questions and answers are already integer encoded, accept them as is. 33 | #Otherwise use the Vocabulary instances to encode the question and answer sequences. 34 | if len(questions) > 0 and isinstance(questions[0], str): 35 | self.questions_into_int = [self.input_vocabulary.words2ints(q) for q in questions] 36 | self.answers_into_int = [self.output_vocabulary.words2ints(a) for a in answers] 37 | else: 38 | self.questions_into_int = questions 39 | self.answers_into_int = answers 40 | 41 | def size(self): 42 | """ The size (number of samples) of the Dataset. 43 | """ 44 | return len(self.questions_into_int) 45 | 46 | def train_val_split(self, val_percent = 20, random_split = True, move_samples = True): 47 | """Splits the dataset into training and validation sets. 48 | 49 | Args: 50 | val_percent: the percentage of the dataset to use as validation data. 51 | 52 | random_split: True to split the dataset randomly. 53 | False to split the dataset sequentially (validation samples are the last N samples, where N = samples * (val_percent / 100)) 54 | 55 | move_samples: True to physically move the samples into the returned training and validation dataset objects (saves memory). 56 | False to copy the samples into the returned training and validation dataset objects, and preserve this dataset instance. 57 | """ 58 | 59 | if move_samples: 60 | questions = self.questions_into_int 61 | answers = self.answers_into_int 62 | else: 63 | questions = self.questions_into_int[:] 64 | answers = self.answers_into_int[:] 65 | 66 | num_validation_samples = int(len(questions) * (val_percent / 100)) 67 | num_training_samples = len(questions) - num_validation_samples 68 | 69 | training_questions = [] 70 | training_answers = [] 71 | validation_questions = [] 72 | validation_answers = [] 73 | if random_split: 74 | for _ in range(num_validation_samples): 75 | random_index = random.randint(0, len(questions) - 1) 76 | validation_questions.append(questions.pop(random_index)) 77 | validation_answers.append(answers.pop(random_index)) 78 | 79 | for _ in range(num_training_samples): 80 | training_questions.append(questions.pop(0)) 81 | training_answers.append(answers.pop(0)) 82 | else: 83 | for _ in range(num_training_samples): 84 | training_questions.append(questions.pop(0)) 85 | training_answers.append(answers.pop(0)) 86 | 87 | for _ in range(num_validation_samples): 88 | validation_questions.append(questions.pop(0)) 89 | validation_answers.append(answers.pop(0)) 90 | 91 | training_dataset = Dataset(training_questions, training_answers, self.input_vocabulary, self.output_vocabulary) 92 | validation_dataset = Dataset(validation_questions, validation_answers, self.input_vocabulary, self.output_vocabulary) 93 | 94 | return training_dataset, validation_dataset 95 | 96 | def sort(self): 97 | """Sorts the dataset by the lengths of the questions. This can speed up training by reducing the 98 | amount of padding the input sequences need. 99 | """ 100 | if self.size() > 0: 101 | self.questions_into_int, self.answers_into_int = zip(*sorted(zip(self.questions_into_int, self.answers_into_int), 102 | key = lambda qa_pair: len(qa_pair[0]))) 103 | 104 | def save(self, filepath): 105 | """Saves the dataset questions & answers exactly as represented by input_vocabulary and output_vocabulary. 106 | """ 107 | filename, ext = path.splitext(filepath) 108 | questions_filepath = "{0}_questions{1}".format(filename, ext) 109 | answers_filepath = "{0}_answers{1}".format(filename, ext) 110 | 111 | with open(questions_filepath, mode="w", encoding="utf-8") as file: 112 | for question_into_int in self.questions_into_int: 113 | question = self.input_vocabulary.ints2words(question_into_int, is_punct_discrete_word = True, capitalize_i = False) 114 | file.write(question) 115 | file.write('\n') 116 | 117 | with open(answers_filepath, mode="w", encoding="utf-8") as file: 118 | for answer_into_int in self.answers_into_int: 119 | answer = self.output_vocabulary.ints2words(answer_into_int, is_punct_discrete_word = True, capitalize_i = False) 120 | file.write(answer) 121 | file.write('\n') 122 | 123 | 124 | 125 | def batches(self, batch_size): 126 | """Provide the dataset as an enumerable collection of batches of size batch_size. 127 | Each batch will be a matrix of a fixed shape (batch_size, max_seq_length_in_batch). 128 | Sequences that are shorter than the largest one are padded at the end with the PAD token. 129 | Padding is largely just used as a placeholder since the dyamic encoder and decoder RNNs 130 | will never see the padded timesteps. 131 | 132 | Args: 133 | batch_size: size of each batch. 134 | If the total number of samples is not evenly divisible by batch_size, the last batch will contain the remainder 135 | which will be less than batch_size. 136 | 137 | Returns: 138 | padded_questions_in_batch: A list of padded, integer-encoded question sequences. 139 | 140 | padded_answers_in_batch: A list of padded, integer-encoded answer sequences. 141 | 142 | seqlen_questions_in_batch: A list of actual sequence lengths for each question in the batch. 143 | 144 | seqlen_answers_in_batch: A list of actual sequence lengths for each answer in the batch. 145 | 146 | """ 147 | for batch_index in range(0, math.ceil(len(self.questions_into_int) / batch_size)): 148 | start_index = batch_index * batch_size 149 | questions_in_batch = self.questions_into_int[start_index : start_index + batch_size] 150 | answers_in_batch = self.answers_into_int[start_index : start_index + batch_size] 151 | 152 | seqlen_questions_in_batch = np.array([len(q) for q in questions_in_batch]) 153 | seqlen_answers_in_batch = np.array([len(a) for a in answers_in_batch]) 154 | 155 | padded_questions_in_batch = np.array(self._apply_padding(questions_in_batch, self.input_vocabulary)) 156 | padded_answers_in_batch = np.array(self._apply_padding(answers_in_batch, self.output_vocabulary)) 157 | 158 | yield padded_questions_in_batch, padded_answers_in_batch, seqlen_questions_in_batch, seqlen_answers_in_batch 159 | 160 | 161 | def _apply_padding(self, batch_of_sequences, vocabulary): 162 | """Padding the sequences with the token to ensure all sequences in the batch are the same physical size. 163 | 164 | Input and target sequences can be any length, but each batch that is fed into the model 165 | must be defined as a fixed size matrix of shape (batch_size, max_seq_length_in_batch). 166 | 167 | Padding allows for this dynamic sequence length within a fixed matrix. 168 | However, the actual padded sequence timesteps are never seen by the encoder or decoder nor are they 169 | counted toward the softmax loss. This is possible since we provide the actual sequence lengths to the model 170 | as a separate vector of shape (batch_size), where the RNNs are instructed to only unroll for the number of 171 | timesteps that have real (unpadded) values. The sequence loss also accepts a masking weight matrix where 172 | we can specify that loss values for padded timesteps should be ignored. 173 | 174 | Args: 175 | batch_of_sequences: list of integer-encoded sequences to pad 176 | 177 | vocabulary: Vocabulary instance to use to look up the integer encoding of the PAD token 178 | """ 179 | max_sequence_length = max([len(sequence) for sequence in batch_of_sequences]) 180 | return [sequence + ([vocabulary.pad_int()] * (max_sequence_length - len(sequence))) for sequence in batch_of_sequences] -------------------------------------------------------------------------------- /dataset_readers/cornell_dataset_reader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reader class for the Cornell movie dialog dataset 3 | """ 4 | from os import path 5 | 6 | from dataset_readers.dataset_reader import DatasetReader 7 | 8 | class CornellDatasetReader(DatasetReader): 9 | """Reader implementation for the Cornell movie dialog dataset 10 | """ 11 | def __init__(self): 12 | super(CornellDatasetReader, self).__init__("cornell_movie_dialog") 13 | 14 | def _get_dialog_lines_and_conversations(self, dataset_dir): 15 | """Get dialog lines and conversations. See base class for explanation. 16 | 17 | Args: 18 | See base class 19 | """ 20 | movie_lines_filepath = path.join(dataset_dir, "movie_lines.txt") 21 | movie_conversations_filepath = path.join(dataset_dir, "movie_conversations.txt") 22 | 23 | # Importing the dataset 24 | with open(movie_lines_filepath, encoding="utf-8", errors="ignore") as file: 25 | lines = file.read().split("\n") 26 | 27 | with open(movie_conversations_filepath, encoding="utf-8", errors="ignore") as file: 28 | conversations = file.read().split("\n") 29 | 30 | # Creating a dictionary that maps each line and its id 31 | id2line = {} 32 | for line in lines: 33 | _line = line.split(" +++$+++ ") 34 | if len(_line) == 5: 35 | id2line[_line[0]] = _line[4] 36 | 37 | # Creating a list of all of the conversations 38 | conversations_ids = [] 39 | for conversation in conversations[:-1]: 40 | _conversation = conversation.split(" +++$+++ ")[-1][1:-1].replace("'", "").replace(" ", "") 41 | conv_ids = _conversation.split(",") 42 | conversations_ids.append(conv_ids) 43 | 44 | return id2line, conversations_ids 45 | -------------------------------------------------------------------------------- /dataset_readers/csv_dataset_reader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reader class for generic CSV question-answer datasets 3 | """ 4 | from os import path 5 | import pandas as pd 6 | 7 | from dataset_readers.dataset_reader import DatasetReader 8 | 9 | class CSVDatasetReader(DatasetReader): 10 | """Reader implementation for generic CSV question-answer datasets 11 | """ 12 | def __init__(self): 13 | super(CSVDatasetReader, self).__init__("csv") 14 | 15 | def _get_dialog_lines_and_conversations(self, dataset_dir): 16 | """Get dialog lines and conversations. See base class for explanation. 17 | Args: 18 | See base class 19 | """ 20 | csv_filepath = path.join(dataset_dir, "csv_data.csv") 21 | 22 | # Importing the dataset 23 | dataset = pd.read_csv(csv_filepath, dtype=str, na_filter=False) 24 | questions = dataset.iloc[:, 0].values 25 | answers = dataset.iloc[:, 1].values 26 | 27 | # Creating a dictionary that maps each line and its id 28 | conversations_ids = [] 29 | id2line = {} 30 | for i in range(len(questions)): 31 | question = questions[i].strip() 32 | answer = answers[i].strip() 33 | if question != '' and answer != '': 34 | q_line_id = "{}_q".format(i) 35 | a_line_id = "{}_a".format(i) 36 | id2line[q_line_id] = question 37 | id2line[a_line_id] = answer 38 | conversations_ids.append([q_line_id, a_line_id]) 39 | 40 | return id2line, conversations_ids -------------------------------------------------------------------------------- /dataset_readers/dataset_reader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for dataset readers 3 | """ 4 | import abc 5 | from os import path 6 | 7 | from vocabulary_importers import vocabulary_importer_factory 8 | from vocabulary_importers.vocabulary_importer import VocabularyImportMode 9 | from vocabulary import Vocabulary 10 | from dataset import Dataset 11 | 12 | class DatasetReadStats(object): 13 | """Contains information about the read dataset. 14 | """ 15 | 16 | def __init__(self): 17 | self.input_vocabulary_import_stats = None 18 | self.output_vocabulary_import_stats = None 19 | 20 | class DatasetReader(object): 21 | """Base class for dataset readers 22 | """ 23 | 24 | def __init__(self, dataset_name): 25 | """Initialize the DatasetReader. 26 | 27 | Args: 28 | dataset_name: Name of the dataset. Subclass must pass this in. 29 | """ 30 | self.dataset_name = dataset_name 31 | 32 | @abc.abstractmethod 33 | def _get_dialog_lines_and_conversations(self, dataset_dir): 34 | """Subclass must implement this 35 | 36 | Read the raw dataset files and extract a dictionary of dialog lines and a list of conversations. 37 | A conversation is a list of dictionary keys for dialog lines that sequentially form a conversation. 38 | 39 | Args: 40 | dataset_dir: directory to load the raw dataset file(s) from 41 | """ 42 | pass 43 | 44 | def read_dataset(self, dataset_dir, model_dir, training_hparams, share_vocab = True, encoder_embeddings_dir = None, decoder_embeddings_dir = None): 45 | """Read and return a chatbot dataset based on the specified dataset 46 | 47 | Args: 48 | dataset_dir: directory to load the raw dataset file(s) from 49 | 50 | model_dir: directory to save the vocabulary to 51 | 52 | training_hparams: training parameters which determine how the dataset will be read. 53 | See hparams.py for in-depth comments. 54 | 55 | share_vocab: True to generate a single vocabulary file from the question and answer words. 56 | False to generate separate input and output vocabulary files, from the question and answer words respectively. 57 | (If training_hparams.conv_history_length > 0, share_vocab should be set to True since previous answers will be appended to the questions. 58 | This could cause many of these previous answer words to map to when looking up against the input vocabulary. 59 | An exception to this is if the output vocabulary is a subset of the input vocaulary.) 60 | 61 | encoder_embeddings_dir: Path to directory containing external embeddings to import for the encoder. 62 | If this is specified, the input vocabulary will be loaded from this source and optionally joined with the generated 63 | dataset vocabulary (see training_hparams.input_vocab_import_mode) 64 | If share_vocab is True, the imported vocabulary is used for both input and output. 65 | 66 | decoder_embeddings_dir: Path to directory containing external embeddings to import for the decoder. 67 | If this is specified, the output vocabulary will be loaded from this source and optionally joined with the generated 68 | dataset vocabulary (see training_hparams.output_vocab_import_mode) 69 | If share_vocab is True, this argument must be None or the same as encoder_embeddings_dir (both are equivalent). 70 | """ 71 | 72 | if share_vocab: 73 | if training_hparams.input_vocab_threshold != training_hparams.output_vocab_threshold and (encoder_embeddings_dir is None or training_hparams.input_vocab_import_mode != VocabularyImportMode.External): 74 | raise ValueError("Cannot share generated or joined imported vocabulary when the input and output vocab thresholds are different.") 75 | if encoder_embeddings_dir is not None: 76 | if training_hparams.input_vocab_import_mode != training_hparams.output_vocab_import_mode: 77 | raise ValueError("Cannot share imported vocabulary when input and output vocab import modes are different.") 78 | if training_hparams.input_vocab_import_normalized != training_hparams.output_vocab_import_normalized: 79 | raise ValueError("Cannot share imported vocabulary when input and output normalization modes are different.") 80 | if decoder_embeddings_dir is not None and decoder_embeddings_dir != encoder_embeddings_dir: 81 | raise ValueError("Cannot share imported vocabulary from two different sources or share import and generated vocabulary.") 82 | 83 | 84 | read_stats = DatasetReadStats() 85 | 86 | #Get dialog line and conversation collections 87 | id2line, conversations_ids = self._get_dialog_lines_and_conversations(dataset_dir) 88 | 89 | #Clean dialog lines 90 | for line_id in id2line: 91 | id2line[line_id] = Vocabulary.clean_text(id2line[line_id], training_hparams.max_question_answer_words, training_hparams.normalize_words) 92 | 93 | #Output cleaned lines for debugging purposes 94 | if training_hparams.log_cleaned_dataset: 95 | self._log_cleaned_dataset(model_dir, id2line.values()) 96 | 97 | # Getting separately the questions and the answers 98 | questions_for_count = [] 99 | questions = [] 100 | answers = [] 101 | for conversation in conversations_ids[:training_hparams.max_conversations]: 102 | for i in range(len(conversation) - 1): 103 | conv_up_to_question = '' 104 | for j in range(max(0, i - training_hparams.conv_history_length), i): 105 | conv_up_to_question += id2line[conversation[j]] + " {0} ".format(Vocabulary.EOS) 106 | question = id2line[conversation[i]] 107 | question_with_history = conv_up_to_question + question 108 | answer = id2line[conversation[i+1]] 109 | if training_hparams.min_question_words <= len(question_with_history.split()): 110 | questions.append(conv_up_to_question + question) 111 | questions_for_count.append(question) 112 | answers.append(answer) 113 | 114 | # Create the vocabulary object & add the question & answer words 115 | if share_vocab: 116 | questions_and_answers = [] 117 | for i in range(len(questions_for_count)): 118 | question = questions_for_count[i] 119 | answer = answers[i] 120 | if i == 0 or question != answers[i - 1]: 121 | questions_and_answers.append(question) 122 | questions_and_answers.append(answer) 123 | 124 | input_vocabulary, read_stats.input_vocabulary_import_stats = self._create_and_save_vocab(questions_and_answers, 125 | training_hparams.input_vocab_threshold, 126 | model_dir, 127 | Vocabulary.SHARED_VOCAB_FILENAME, 128 | encoder_embeddings_dir, 129 | training_hparams.input_vocab_import_normalized, 130 | training_hparams.input_vocab_import_mode) 131 | output_vocabulary = input_vocabulary 132 | read_stats.output_vocabulary_import_stats = read_stats.input_vocabulary_import_stats 133 | else: 134 | input_vocabulary, read_stats.input_vocabulary_import_stats = self._create_and_save_vocab(questions_for_count, 135 | training_hparams.input_vocab_threshold, 136 | model_dir, 137 | Vocabulary.INPUT_VOCAB_FILENAME, 138 | encoder_embeddings_dir, 139 | training_hparams.input_vocab_import_normalized, 140 | training_hparams.input_vocab_import_mode) 141 | 142 | output_vocabulary, read_stats.output_vocabulary_import_stats = self._create_and_save_vocab(answers, 143 | training_hparams.output_vocab_threshold, 144 | model_dir, 145 | Vocabulary.OUTPUT_VOCAB_FILENAME, 146 | decoder_embeddings_dir, 147 | training_hparams.output_vocab_import_normalized, 148 | training_hparams.output_vocab_import_mode) 149 | 150 | # Adding the End Of String tokens to the end of every answer 151 | for i in range(len(answers)): 152 | answers[i] += " {0}".format(Vocabulary.EOS) 153 | 154 | #Create the Dataset object from the questions / answers lists and the vocab object. 155 | dataset = Dataset(questions, answers, input_vocabulary, output_vocabulary) 156 | 157 | return dataset, read_stats 158 | 159 | def _create_and_save_vocab(self, word_sequences, vocab_threshold, model_dir, vocab_filename, embeddings_dir, normalize_imported_vocab, vocab_import_mode): 160 | """Create a Vocabulary instance from a list of word sequences, and save it to disk. 161 | 162 | Args: 163 | word_sequences: List of word sequences (sentence(s)) to use as basis for the vocabulary. 164 | 165 | vocab_threshold: Minimum number of times any word must appear within word_sequences 166 | in order to be included in the vocabulary. 167 | 168 | model_dir: directory to save the vocabulary file to 169 | 170 | vocab_filename: file name of the vocabulary file 171 | 172 | embeddings_dir: Optional directory to import external vocabulary & embeddings 173 | If provided, the external vocabulary will be imported and processed according to the vocab_import_mode. 174 | If None, only the generated vocabulary will be used. 175 | 176 | normalize_imported_vocab: See VocabularyImporter.import_vocabulary 177 | 178 | vocab_import_mode: If embeddings_dir is specified, this flag indicates if the dataset vocabulary should be generated 179 | and used in combination with the external vocabulary according to the rules of VocabularyImportMode. 180 | """ 181 | vocabulary = None 182 | if embeddings_dir is None or vocab_import_mode != VocabularyImportMode.External: 183 | vocabulary = Vocabulary() 184 | for i in range(len(word_sequences)): 185 | word_seq = word_sequences[i] 186 | vocabulary.add_words(word_seq.split()) 187 | vocabulary.compile(vocab_threshold) 188 | 189 | vocabulary_import_stats = None 190 | if embeddings_dir is not None: 191 | vocabulary_importer = vocabulary_importer_factory.get_vocabulary_importer(embeddings_dir) 192 | vocabulary, vocabulary_import_stats = vocabulary_importer.import_vocabulary(embeddings_dir, 193 | normalize_imported_vocab, 194 | vocab_import_mode, 195 | vocabulary) 196 | 197 | vocab_filepath = path.join(model_dir, vocab_filename) 198 | vocabulary.save(vocab_filepath) 199 | return vocabulary, vocabulary_import_stats 200 | 201 | def _log_cleaned_dataset(self, model_dir, lines): 202 | """Write the cleaned dataset to disk 203 | """ 204 | log_filepath = path.join(model_dir, "cleaned_dataset.txt") 205 | with open(log_filepath, mode="w", encoding="utf-8") as file: 206 | for line in lines: 207 | file.write(line) 208 | file.write('\n') -------------------------------------------------------------------------------- /dataset_readers/dataset_reader_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset reader implementation factory 3 | """ 4 | from os import path 5 | from dataset_readers.cornell_dataset_reader import CornellDatasetReader 6 | from dataset_readers.csv_dataset_reader import CSVDatasetReader 7 | 8 | def get_dataset_reader(dataset_dir): 9 | """Gets the appropriate reader implementation for the specified dataset name. 10 | 11 | Args: 12 | dataset_dir: The directory of the dataset to get a reader implementation for. 13 | """ 14 | dataset_name = path.basename(dataset_dir) 15 | 16 | #When adding support for new datasets, add an instance of their reader class to the reader array below. 17 | readers = [CornellDatasetReader(), CSVDatasetReader()] 18 | 19 | for reader in readers: 20 | if reader.dataset_name == dataset_name: 21 | return reader 22 | 23 | raise ValueError("There is no dataset reader implementation for '{0}'. If this is a new dataset, please add one!".format(dataset_name)) 24 | -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/README.md: -------------------------------------------------------------------------------- 1 | # Cornell Movie Dialog Corpus 2 | 3 | "This corpus contains a large metadata-rich collection of fictional conversations extracted from raw movie scripts: 4 | 5 | - 220,579 conversational exchanges between 10,292 pairs of movie characters 6 | 7 | - involves 9,035 characters from 617 movies 8 | 9 | - in total 304,713 utterances" 10 | 11 | Source: [https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html) -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/movie_lines.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/datasets/cornell_movie_dialog/movie_lines.txt -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/train_with_dependency_based_embeddings.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\dependency_based --decoderembeddingsdir=embeddings\dependency_based 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/train_with_dependency_based_embeddings_decoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\cornell_movie_dialog --decoderembeddingsdir=embeddings\dependency_based 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/train_with_dependency_based_embeddings_encoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\dependency_based 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/train_with_nnlm_en_embeddings.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\nnlm_en --decoderembeddingsdir=embeddings\nnlm_en 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/train_with_nnlm_en_embeddings_decoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\cornell_movie_dialog --decoderembeddingsdir=embeddings\nnlm_en 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/train_with_nnlm_en_embeddings_encoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\nnlm_en 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/train_with_random_embeddings.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\cornell_movie_dialog 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/train_with_word2vec_wikipedia_embeddings.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\word2vec_wikipedia --decoderembeddingsdir=embeddings\word2vec_wikipedia 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/train_with_word2vec_wikipedia_embeddings_decoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\cornell_movie_dialog --decoderembeddingsdir=embeddings\word2vec_wikipedia 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/cornell_movie_dialog/train_with_word2vec_wikipedia_embeddings_encoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\word2vec_wikipedia 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/csv/README.md: -------------------------------------------------------------------------------- 1 | # CSV question-answer dataset 2 | A generic CSV dataset with two columns - question and answer. 3 | 4 | ## Instructions: 5 | 6 | 1) Drop a CSV file here with the name "csv_data.csv" (replace the sample included in this repo with your own data). 7 | 8 | 2) The CSV must have two columns - the first one is for questions and the second one is for answers (responses). 9 | 10 | 3) Train using the included batch files or the command line. 11 | -------------------------------------------------------------------------------- /datasets/csv/csv_data.csv: -------------------------------------------------------------------------------- 1 | question,answer 2 | your question,your answer 3 | what is your name,bob -------------------------------------------------------------------------------- /datasets/csv/train_with_dependency_based_embeddings.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\dependency_based --decoderembeddingsdir=embeddings\dependency_based 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/csv/train_with_dependency_based_embeddings_decoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\csv --decoderembeddingsdir=embeddings\dependency_based 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/csv/train_with_dependency_based_embeddings_encoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\dependency_based 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/csv/train_with_nnlm_en_embeddings.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\nnlm_en --decoderembeddingsdir=embeddings\nnlm_en 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/csv/train_with_nnlm_en_embeddings_decoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\csv --decoderembeddingsdir=embeddings\nnlm_en 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/csv/train_with_nnlm_en_embeddings_encoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\nnlm_en 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/csv/train_with_random_embeddings.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\csv 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/csv/train_with_word2vec_wikipedia_embeddings.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\word2vec_wikipedia --decoderembeddingsdir=embeddings\word2vec_wikipedia 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/csv/train_with_word2vec_wikipedia_embeddings_decoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\csv --decoderembeddingsdir=embeddings\word2vec_wikipedia 4 | 5 | cmd /k -------------------------------------------------------------------------------- /datasets/csv/train_with_word2vec_wikipedia_embeddings_encoder_only.bat: -------------------------------------------------------------------------------- 1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3 2 | cd ..\.. 3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\word2vec_wikipedia 4 | 5 | cmd /k -------------------------------------------------------------------------------- /general_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utility methods 3 | """ 4 | import os 5 | import argparse 6 | import datetime 7 | import platform 8 | from shutil import copyfile 9 | 10 | from hparams import Hparams 11 | 12 | def initialize_session(mode): 13 | """Helper method for initializing a chatbot training session 14 | by loading the model dir from command line args and reading the hparams in 15 | 16 | Args: 17 | mode: "train" or "chat" 18 | """ 19 | parser = argparse.ArgumentParser("Train a chatbot model" if mode == "train" else "Chat with a trained chatbot model") 20 | if mode == "train": 21 | ex_group = parser.add_mutually_exclusive_group(required=True) 22 | ex_group.add_argument("--datasetdir", "-d", help="Path structured as datasets/dataset_name. A new model will be trained using the dataset contained in this directory.") 23 | ex_group.add_argument("--checkpointfile", "-c", help="Path structured as 'models/dataset_name/model_name/checkpoint_name.ckpt'. Training will resume from the selected checkpoint. The hparams.json file should exist in the same directory as the checkpoint.") 24 | em_group = parser.add_argument_group() 25 | em_group.add_argument("--encoderembeddingsdir", "--embeddingsdir", "-e", help="Path structured as embeddings/embeddings_name. Encoder (& Decoder if shared) vocabulary and embeddings will be initialized from the checkpoint file and tokens file contained in this directory.") 26 | em_group.add_argument("--decoderembeddingsdir", help="Path structured as embeddings/embeddings_name. Decoder vocabulary and embeddings will be initialized from the checkpoint file and tokens file contained in this directory.") 27 | elif mode == "chat": 28 | parser.add_argument("checkpointfile", help="Path structured as 'models/dataset_name/model_name/checkpoint_name.ckpt'. The hparams.json file and the vocabulary file(s) should exist in the same directory as the checkpoint.") 29 | else: 30 | raise ValueError("Unsupported session mode. Choose 'train' or 'chat'.") 31 | args = parser.parse_args() 32 | 33 | #Make sure script was run in the correct working directory 34 | models_dir = "models" 35 | datasets_dir = "datasets" 36 | if not os.path.isdir(models_dir) or not os.path.isdir(datasets_dir): 37 | raise NotADirectoryError("Cannot find models directory 'models' and datasets directory 'datasets' within working directory '{0}'. Make sure to set the working directory to the chatbot root folder." 38 | .format(os.getcwd())) 39 | 40 | encoder_embeddings_dir = decoder_embeddings_dir = None 41 | if mode == "train": 42 | #If provided, make sure the embeddings exist 43 | if args.encoderembeddingsdir: 44 | encoder_embeddings_dir = os.path.relpath(args.encoderembeddingsdir) 45 | if not os.path.isdir(encoder_embeddings_dir): 46 | raise NotADirectoryError("Cannot find embeddings directory '{0}'".format(os.path.realpath(encoder_embeddings_dir))) 47 | if args.decoderembeddingsdir: 48 | decoder_embeddings_dir = os.path.relpath(args.decoderembeddingsdir) 49 | if not os.path.isdir(decoder_embeddings_dir): 50 | raise NotADirectoryError("Cannot find embeddings directory '{0}'".format(os.path.realpath(decoder_embeddings_dir))) 51 | 52 | if mode == "train" and args.datasetdir: 53 | #Make sure dataset exists 54 | dataset_dir = os.path.relpath(args.datasetdir) 55 | if not os.path.isdir(dataset_dir): 56 | raise NotADirectoryError("Cannot find dataset directory '{0}'".format(os.path.realpath(dataset_dir))) 57 | #Create the new model directory 58 | dataset_name = os.path.basename(dataset_dir) 59 | model_dir = os.path.join("models", dataset_name, datetime.datetime.now().strftime("%Y%m%d_%H%M%S")) 60 | os.makedirs(model_dir, exist_ok=True) 61 | copyfile("hparams.json", os.path.join(model_dir, "hparams.json")) 62 | checkpoint = None 63 | elif args.checkpointfile: 64 | #Make sure checkpoint file & hparams file exists 65 | checkpoint_filepath = os.path.relpath(args.checkpointfile) 66 | if not os.path.isfile(checkpoint_filepath + ".meta"): 67 | raise FileNotFoundError("The checkpoint file '{0}' was not found.".format(os.path.realpath(checkpoint_filepath))) 68 | #Get the checkpoint model directory 69 | checkpoint = os.path.basename(checkpoint_filepath) 70 | model_dir = os.path.dirname(checkpoint_filepath) 71 | dataset_name = os.path.basename(os.path.dirname(model_dir)) 72 | dataset_dir = os.path.join(datasets_dir, dataset_name) 73 | else: 74 | raise ValueError("Invalid arguments. Use --help for proper usage.") 75 | 76 | #Load the hparams from file 77 | hparams_filepath = os.path.join(model_dir, "hparams.json") 78 | hparams = Hparams.load(hparams_filepath) 79 | 80 | return dataset_dir, model_dir, hparams, checkpoint, encoder_embeddings_dir, decoder_embeddings_dir 81 | 82 | def initialize_session_server(checkpointfile): 83 | #Make sure checkpoint file & hparams file exists 84 | checkpoint_filepath = os.path.relpath(checkpointfile) 85 | if not os.path.isfile(checkpoint_filepath + ".meta"): 86 | raise FileNotFoundError("The checkpoint file '{0}' was not found.".format(os.path.realpath(checkpoint_filepath))) 87 | #Get the checkpoint model directory 88 | checkpoint = os.path.basename(checkpoint_filepath) 89 | model_dir = os.path.dirname(checkpoint_filepath) 90 | 91 | #Load the hparams from file 92 | hparams_filepath = os.path.join(model_dir, "hparams.json") 93 | hparams = Hparams.load(hparams_filepath) 94 | 95 | return model_dir, hparams, checkpoint 96 | 97 | def create_batch_files(model_dir, checkpoint_training, checkpoint_val, encoder_embeddings_dir, decoder_embeddings_dir): 98 | os_type = platform.system().lower() 99 | if os_type == "windows": 100 | if checkpoint_training is not None: 101 | create_windows_batch_files(model_dir, checkpoint_training, encoder_embeddings_dir, decoder_embeddings_dir) 102 | if checkpoint_val is not None: 103 | create_windows_batch_files(model_dir, checkpoint_val, encoder_embeddings_dir, decoder_embeddings_dir) 104 | elif os_type == "darwin": 105 | pass 106 | elif os_type == "linux": 107 | pass 108 | else: 109 | pass 110 | 111 | def create_windows_batch_files(model_dir, checkpoint, encoder_embeddings_dir, decoder_embeddings_dir): 112 | if "CONDA_PREFIX" in os.environ: 113 | conda_prefix = os.environ["CONDA_PREFIX"] 114 | conda_activate = os.path.join(conda_prefix, r"scripts\activate.bat") 115 | checkpoint_file = os.path.join(model_dir, checkpoint) 116 | checkpoint_name = os.path.splitext(checkpoint)[0] 117 | 118 | #Resume training batch file 119 | batch_file = os.path.join(model_dir, "resume_training_{0}.bat".format(checkpoint_name)) 120 | with open(batch_file, mode="w", encoding="utf-8") as file: 121 | file.write("\n".join([ 122 | "call {0} {1}".format(conda_activate, conda_prefix), 123 | r"cd ..\..\..", 124 | "python train.py --checkpointfile=\"{0}\"{1}{2}".format(checkpoint_file, 125 | " --encoderembeddingsdir={0}".format(encoder_embeddings_dir) if encoder_embeddings_dir is not None else "", 126 | " --decoderembeddingsdir={0}".format(decoder_embeddings_dir) if decoder_embeddings_dir is not None else ""), 127 | "", 128 | "cmd /k" 129 | ])) 130 | 131 | #Chat batch file 132 | batch_file = os.path.join(model_dir, "chat_console_{0}.bat".format(checkpoint_name)) 133 | with open(batch_file, mode="w", encoding="utf-8") as file: 134 | file.write("\n".join([ 135 | "call {0} {1}".format(conda_activate, conda_prefix), 136 | r"cd ..\..\..", 137 | "python chat.py \"{0}\"".format(checkpoint_file), 138 | "", 139 | "cmd /k" 140 | ])) 141 | 142 | #Chat web batch file 143 | batch_file = os.path.join(model_dir, "chat_web_{0}.bat".format(checkpoint_name)) 144 | with open(batch_file, mode="w", encoding="utf-8") as file: 145 | file.write("\n".join([ 146 | "call {0} {1}".format(conda_activate, conda_prefix), 147 | r"cd ..\..\..", 148 | "set FLASK_APP=chat_web.py", 149 | "flask serve_chat \"{0}\" -p 8080".format(checkpoint_file), 150 | "", 151 | "cmd /k" 152 | ])) 153 | 154 | #Tensorboard batch file 155 | batch_file = os.path.join(model_dir, "tensorboard_{0}.bat".format(checkpoint_name)) 156 | with open(batch_file, mode="w", encoding="utf-8") as file: 157 | file.write("\n".join([ 158 | "call {0} {1}".format(conda_activate, conda_prefix), 159 | "tensorboard --logdir=." 160 | ])) 161 | -------------------------------------------------------------------------------- /hparams.json: -------------------------------------------------------------------------------- 1 | { 2 | "py/object": "hparams.Hparams", 3 | "model_hparams" : { 4 | "py/object": "hparams.ModelHparams", 5 | "rnn_cell_type": "lstm", 6 | "rnn_size": 1024, 7 | "use_bidirectional_encoder": true, 8 | "encoder_num_layers": 4, 9 | "decoder_num_layers": 4, 10 | "encoder_embedding_size": 128, 11 | "decoder_embedding_size": 128, 12 | "encoder_embedding_trainable": true, 13 | "decoder_embedding_trainable": true, 14 | "share_embedding": true, 15 | "attention_type": "normed_bahdanau", 16 | "beam_width": 20, 17 | "enable_sampling": false, 18 | "optimizer": "sgd", 19 | "max_gradient_norm": 5.0, 20 | "gpu_dynamic_memory_growth": true 21 | }, 22 | "training_hparams": { 23 | "py/object": "hparams.TrainingHparams", 24 | "min_question_words": 1, 25 | "max_question_answer_words": 30, 26 | "max_conversations": -1, 27 | "conv_history_length": 6, 28 | "normalize_words": false, 29 | "input_vocab_threshold": 1, 30 | "output_vocab_threshold": 1, 31 | "input_vocab_import_normalized": true, 32 | "output_vocab_import_normalized": true, 33 | "input_vocab_import_mode": "Dataset", 34 | "output_vocab_import_mode": "Dataset", 35 | "validation_set_percent": 0.0, 36 | "random_train_val_split": true, 37 | "validation_metric": "loss", 38 | "epochs": 500, 39 | "early_stopping_epochs": 500, 40 | "batch_size": 128, 41 | "learning_rate": 2.0, 42 | "learning_rate_decay": 0.9975, 43 | "min_learning_rate": 0.1, 44 | "dropout": 0.2, 45 | "checkpoint_on_training": true, 46 | "checkpoint_on_validation": false, 47 | "log_summary": true, 48 | "log_cleaned_dataset": true, 49 | "log_training_data": true, 50 | "stats_after_n_batches": 100, 51 | "backup_on_training_loss": [3.0, 2.5, 2.0, 1.5, 1.0, 0.5] 52 | }, 53 | "inference_hparams": { 54 | "py/object": "hparams.InferenceHparams", 55 | "beam_length_penalty_weight": 1.25, 56 | "sampling_temperature": 0.5, 57 | "max_answer_words": 100, 58 | "conv_history_length": 20, 59 | "normalize_words": false, 60 | "log_summary": true, 61 | "log_chat": true 62 | } 63 | } -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hyperparameters class 3 | """ 4 | 5 | import jsonpickle 6 | from vocabulary_importers.vocabulary_importer import VocabularyImportMode 7 | 8 | class Hparams(object): 9 | """Container for model, training, and inference hyperparameters. 10 | 11 | Members: 12 | model_hparams: ModelHparams instance 13 | 14 | training_hparams: TrainingHparams instance 15 | 16 | inference_hparams: InferenceHparams instance 17 | """ 18 | def __init__(self): 19 | """Initializes the Hparams instance. 20 | """ 21 | self.model_hparams = ModelHparams() 22 | self.training_hparams = TrainingHparams() 23 | self.inference_hparams = InferenceHparams() 24 | 25 | @staticmethod 26 | def load(filepath): 27 | """Loads the hyperparameters from a JSON file. 28 | 29 | Args: 30 | filepath: path of the JSON file. 31 | """ 32 | with open(filepath, "r") as file: 33 | json = file.read() 34 | hparams = jsonpickle.decode(json) 35 | hparams.training_hparams.input_vocab_import_mode = VocabularyImportMode[hparams.training_hparams.input_vocab_import_mode] 36 | hparams.training_hparams.output_vocab_import_mode = VocabularyImportMode[hparams.training_hparams.output_vocab_import_mode] 37 | return hparams 38 | 39 | 40 | class ModelHparams(object): 41 | """Hyperparameters which determine the architecture and complexity of the chatbot model. 42 | 43 | Members: 44 | rnn_cell_type: The architecture of RNN cell: "lstm" or "gru" 45 | LSTM: "Long-Short Term Memory" 46 | GRU: "Gated Recurrent Unit" 47 | 48 | rnn_size: the number of units (neurons) in each RNN cell. Applies to the encoder and decoder. 49 | 50 | use_bidirectional_encoder: True to use a bi-directional encoder. 51 | Bi-directional encoder: Two separate RNN cells (or stacks of cells) are used - 52 | one receives the input sequence (question) in forward order, one receives the input sequence (question) in reverse order. 53 | When creating stacked RNN layers, each direction is stacked separately, with one stack for forward cells 54 | and one stack for reverse cells. 55 | Uni-directional encoder: One RNN cell (or stack of cells) is used in the forward direction (traditional RNN) 56 | 57 | encoder_num_layers: the number of RNN cells to stack in the encoder. 58 | If use_bidirectional_encoder is set to true, this number is divided in half and applied to 59 | each direction. For example: 4 layers with bidrectional encoder means 2 forward & 2 backward cells. 60 | 61 | decoder_num_layers: the number of RNN cells to stack in the decoder. 62 | The encoder state can only be passed in to the decoder as its intial state if this value 63 | is the same as encoder_num_layers. 64 | 65 | encoder_embedding_size: the number of dimensions for each vector in the encoder embedding matrix. 66 | This matrix will be shaped (input_vocabulary.size(), encoder_embedding_size) 67 | 68 | decoder_embedding_size: the number of dimensions for each vector in the decoder embedding matrix. 69 | This matrix will be shaped (output_vocabulary.size(), decoder_embedding_size) 70 | 71 | encoder_embedding_trainable: True to allow gradient updates to be applied to the encoder embedding matrix. 72 | False to freeze the embedding matrix and only train the encoder & decoder RNNs, enabling greater training 73 | efficiency when loading pre-trained embeddings such as Word2Vec. 74 | 75 | decoder_embedding_trainable: True to allow gradient updates to be applied to the decoder embedding matrix. 76 | False to freeze the embedding matrix and only train the encoder & decoder RNNs, enabling greater training 77 | efficiency when loading pre-trained embeddings such as Word2Vec. 78 | 79 | share_embedding: True to reuse the same embedding matrix for the encoder and decoder. 80 | If the vocabulary is identical between input questions and output answers (as in a chatbot), then this should be True. 81 | If the vocabulary is different between input questions and output answers (as in a domain-specific Q&A system), then this should be False. 82 | If True - 83 | 1) input_vocabulary.size() & output_vocabulary.size() must have the same value 84 | 2) encoder_embedding_size & decoder_embedding_size must have the same value 85 | 3) encoder_embedding_trainable & decoder_embedding_trainable must have the same value 86 | 4) If loading pre-trained embeddings, --encoderembeddingsdir & --decoderembeddingsdir args 87 | must be supplied with the same value (or --embeddingsdir can be used instead) 88 | If all of the above conditions are not met, an error is raised. 89 | 90 | attention_type: Type of attention mechanism to use. 91 | ("bahdanau", "normed_bahdanau", "luong", "scaled_luong") 92 | 93 | beam_width: If mode is "infer", the number of beams to generate with the BeamSearchDecoder. 94 | Set to 0 for greedy / sampling decoding. 95 | This value is ignored if mode is "train". 96 | NOTE: this parameter should ideally be in InferenceHparams instead of ModelHparams, but is here for now 97 | because the graph of the model physically changes based on the beam width. 98 | 99 | enable_sampling: If True while beam_width = 0, the sampling decoder is used instead of the greedy decoder. 100 | 101 | optimizer: Type of optimizer to use when training. 102 | ("sgd", "adam") 103 | NOTE: this parameter should ideally be in TrainingHparams instead of ModelHparams, but is here for now 104 | because the graph of the model physically changes based on which optimizer is used. 105 | 106 | max_gradient_norm: max value to clip the gradients if gradient clipping is enabled. 107 | Set to 0 to disable gradient clipping. Defaults to 5. 108 | This value is ignored if mode is "infer". 109 | NOTE: this parameter should ideally be in TrainingHparams instead of ModelHparams, but is here for now 110 | because the graph of the model physically changes based on whether or not gradient clipping is used. 111 | 112 | gpu_dynamic_memory_growth: Configures the TensorFlow session to only allocate GPU memory as needed, 113 | instead of the default behavior of trying to aggresively allocate as much memory as possible. 114 | Defaults to True. 115 | """ 116 | def __init__(self): 117 | """Initializes the ModelHparams instance. 118 | """ 119 | self.rnn_cell_type = "lstm" 120 | 121 | self.rnn_size = 256 122 | 123 | self.use_bidirectional_encoder = True 124 | 125 | self.encoder_num_layers = 2 126 | 127 | self.decoder_num_layers = 2 128 | 129 | self.encoder_embedding_size = 256 130 | 131 | self.decoder_embedding_size = 256 132 | 133 | self.encoder_embedding_trainable = True 134 | 135 | self.decoder_embedding_trainable = True 136 | 137 | self.share_embedding = True 138 | 139 | self.attention_type = "normed_bahdanau" 140 | 141 | self.beam_width = 10 142 | 143 | self.enable_sampling = False 144 | 145 | self.optimizer = "adam" 146 | 147 | self.max_gradient_norm = 5. 148 | 149 | self.gpu_dynamic_memory_growth = True 150 | 151 | class TrainingHparams(object): 152 | """Hyperparameters used when training the chatbot model. 153 | 154 | Members: 155 | min_question_words: minimum length (in words) for a question. 156 | set this to a higher number if you wish to exclude shorter questions which 157 | can sometimes lead to higher training error. 158 | 159 | max_question_answer_words: maximum length (in words) for a question or answer. 160 | any questions or answers longer than this are truncated to fit. The higher this number, the more 161 | timesteps the encoder RNN will need to be unrolled. 162 | 163 | max_conversations: number of conversations to use from the cornell dataset. Specify -1 for no limit. 164 | pick a lower limit if training on the whole dataset is too slow (for lower-end GPUs) 165 | 166 | conv_history_length: number of conversation steps to prepend every question. 167 | For example, a length of 2 would output: 168 | "hello how are you ? i am fine thank you how is the new job?" 169 | where "how is the new job?" is the question and the rest is the prepended conversation history. 170 | the intent is to let the attention mechanism be able to pick up context clues from earlier in the 171 | conversation in order to determine the best way to respond. 172 | pick a lower limit if training is too slow or causes out of memory errors. The higher this number, 173 | the more timesteps the encoder RNN will need to be unrolled. 174 | 175 | normalize_words: True to preprocess the words in the training dataset by replacing word contractions 176 | with their full forms (e.g. i'm -> i am) and then stripping out any remaining apostrophes. 177 | 178 | input_vocab_threshold: the minimum number of times a word must appear in the questions in order to be included 179 | in the vocabulary embedding. Any words that are not included in the vocabulary 180 | get replaced with an token before training and inference. 181 | if model_params.share_embedding = True, this must equal output_vocab_threshold. 182 | 183 | output_vocab_threshold: the minimum number of times a word must appear in the answers in order to be included 184 | For more info see input_vocab_threshold. 185 | if model_params.share_embedding = True, this must equal input_vocab_threshold. 186 | 187 | input_vocab_import_normalized: True to normalize external word vocabularies and embeddings before import as input vocabulary. 188 | In this context normalization means convert all word tokens to lower case and then average the embedding vectors for any duplicate words. 189 | For example, "JOHN", "John", and "john" will be converted to "john" and it will take the mean of all three embedding vectors. 190 | 191 | output_vocab_import_normalized: True to normalize external word vocabularies and embeddings before import as output vocabulary. 192 | For more info see input_vocab_import_normalized. 193 | 194 | input_vocab_import_mode: Mode to govern how external vocabularies and embeddings are imported as input vocabulary. 195 | Ignored if no external vocabulary specified. 196 | See VocabularyImportMode. 197 | 198 | output_vocab_import_mode: Mode to govern how external vocabularies and embeddings are imported as output vocabulary. 199 | Ignored if no external vocabulary specified. 200 | See VocabularyImportMode. 201 | This should be set to 'Dataset' or 'ExternalIntersectDataset' for large vocabularies, since the size of the 202 | decoder output layer is the vocabulary size. For example, an external embedding may have a vocabulary size 203 | of 1 million, but only 30k words appear in the dataset and having an output layer of 30k dimensions is 204 | much more efficient than an output layer of 1m dimensions. 205 | 206 | validation_set_percent: the percentage of the training dataset to use as the validation set. 207 | 208 | random_train_val_split: 209 | True to split the dataset randomly. 210 | False to split the dataset sequentially 211 | (validation samples are the last N samples, where N = samples * (val_percent / 100)) 212 | 213 | validation_metric: the metric to use to measure the model during validation. 214 | "loss" - cross-entropy loss between predictions and targets 215 | "accuracy" (coming soon) 216 | "bleu" (coming soon) 217 | 218 | epochs: Number of epochs to train (1 epoch = all samples in dataset) 219 | 220 | early_stopping_epochs: stop early if no improvement in the validation metric 221 | after training for the given number of epochs in a row. 222 | 223 | batch_size: Training batch size 224 | 225 | learning_rate: learning rate used by SGD. 226 | 227 | learning_rate_decay: rate at which the learning rate drops. 228 | for each epoch, current_lr = starting_lr * (decay) ^ (epoch - 1) 229 | 230 | min_learning_rate: lowest value that the learning rate can go. 231 | 232 | dropout: probability that any neuron will be temporarily disabled during any training iteration. 233 | this is a regularization technique that helps the model learn more independent correlations in the data 234 | and can reduce overfitting. 235 | 236 | checkpoint_on_training: Write a checkpoint after an epoch if the training loss improved. 237 | 238 | checkpoint_on_validation: Write a checkpoint after an epoch if the validation metric improved. 239 | 240 | log_summary: True to log training stats & graph for visualization in tensorboard. 241 | 242 | log_cleaned_dataset: True to save a copy of the cleaned dataset to disk for debugging purposes before training begins. 243 | 244 | log_training_data: True to save a copy of the training question-answer pairs as represented by their vocabularies to disk. 245 | this is useful to see how frequently words are replaced by and also how dialog context is prepended to questions. 246 | 247 | stats_after_n_batches: Output training statistics (loss, time, etc.) after every N batches. 248 | 249 | backup_on_training_loss: List of training loss values upon which to backup the model 250 | Backups are full copies of the latest checkpoint files to another directory, also including vocab and hparam files. 251 | """ 252 | def __init__(self): 253 | """Initializes the TrainingHparams instance. 254 | """ 255 | self.min_question_words = 1 256 | 257 | self.max_question_answer_words = 30 258 | 259 | self.max_conversations = -1 260 | 261 | self.conv_history_length = 6 262 | 263 | self.normalize_words = True 264 | 265 | self.input_vocab_threshold = 2 266 | 267 | self.output_vocab_threshold = 2 268 | 269 | self.input_vocab_import_normalized = True 270 | 271 | self.output_vocab_import_normalized = True 272 | 273 | self.input_vocab_import_mode = VocabularyImportMode.External 274 | 275 | self.output_vocab_import_mode = VocabularyImportMode.Dataset 276 | 277 | self.validation_set_percent = 0 278 | 279 | self.random_train_val_split = True 280 | 281 | self.validation_metric = "loss" 282 | 283 | self.epochs = 500 284 | 285 | self.early_stopping_epochs = 500 286 | 287 | self.batch_size = 128 288 | 289 | self.learning_rate = 2.0 290 | 291 | self.learning_rate_decay = 0.99 292 | 293 | self.min_learning_rate = 0.1 294 | 295 | self.dropout = 0.2 296 | 297 | self.checkpoint_on_training = True 298 | 299 | self.checkpoint_on_validation = True 300 | 301 | self.log_summary = True 302 | 303 | self.log_cleaned_dataset = True 304 | 305 | self.log_training_data = True 306 | 307 | self.stats_after_n_batches = 100 308 | 309 | self.backup_on_training_loss = [] 310 | 311 | class InferenceHparams(object): 312 | """Hyperparameters used when chatting with the chatbot model (a.k.a prediction or inference). 313 | 314 | Members: 315 | beam_length_penalty_weight: higher values mean longer beams are scored better 316 | while lower (or negative) values mean shorter beams are scored better. 317 | Ignored if beam_width = 0 318 | 319 | sampling_temperature: This value sets the softmax temperature of the sampling decoder, if enabled. 320 | 321 | max_answer_words: Max length (in words) for an answer. 322 | 323 | conv_history_length: number of conversation steps to prepend every question. 324 | This can be different from the value used during training. 325 | 326 | normalize_words: True to preprocess the words in the input question by replacing word contractions 327 | with their full forms (e.g. i'm -> i am) and then stripping out any remaining apostrophes. 328 | 329 | log_summary: True to log attention alignment images and inference graph for visualization in tensorboard. 330 | 331 | log_chat: True to log conversation history (chatlog) to a file. 332 | """ 333 | def __init__(self): 334 | """Initializes the InferenceHparams instance. 335 | """ 336 | self.beam_length_penalty_weight = 1.25 337 | 338 | self.sampling_temperature = 0.5 339 | 340 | self.max_answer_words = 100 341 | 342 | self.conv_history_length = 6 343 | 344 | self.normalize_words = True 345 | 346 | self.log_summary = True 347 | 348 | self.log_chat = True 349 | 350 | -------------------------------------------------------------------------------- /models/cornell_movie_dialog/README.md: -------------------------------------------------------------------------------- 1 | # Trained Models for Cornell Movie Dialog Corpus 2 | 3 | ## Trained Model v2 (Jul 21 2018) 4 | This model was initialized using the [nnlm_en](../../embeddings/nnlm_en/README.md) pre-trained word embedding vectors. Embeddings were updated during training. 5 | 6 | You can download it from [here](https://drive.google.com/uc?id=1y1b1vXeSti5lpBACNdYlo8HbVJDUO3ir&export=download). 7 | 8 | After download, unzip the folder **trained_model_v2** into this directory. 9 | 10 | ## Trained Model v1 (Mar 29 2018) 11 | You can download it from [here](https://drive.google.com/uc?id=1Ig-sgdka5QpgE-b9g4ZQqnGGCrE-2f4p&export=download). 12 | 13 | After download, unzip the folder **trained_model_v1** into this directory. 14 | 15 | To chat with the trained cornell movie dialog model **trained_model_v2**: 16 | 17 | 1. Download and unzip [trained_model_v2](seq2seq-chatbot/models/cornell_movie_dialog/README.md) into the [seq2seq-chatbot/models/cornell_movie_dialog](seq2seq-chatbot/models/cornell_movie_dialog) folder 18 | 19 | 2. Set console working directory to the **seq2seq-chatbot** directory 20 | 21 | 3. Run: 22 | ```shell 23 | run chat.py models\cornell_movie_dialog\trained_model_v2\best_weights_training.ckpt 24 | ``` 25 | -------------------------------------------------------------------------------- /roadmap.md: -------------------------------------------------------------------------------- 1 | # Roadmap 2 | For bugs and other issues see the repository issues on GitHub. 3 | 4 | ## General Improvements 5 | - Load training data from a file instead of storing it in memory to allow for quick resume of training without needing to regenerate the dataset from the source corpus. 6 | 7 | - Read training data from a file in chunks to allow for datasets beyond the size of available RAM. 8 | 9 | - Output attention alignments correctly in tensorboard summary logs. 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for training the chatbot model 3 | """ 4 | import time 5 | import math 6 | from os import path 7 | from shutil import copytree 8 | 9 | import general_utils 10 | import train_console_helper 11 | from dataset_readers import dataset_reader_factory 12 | from vocabulary_importers import vocabulary_importer_factory 13 | from vocabulary import Vocabulary 14 | from chatbot_model import ChatbotModel 15 | from training_stats import TrainingStats 16 | 17 | #Read the hyperparameters and paths 18 | dataset_dir, model_dir, hparams, resume_checkpoint, encoder_embeddings_dir, decoder_embeddings_dir = general_utils.initialize_session("train") 19 | training_stats_filepath = path.join(model_dir, "training_stats.json") 20 | 21 | #Read the chatbot dataset and generate / import the vocabulary 22 | dataset_reader = dataset_reader_factory.get_dataset_reader(dataset_dir) 23 | 24 | print() 25 | print("Reading dataset '{0}'...".format(dataset_reader.dataset_name)) 26 | dataset, dataset_read_stats = dataset_reader.read_dataset(dataset_dir = dataset_dir, 27 | model_dir = model_dir, 28 | training_hparams = hparams.training_hparams, 29 | share_vocab = hparams.model_hparams.share_embedding, 30 | encoder_embeddings_dir = encoder_embeddings_dir, 31 | decoder_embeddings_dir = decoder_embeddings_dir) 32 | if encoder_embeddings_dir is not None: 33 | print() 34 | print("Imported {0} vocab '{1}'...".format("shared" if hparams.model_hparams.share_embedding else "input", encoder_embeddings_dir)) 35 | train_console_helper.write_vocabulary_import_stats(dataset_read_stats.input_vocabulary_import_stats) 36 | 37 | if decoder_embeddings_dir is not None and not hparams.model_hparams.share_embedding: 38 | print() 39 | print("Imported output vocab '{0}'...".format(decoder_embeddings_dir)) 40 | train_console_helper.write_vocabulary_import_stats(dataset_read_stats.output_vocabulary_import_stats) 41 | 42 | print() 43 | print("Final {0} vocab size: {1}".format("shared" if hparams.model_hparams.share_embedding else "input", dataset.input_vocabulary.size())) 44 | if not hparams.model_hparams.share_embedding: 45 | print("Final output vocab size: {0}".format(dataset.output_vocabulary.size())) 46 | 47 | #Split the chatbot dataset into training & validation datasets 48 | print() 49 | print("Splitting {0} samples into training & validation sets ({1}% used for validation)..." 50 | .format(dataset.size(), hparams.training_hparams.validation_set_percent)) 51 | 52 | training_dataset, validation_dataset = dataset.train_val_split(val_percent = hparams.training_hparams.validation_set_percent, 53 | random_split = hparams.training_hparams.random_train_val_split) 54 | training_dataset_size = training_dataset.size() 55 | validation_dataset_size = validation_dataset.size() 56 | print("Training set: {0} samples. Validation set: {1} samples." 57 | .format(training_dataset_size, validation_dataset_size)) 58 | 59 | print("Sorting training & validation sets to increase training efficiency...") 60 | training_dataset.sort() 61 | validation_dataset.sort() 62 | 63 | #Log the final training dataset if configured to do so 64 | if hparams.training_hparams.log_training_data: 65 | training_data_log_filepath = path.join(model_dir, "training_data.txt") 66 | training_dataset.save(training_data_log_filepath) 67 | 68 | #Create the model 69 | print("Initializing model...") 70 | print() 71 | with ChatbotModel(mode = "train", 72 | model_hparams = hparams.model_hparams, 73 | input_vocabulary = dataset.input_vocabulary, 74 | output_vocabulary = dataset.output_vocabulary, 75 | model_dir = model_dir) as model: 76 | 77 | print() 78 | 79 | #Restore from checkpoint if specified 80 | best_train_checkpoint = "best_weights_training.ckpt" 81 | best_val_checkpoint = "best_weights_validation.ckpt" 82 | training_stats = TrainingStats(hparams.training_hparams) 83 | if resume_checkpoint is not None: 84 | print("Resuming training from checkpoint {0}...".format(resume_checkpoint)) 85 | model.load(resume_checkpoint) 86 | training_stats.load(training_stats_filepath) 87 | else: 88 | print("Creating checkpoint batch files...") 89 | general_utils.create_batch_files(model_dir, 90 | best_train_checkpoint if hparams.training_hparams.checkpoint_on_training else None, 91 | best_val_checkpoint if hparams.training_hparams.checkpoint_on_validation else None, 92 | encoder_embeddings_dir, 93 | decoder_embeddings_dir) 94 | 95 | print("Initializing training...") 96 | 97 | print("Epochs: {0}".format(hparams.training_hparams.epochs)) 98 | print("Batch Size: {0}".format(hparams.training_hparams.batch_size)) 99 | print("Optimizer: {0}".format(hparams.model_hparams.optimizer)) 100 | 101 | backup_on_training_loss = sorted(hparams.training_hparams.backup_on_training_loss.copy(), reverse=True) 102 | 103 | #Train on all batches in epoch 104 | for epoch in range(1, hparams.training_hparams.epochs + 1): 105 | batch_counter = 0 106 | batches_starting_time = time.time() 107 | batches_total_train_loss = 0 108 | epoch_starting_time = time.time() 109 | epoch_total_train_loss = 0 110 | train_batches = training_dataset.batches(hparams.training_hparams.batch_size) 111 | for batch_index, (questions, answers, seqlen_questions, seqlen_answers) in enumerate(train_batches): 112 | batch_train_loss = model.train_batch(inputs = questions, 113 | targets = answers, 114 | input_sequence_length = seqlen_questions, 115 | target_sequence_length = seqlen_answers, 116 | learning_rate = training_stats.learning_rate, 117 | dropout = hparams.training_hparams.dropout, 118 | global_step = training_stats.global_step, 119 | log_summary = hparams.training_hparams.log_summary) 120 | batches_total_train_loss += batch_train_loss 121 | epoch_total_train_loss += batch_train_loss 122 | batch_counter += 1 123 | training_stats.global_step += 1 124 | if batch_counter == hparams.training_hparams.stats_after_n_batches or batch_index == (training_dataset_size // hparams.training_hparams.batch_size): 125 | batches_average_train_loss = batches_total_train_loss / batch_counter 126 | epoch_average_train_loss = epoch_total_train_loss / (batch_index + 1) 127 | print('Epoch: {:>3}/{}, Batch: {:>4}/{}, Stats for last {} batches: (Training Loss: {:>6.3f}, Training Time: {:d} seconds), Stats for epoch: (Training Loss: {:>6.3f}, Training Time: {:d} seconds)'.format( 128 | epoch, 129 | hparams.training_hparams.epochs, 130 | batch_index + 1, 131 | math.ceil(training_dataset_size / hparams.training_hparams.batch_size), 132 | batch_counter, 133 | batches_average_train_loss, 134 | int(time.time() - batches_starting_time), 135 | epoch_average_train_loss, 136 | int(time.time() - epoch_starting_time))) 137 | batches_total_train_loss = 0 138 | batch_counter = 0 139 | batches_starting_time = time.time() 140 | 141 | #End of epoch activities 142 | #Run validation 143 | if validation_dataset_size > 0: 144 | total_val_metric_value = 0 145 | batches_starting_time = time.time() 146 | val_batches = validation_dataset.batches(hparams.training_hparams.batch_size) 147 | for batch_index_validation, (questions, answers, seqlen_questions, seqlen_answers) in enumerate(val_batches): 148 | batch_val_metric_value = model.validate_batch(inputs = questions, 149 | targets = answers, 150 | input_sequence_length = seqlen_questions, 151 | target_sequence_length = seqlen_answers, 152 | metric = hparams.training_hparams.validation_metric) 153 | total_val_metric_value += batch_val_metric_value 154 | average_val_metric_value = total_val_metric_value / math.ceil(validation_dataset_size / hparams.training_hparams.batch_size) 155 | print('Epoch: {:>3}/{}, Validation {}: {:>6.3f}, Batch Validation Time: {:d} seconds'.format( 156 | epoch, 157 | hparams.training_hparams.epochs, 158 | hparams.training_hparams.validation_metric, 159 | average_val_metric_value, 160 | int(time.time() - batches_starting_time))) 161 | 162 | #Apply learning rate decay 163 | if hparams.training_hparams.learning_rate_decay > 0: 164 | prev_learning_rate, learning_rate = training_stats.decay_learning_rate() 165 | print('Learning rate decay: adjusting from {:>6.3f} to {:>6.3f}'.format(prev_learning_rate, learning_rate)) 166 | 167 | #Checkpoint - training 168 | if training_stats.compare_training_loss(epoch_average_train_loss): 169 | if hparams.training_hparams.checkpoint_on_training: 170 | model.save(best_train_checkpoint) 171 | training_stats.save(training_stats_filepath) 172 | print('Training loss improved!') 173 | 174 | #Checkpoint - validation 175 | if validation_dataset_size > 0: 176 | if training_stats.compare_validation_metric(average_val_metric_value): 177 | if hparams.training_hparams.checkpoint_on_validation: 178 | model.save(best_val_checkpoint) 179 | training_stats.save(training_stats_filepath) 180 | print('Validation {0} improved!'.format(hparams.training_hparams.validation_metric)) 181 | else: 182 | if training_stats.early_stopping_check == hparams.training_hparams.early_stopping_epochs: 183 | print("Early stopping checkpoint reached - validation loss has not improved in {0} epochs. Terminating training...".format(hparams.training_hparams.early_stopping_epochs)) 184 | break 185 | 186 | #Backup 187 | do_backup = False 188 | while len(backup_on_training_loss) > 0 and epoch_average_train_loss <= backup_on_training_loss[0]: 189 | backup_on_training_loss.pop(0) 190 | do_backup = True 191 | if do_backup: 192 | backup_dir = "{0}_backup_{1}".format(model_dir, "{:0.3f}".format(epoch_average_train_loss).replace(".", "_")) 193 | copytree(model_dir, backup_dir) 194 | general_utils.create_batch_files(backup_dir, 195 | best_train_checkpoint if hparams.training_hparams.checkpoint_on_training else None, 196 | best_val_checkpoint if hparams.training_hparams.checkpoint_on_validation else None, 197 | encoder_embeddings_dir, 198 | decoder_embeddings_dir) 199 | print('Backup to {0} complete!'.format(backup_dir)) 200 | 201 | #Training is complete... if no checkpointing was turned on, save the final model state 202 | if not hparams.training_hparams.checkpoint_on_training and not hparams.training_hparams.checkpoint_on_validation: 203 | model.save(best_train_checkpoint) 204 | model.save(best_val_checkpoint) 205 | training_stats.save(training_stats_filepath) 206 | print('Model saved.') 207 | print("Training Complete!") 208 | -------------------------------------------------------------------------------- /train_console_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Console helper for training session 3 | """ 4 | 5 | def write_vocabulary_import_stats(vocabulary_import_stats): 6 | print(" Stats:") 7 | print(" External vocab size: {0}".format(vocabulary_import_stats.external_vocabulary_size)) 8 | if vocabulary_import_stats.dataset_vocabulary_size is not None: 9 | print(" Dataset vocab size: {0}".format(vocabulary_import_stats.dataset_vocabulary_size)) 10 | if vocabulary_import_stats.intersection_size is not None: 11 | print(" Intersection size: {0}".format(vocabulary_import_stats.intersection_size)) -------------------------------------------------------------------------------- /training_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | TrainingStats class 3 | """ 4 | import jsonpickle 5 | 6 | class TrainingStats(object): 7 | """Class that contains a set of metrics & stats that represent a model at a point in time. 8 | """ 9 | 10 | def __init__(self, training_hparams): 11 | """Initializes the TrainingStats instance. 12 | 13 | Args: 14 | training_hparams: the training hyperparameters. 15 | """ 16 | self.training_hparams = training_hparams 17 | self.best_validation_metric_value = self._get_metric_baseline(self.training_hparams.validation_metric) 18 | self.best_training_loss = self._get_metric_baseline("loss") 19 | self.learning_rate = self.training_hparams.learning_rate 20 | self.early_stopping_check = 0 21 | self.global_step = 0 22 | 23 | def __getstate__(self): 24 | state = self.__dict__.copy() 25 | del state["training_hparams"] 26 | state["best_validation_metric_value"] = float(self.best_validation_metric_value) 27 | state["best_training_loss"] = float(self.best_training_loss) 28 | return state 29 | 30 | def __setstate__(self, state): 31 | self.__dict__.update(state) 32 | 33 | def compare_training_loss(self, new_value): 34 | """Compare the best training loss against a new value. 35 | 36 | Args: 37 | new_value: the new training loss value to compare against. 38 | 39 | Returns: 40 | True if new_value is better than the best training loss. 41 | False if the best training loss is better than or equal to new_value 42 | """ 43 | if self._compare_metric("loss", self.best_training_loss, new_value): 44 | self.best_training_loss = new_value 45 | return True 46 | else: 47 | return False 48 | 49 | def compare_validation_metric(self, new_value): 50 | """Compare the best validation metric value against a new value. 51 | Validation metric is specified in training_hparams. 52 | 53 | Args: 54 | new_value: the new validation metric value to compare against. 55 | 56 | Returns: 57 | True if new_value is better than the best validation metric value. 58 | False if the best validation metric value is better than or equal to new_value 59 | """ 60 | if self._compare_metric(self.training_hparams.validation_metric, self.best_validation_metric_value, new_value): 61 | self.best_validation_metric_value = new_value 62 | self.early_stopping_check = 0 63 | return True 64 | else: 65 | self.early_stopping_check += 1 66 | return False 67 | 68 | def decay_learning_rate(self): 69 | """Multiply the current learning rate by the decay coefficient specified in training_hparams. 70 | 71 | If the learning rate falls below the minimum learning rate, it is set to the minimum. 72 | """ 73 | prev_learning_rate = self.learning_rate 74 | self.learning_rate *= self.training_hparams.learning_rate_decay 75 | if self.learning_rate < self.training_hparams.min_learning_rate: 76 | self.learning_rate = self.training_hparams.min_learning_rate 77 | return prev_learning_rate, self.learning_rate 78 | 79 | def save(self, filepath): 80 | """Saves the TrainingStats to disk. 81 | 82 | Args: 83 | filepath: The path of the file to save to 84 | """ 85 | json = jsonpickle.encode(self) 86 | with open(filepath, "w") as file: 87 | file.write(json) 88 | 89 | def load(self, filepath): 90 | """Loads the TrainingStats from a JSON file. 91 | 92 | Args: 93 | filepath: path of the JSON file. 94 | """ 95 | with open(filepath) as file: 96 | json = file.read() 97 | training_stats = jsonpickle.decode(json) 98 | self.best_validation_metric_value = training_stats.best_validation_metric_value 99 | self.best_training_loss = training_stats.best_training_loss 100 | self.learning_rate = training_stats.learning_rate 101 | self.early_stopping_check = training_stats.early_stopping_check 102 | self.global_step = training_stats.global_step 103 | 104 | 105 | def _compare_metric(self, metric, previous_value, new_value): 106 | """Compare a new metric value with its previous known value and determine which value is better. 107 | 108 | Which value is better is specific to the metric. 109 | For instance, loss is a lower-is-better metric while accuracy is a higher-is-better metric. 110 | 111 | Args: 112 | metric: The metric being compared 113 | 114 | previous_value: The previous known value for the metric. 115 | 116 | new_value: The new value to compare against the previous value. 117 | 118 | Returns: 119 | True if new_value is better than previous_value 120 | False if previous_value is better than or equal to new_value 121 | """ 122 | if metric == "loss": 123 | return new_value < previous_value 124 | else: 125 | raise ValueError("Unsupported metric: '{0}'".format(metric)) 126 | 127 | def _get_metric_baseline(self, metric): 128 | """Gets a baseline value for a metric that can be used to compare the first measurement against. 129 | 130 | For lower-is-better metrics such as loss, this will be a very large number (99999) 131 | 132 | For higher-is-better metrics such as accuracy, this will be 0. 133 | 134 | Args: 135 | metric: The metric for which to get a baseline value 136 | """ 137 | if metric == "loss": 138 | return 99999 139 | else: 140 | raise ValueError("Unsupported metric: '{0}'".format(metric)) -------------------------------------------------------------------------------- /vocabulary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vocabulary class 3 | """ 4 | import re 5 | 6 | class Vocabulary(object): 7 | """Class representing a chatbot vocabulary. 8 | 9 | The Vocabulary class is responsible for encoding words into integers and decoding integers into words. 10 | The number of times each word occurs in the source corpus is also tracked for visualization purposes. 11 | 12 | Special tokens that exist in every vocabulary instance: 13 | - PAD (""): The token used for extra sequence timesteps in a batch 14 | - SOS (""): Start Of Sequence token is used as the input of the first decoder timestep 15 | - EOS (""): End Of Sequence token is used to signal that the decoder should stop generating a sequence. 16 | It is also used to separate conversation history (context) questions prepended to the current input question. 17 | - OUT (""): If a word does not exist in the vocabulary, it is substituted with this token. 18 | """ 19 | 20 | SHARED_VOCAB_FILENAME = "shared_vocab.tsv" 21 | INPUT_VOCAB_FILENAME = "input_vocab.tsv" 22 | OUTPUT_VOCAB_FILENAME = "output_vocab.tsv" 23 | 24 | PAD = "" 25 | SOS = "" 26 | EOS = "" 27 | OUT = "" 28 | special_tokens = [PAD, SOS, EOS, OUT] 29 | 30 | def __init__(self, external_embeddings = None): 31 | """Initializes the Vocabulary instance in an non-compiled state. 32 | Compile must be called before the Vocab instance can be used to integer encode/decode words. 33 | 34 | Args: 35 | external_embeddings: An optional 2d numpy array (matrix) containing external embedding vectors 36 | """ 37 | self._word2count = {} 38 | self._words2int = {} 39 | self._ints2word = {} 40 | self._compiled = False 41 | self.external_embeddings = external_embeddings 42 | 43 | def load_word(self, word, word_int, count = 1): 44 | """Load a word and its integer encoding into the vocabulary instance. 45 | 46 | Args: 47 | word: The word to load. 48 | 49 | word_int: The integer encoding of the word to load. 50 | 51 | count: (Optional) The number of times the word occurs in the source corpus. 52 | """ 53 | self._validate_compile(False) 54 | 55 | self._word2count[word] = count 56 | self._words2int[word] = word_int 57 | self._ints2word[word_int] = word 58 | 59 | def add_words(self, words): 60 | """Add a sequence of words to the vocabulary instance. 61 | If a word occurs more than once, its count will be incremented accordingly. 62 | 63 | Args: 64 | words: The sequence of words to add. 65 | """ 66 | self._validate_compile(False) 67 | 68 | for i in range(len(words)): 69 | word = words[i] 70 | if word in self._word2count: 71 | self._word2count[word] += 1 72 | else: 73 | self._word2count[word] = 1 74 | 75 | def compile(self, vocab_threshold = 1, loading = False): 76 | """Compile the internal lookup dictionaries that enable words to be integer encoded / decoded. 77 | 78 | Args: 79 | vocab_threshold: Minimum number of times any word must appear within word_sequences in order to be included in the vocabulary. 80 | This is useful for filtering out rarely used words in order to reduce the size of the vocabulary 81 | (which consequently reduces the size of the model's embedding matrices & reduces the dimensionality of the output softmax) 82 | This value is ignored if loading is True. 83 | 84 | loading: Indicates if the vocabulary is being loaded from disk, in which case the compilation is already done and this method 85 | only needs to set the flag to indicate as such. 86 | 87 | 88 | """ 89 | self._validate_compile(False) 90 | 91 | if not loading: 92 | #Add the special tokens to the lookup dictionaries 93 | for i, special_token in enumerate(Vocabulary.special_tokens): 94 | self._words2int[special_token] = i 95 | self._ints2word[i] = special_token 96 | 97 | #Add the words in _word2count to the lookup dictionaries if their count meets the threshold. 98 | #Any words that don't meet the threshold are removed. 99 | word_int = len(self._words2int) 100 | for word, count in sorted(self._word2count.items()): 101 | if count >= vocab_threshold: 102 | self._words2int[word] = word_int 103 | self._ints2word[word_int] = word 104 | word_int += 1 105 | else: 106 | del self._word2count[word] 107 | 108 | #Add the special tokens to _word2count so they have count values for saving to disk 109 | self.add_words(Vocabulary.special_tokens) 110 | 111 | #The Vocabulary instance may now be used for integer encoding / decoding 112 | self._compiled = True 113 | 114 | 115 | 116 | def size(self): 117 | """The size (number of words) of the Vocabulary 118 | """ 119 | self._validate_compile(True) 120 | return len(self._word2count) 121 | 122 | def word_exists(self, word): 123 | """Check if the given word exists in the vocabulary. 124 | 125 | Args: 126 | word: The word to check. 127 | """ 128 | self._validate_compile(True) 129 | return word in self._words2int 130 | 131 | def words2ints(self, words): 132 | """Encode a sequence of space delimited words into a sequence of integers 133 | 134 | Args: 135 | words: The sequence of space delimited words to encode 136 | """ 137 | return [self.word2int(w) for w in words.split()] 138 | 139 | def word2int(self, word): 140 | """Encode a word into an integer 141 | 142 | Args: 143 | word: The word to encode 144 | """ 145 | self._validate_compile(True) 146 | return self._words2int[word] if word in self._words2int else self.out_int() 147 | 148 | def ints2words(self, words_ints, is_punct_discrete_word = False, capitalize_i = True): 149 | """Decode a sequence of integers into a sequence of space delimited words 150 | 151 | Args: 152 | words_ints: The sequence of integers to decode 153 | 154 | is_punct_discrete_word: True to output a space before punctuation 155 | False to place punctuation immediately after the end of the preceeding word (normal usage). 156 | """ 157 | words = "" 158 | for i in words_ints: 159 | word = self.int2word(i, capitalize_i) 160 | if is_punct_discrete_word or word not in ['.', '!', '?']: 161 | words += " " 162 | words += word 163 | words = words.strip() 164 | return words 165 | 166 | def int2word(self, word_int, capitalize_i = True): 167 | """Decode an integer into a word 168 | 169 | Args: 170 | words_int: The integer to decode 171 | """ 172 | self._validate_compile(True) 173 | word = self._ints2word[word_int] 174 | if capitalize_i and word == 'i': 175 | word = 'I' 176 | return word 177 | 178 | def pad_int(self): 179 | """Get the integer encoding of the PAD token 180 | """ 181 | return self.word2int(Vocabulary.PAD) 182 | 183 | def sos_int(self): 184 | """Get the integer encoding of the SOS token 185 | """ 186 | return self.word2int(Vocabulary.SOS) 187 | 188 | def eos_int(self): 189 | """Get the integer encoding of the EOS token 190 | """ 191 | return self.word2int(Vocabulary.EOS) 192 | 193 | def out_int(self): 194 | """Get the integer encoding of the OUT token 195 | """ 196 | return self.word2int(Vocabulary.OUT) 197 | 198 | def save(self, filepath): 199 | """Saves the vocabulary to disk. 200 | 201 | Args: 202 | filepath: The path of the file to save to 203 | """ 204 | total_words = self.size() 205 | with open(filepath, "w", encoding="utf-8") as file: 206 | file.write('\t'.join(["word", "count"])) 207 | file.write('\n') 208 | for i in range(total_words): 209 | word = self._ints2word[i] 210 | count = self._word2count[word] 211 | file.write('\t'.join([word, str(count)])) 212 | if i < total_words - 1: 213 | file.write('\n') 214 | 215 | def _validate_compile(self, expected_status): 216 | """Validate that the vocabulary is compiled or not based on the needs of the attempted operation 217 | 218 | Args: 219 | expected_status: The compilation status expected by the attempted operation 220 | """ 221 | if self._compiled and not expected_status: 222 | raise ValueError("This vocabulary instance has already been compiled.") 223 | if not self._compiled and expected_status: 224 | raise ValueError("This vocabulary instance has not been compiled yet.") 225 | 226 | @staticmethod 227 | def load(filepath): 228 | """Loads the vocabulary from disk. 229 | 230 | Args: 231 | filepath: The path of the file to load from 232 | """ 233 | vocabulary = Vocabulary() 234 | 235 | with open(filepath, encoding="utf-8") as file: 236 | for index, line in enumerate(file): 237 | if index > 0: #Skip header line 238 | word, count = line.split('\t') 239 | word_int = index - 1 240 | vocabulary.load_word(word, word_int, int(count)) 241 | 242 | vocabulary.compile(loading = True) 243 | return vocabulary 244 | 245 | @staticmethod 246 | def clean_text(text, max_words = None, normalize_words = True): 247 | """Clean text to prepare for training and inference. 248 | 249 | Clean by removing unsupported special characters & extra whitespace, 250 | and by normalizing common word permutations (i.e. can't, cannot, can not) 251 | 252 | Args: 253 | text: the text to clean 254 | 255 | max_words: maximum number of words to output (assuming words are separated by spaces). 256 | any words beyond this limit are truncated. 257 | Defaults to None (unlimited number of words) 258 | 259 | normalize_words: True to replace word contractions with their full forms (e.g. i'm -> i am) 260 | and then strip out any remaining apostrophes. 261 | """ 262 | text = text.lower() 263 | text = re.sub(r"'+", "'", text) 264 | if normalize_words: 265 | text = re.sub(r"i'm", "i am", text) 266 | text = re.sub(r"he's", "he is", text) 267 | text = re.sub(r"she's", "she is", text) 268 | text = re.sub(r"that's", "that is", text) 269 | text = re.sub(r"there's", "there is", text) 270 | text = re.sub(r"what's", "what is", text) 271 | text = re.sub(r"where's", "where is", text) 272 | text = re.sub(r"who's", "who is", text) 273 | text = re.sub(r"how's", "how is", text) 274 | text = re.sub(r"it's", "it is", text) 275 | text = re.sub(r"let's", "let us", text) 276 | text = re.sub(r"\'ll", " will", text) 277 | text = re.sub(r"\'ve", " have", text) 278 | text = re.sub(r"\'re", " are", text) 279 | text = re.sub(r"\'d", " would", text) 280 | text = re.sub(r"won't", "will not", text) 281 | text = re.sub(r"shan't", "shall not", text) 282 | text = re.sub(r"can't", "can not", text) 283 | text = re.sub(r"cannot", "can not", text) 284 | text = re.sub(r"n't", " not", text) 285 | text = re.sub(r"'", "", text) 286 | else: 287 | text = re.sub(r"(\W)'", r"\1", text) 288 | text = re.sub(r"'(\W)", r"\1", text) 289 | text = re.sub(r"[()\"#/@;:<>{}`+=~|$&*%\[\]_]", "", text) 290 | text = re.sub(r"[.]+", " . ", text) 291 | text = re.sub(r"[!]+", " ! ", text) 292 | text = re.sub(r"[?]+", " ? ", text) 293 | text = re.sub(r"[,-]+", " ", text) 294 | text = re.sub(r"[\t]+", " ", text) 295 | text = re.sub(r" +", " ", text) 296 | text = text.strip() 297 | 298 | #Truncate words beyond the limit, if provided. Remove partial sentences from the end if punctuation exists within the limit. 299 | if max_words is not None: 300 | text_parts = text.split() 301 | if len(text_parts) > max_words: 302 | truncated_text_parts = text_parts[:max_words] 303 | while len(truncated_text_parts) > 0 and not re.match("[.!?]", truncated_text_parts[-1]): 304 | truncated_text_parts.pop(-1) 305 | if len(truncated_text_parts) == 0: 306 | truncated_text_parts = text_parts[:max_words] 307 | text = " ".join(truncated_text_parts) 308 | 309 | return text 310 | 311 | @staticmethod 312 | def auto_punctuate(text): 313 | """Automatically apply punctuation to text that does not end with any punctuation marks. 314 | 315 | Args: 316 | text: the text to apply punctuation to. 317 | """ 318 | text = text.strip() 319 | if not (text.endswith(".") or text.endswith("?") or text.endswith("!") or text.startswith("--")): 320 | tmp = re.sub(r"'", "", text.lower()) 321 | if (tmp.startswith("who") or tmp.startswith("what") or tmp.startswith("when") or 322 | tmp.startswith("where") or tmp.startswith("why") or tmp.startswith("how") or 323 | tmp.endswith("who") or tmp.endswith("what") or tmp.endswith("when") or 324 | tmp.endswith("where") or tmp.endswith("why") or tmp.endswith("how") or 325 | tmp.startswith("are") or tmp.startswith("will") or tmp.startswith("wont") or tmp.startswith("can")): 326 | text = "{}?".format(text) 327 | else: 328 | text = "{}.".format(text) 329 | return text -------------------------------------------------------------------------------- /vocabulary_importers/__pycache__/vocabulary_importer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/vocabulary_importers/__pycache__/vocabulary_importer.cpython-36.pyc -------------------------------------------------------------------------------- /vocabulary_importers/checkpoint_vocabulary_importer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for TensorFlow Checkpoint vocabulary importers 3 | """ 4 | import tensorflow as tf 5 | from collections import OrderedDict 6 | from os import path 7 | from vocabulary_importers.vocabulary_importer import VocabularyImporter 8 | 9 | class CheckpointVocabularyImporter(VocabularyImporter): 10 | """Base class for TensorFlow Checkpoint vocabulary importers 11 | """ 12 | 13 | def __init__(self, vocabulary_name, tokens_filename, embeddings_variable_name): 14 | super(CheckpointVocabularyImporter, self).__init__(vocabulary_name) 15 | """Initialize the CheckpointVocabularyImporter. 16 | 17 | Args: 18 | vocabulary_name: See base class 19 | 20 | tokens_filename: Name of the file containing the token/word list. Subclass must pass this in. 21 | 22 | embeddings_variable_name: Name of the variable to read out of the checkpoint. Subclass must pass this in. 23 | """ 24 | 25 | self.tokens_filename = tokens_filename 26 | 27 | self.embeddings_variable_name = embeddings_variable_name 28 | 29 | def _read_vocabulary_and_embeddings(self, vocabulary_dir): 30 | """Read the raw vocabulary file(s) and return the tokens list with corresponding word vectors 31 | 32 | Args: 33 | vocabulary_dir: See base class 34 | """ 35 | 36 | #Import embeddings 37 | tf.reset_default_graph() 38 | embeddings = tf.Variable(tf.contrib.framework.load_variable(vocabulary_dir, self.embeddings_variable_name), name = "embeddings") 39 | with tf.Session() as sess: 40 | sess.run(tf.global_variables_initializer()) 41 | loaded_embeddings_matrix = sess.run(embeddings) 42 | 43 | #Import vocabulary 44 | tokens_filepath = path.join(vocabulary_dir, self.tokens_filename) 45 | tokens_with_embeddings = OrderedDict() 46 | with open(tokens_filepath, encoding="utf-8") as file: 47 | for index, line in enumerate(file): 48 | token = line.strip() 49 | if token != "": 50 | token = self._process_token(token) 51 | tokens_with_embeddings[token] = loaded_embeddings_matrix[index] 52 | 53 | return tokens_with_embeddings -------------------------------------------------------------------------------- /vocabulary_importers/dependency_based_vocabulary_importer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Importer class for the Dependency-Based vocabulary (Levy & Goldberg, 2014) 3 | https://levyomer.wordpress.com/2014/04/25/dependency-based-word-embeddings/ 4 | https://levyomer.files.wordpress.com/2014/04/dependency-based-word-embeddings-acl-2014.pdf 5 | """ 6 | from os import path 7 | 8 | from vocabulary_importers.flatfile_vocabulary_importer import FlatFileVocabularyImporter 9 | from vocabulary import Vocabulary 10 | 11 | class DependencyBasedVocabularyImporter(FlatFileVocabularyImporter): 12 | """Importer implementation for the Dependency-Based vocabulary 13 | """ 14 | def __init__(self): 15 | super(DependencyBasedVocabularyImporter, self).__init__("dependency_based", "deps.words", " ") 16 | 17 | def _process_token(self, token): 18 | """Perform token preprocessing (See base class for explanation) 19 | 20 | Args: 21 | See base class 22 | 23 | Returns: 24 | See base class 25 | """ 26 | 27 | if token == "[": 28 | token = Vocabulary.SOS 29 | elif token == "]": 30 | token = Vocabulary.EOS 31 | elif token == "iz": 32 | token = Vocabulary.OUT 33 | elif token == "--": 34 | token = Vocabulary.PAD 35 | elif token == "''": 36 | token = "?" 37 | 38 | return token -------------------------------------------------------------------------------- /vocabulary_importers/flatfile_vocabulary_importer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for Flat File vocabulary importers 3 | """ 4 | import numpy as np 5 | from collections import OrderedDict 6 | from os import path 7 | from vocabulary_importers.vocabulary_importer import VocabularyImporter 8 | 9 | class FlatFileVocabularyImporter(VocabularyImporter): 10 | """Base class for Flat File vocabulary importers 11 | """ 12 | 13 | def __init__(self, vocabulary_name, tokens_and_embeddings_filename, delimiter): 14 | super(FlatFileVocabularyImporter, self).__init__(vocabulary_name) 15 | """Initialize the FlatFileVocabularyImporter. 16 | 17 | Args: 18 | vocabulary_name: See base class 19 | 20 | tokens_and_embeddings_filename: Name of the file containing the token/word list and embeddings. 21 | Format should be one line per word where the word is at the beginning of the line and the embedding vector follows 22 | seperated by a delimiter. 23 | 24 | delimiter: Character that separates the word and the values of the embedding vector. 25 | """ 26 | 27 | self.tokens_and_embeddings_filename = tokens_and_embeddings_filename 28 | 29 | self.delimiter = delimiter 30 | 31 | def _read_vocabulary_and_embeddings(self, vocabulary_dir): 32 | """Read the raw vocabulary file(s) and return the tokens list with corresponding word vectors 33 | 34 | Args: 35 | vocabulary_dir: See base class 36 | """ 37 | 38 | tokens_and_embeddings_filepath = path.join(vocabulary_dir, self.tokens_and_embeddings_filename) 39 | tokens_with_embeddings = OrderedDict() 40 | with open(tokens_and_embeddings_filepath, encoding="utf-8") as file: 41 | for _, line in enumerate(file): 42 | values = line.split(self.delimiter) 43 | token = values[0].strip() 44 | if token != "": 45 | token = self._process_token(token) 46 | tokens_with_embeddings[token] = np.array(values[1:], dtype=np.float32) 47 | 48 | return tokens_with_embeddings -------------------------------------------------------------------------------- /vocabulary_importers/nnlm_en_vocabulary_importer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Importer class for the nnlm english vocabulary (Bengio et al, 2003) 3 | https://www.tensorflow.org/hub/modules/google/nnlm-en-dim128/1 4 | http://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf 5 | """ 6 | from os import path 7 | 8 | from vocabulary_importers.checkpoint_vocabulary_importer import CheckpointVocabularyImporter 9 | from vocabulary import Vocabulary 10 | 11 | class NnlmEnVocabularyImporter(CheckpointVocabularyImporter): 12 | """Importer implementation for the nnlm english vocabulary 13 | """ 14 | def __init__(self): 15 | super(NnlmEnVocabularyImporter, self).__init__("nnlm_en", "tokens.txt", "embeddings") 16 | 17 | def _process_token(self, token): 18 | """Perform token preprocessing (See base class for explanation) 19 | 20 | Args: 21 | See base class 22 | 23 | Returns: 24 | See base class 25 | """ 26 | 27 | if token == "": 28 | token = Vocabulary.SOS 29 | elif token == "": 30 | token = Vocabulary.EOS 31 | elif token == "": 32 | token = Vocabulary.OUT 33 | elif token == "--": 34 | token = Vocabulary.PAD 35 | 36 | return token -------------------------------------------------------------------------------- /vocabulary_importers/vocabulary_importer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for all vocabulary importers 3 | """ 4 | import abc 5 | import numpy as np 6 | from enum import Enum 7 | from collections import OrderedDict 8 | from vocabulary import Vocabulary 9 | 10 | class VocabularyImportMode(Enum): 11 | """Vocabulary import modes 12 | 13 | Members: 14 | External: Vocabulary is all the external words. Entire vocabulary has pre-trained embeddings. 15 | Dataset words not in vocabulary are mapped to during training and chatting. 16 | 17 | ExternalIntersectDataset: Vocabulary is only the external words that also exist in the dataset. 18 | Entire vocabulary has pre-trained embeddings. Dataset words not in vocabulary are mapped to during training and chatting. 19 | 20 | ExternalUnionDataset: Vocabulary is the unique set of all external words and all dataset words. 21 | Vocabulary has partially pre-trained and partially randomized embeddings to be learned during training. 22 | Dataset words are always in vocabulary. 23 | 24 | Dataset: Vocabulary is all the dataset words. Vocabulary has partially pre-trained and partially randomized embeddings 25 | to be learned during training. Dataset words are always in vocabulary. 26 | """ 27 | External = 1 28 | ExternalIntersectDataset = 2 29 | ExternalUnionDataset = 3 30 | Dataset = 4 31 | 32 | class VocabularyImportStats(object): 33 | """Contains information about the imported vocabulary. 34 | """ 35 | 36 | def __init__(self): 37 | self.external_vocabulary_size = None 38 | self.dataset_vocabulary_size = None 39 | self.intersection_size = None 40 | 41 | class VocabularyImporter(object): 42 | """Base class for all vocabulary importers 43 | """ 44 | 45 | def __init__(self, vocabulary_name): 46 | """Initialize the VocabularyImporter. 47 | 48 | Args: 49 | vocabulary_name: Name of the vocabulary. Subclass must pass this in. 50 | """ 51 | self.vocabulary_name = vocabulary_name 52 | 53 | @abc.abstractmethod 54 | def _process_token(self, token): 55 | """Subclass must implement this. 56 | 57 | Perform any preprocessing (e.g. replacement) to each token before importing it. 58 | 59 | Args: 60 | token: the token/word to process 61 | 62 | Returns: 63 | the processed token/word 64 | """ 65 | pass 66 | 67 | @abc.abstractmethod 68 | def _read_vocabulary_and_embeddings(self, vocabulary_dir): 69 | """Subclass must implement this. 70 | 71 | Read the raw vocabulary file(s) and return the tokens list with corresponding word vectors 72 | 73 | Args: 74 | vocabulary_dir: See import_vocabulary 75 | 76 | Returns: 77 | tokens_with_embeddings: OrderedDict of words in the vocabulary with corresponding word vectors as numpy arrays 78 | """ 79 | pass 80 | 81 | def import_vocabulary(self, vocabulary_dir, normalize = True, import_mode = VocabularyImportMode.External, dataset_vocab = None): 82 | """Read the raw vocabulary file(s) and use it to initialize a Vocabulary object 83 | 84 | Args: 85 | vocabulary_dir: directory to load the raw vocabulary file from 86 | 87 | normalize: True to convert all word tokens to lower case and then average the 88 | embedding vectors for any duplicate words before import. 89 | 90 | import_mode: indicates import behavior. See VocabularyImportMode class for more info. 91 | 92 | dataset_vocab: Optionally provide the vocabulary instance generated from the dataset. 93 | This must be provided if import_mode is not 'External'. 94 | 95 | Returns: 96 | vocabulary: The final imported vocabulary instance 97 | 98 | import_stats: Informational statistics on the external vocabulary, the dataset vocabulary, and their intersection. 99 | """ 100 | 101 | if dataset_vocab is None and import_mode != VocabularyImportMode.External: 102 | raise ValueError("dataset_vocab must be provided if import_mode is not 'External'.") 103 | 104 | import_stats = VocabularyImportStats() 105 | 106 | #Read the external vocabulary tokens and embeddings 107 | tokens_with_embeddings = self._read_vocabulary_and_embeddings(vocabulary_dir) 108 | 109 | #If normalize flag is true, normalize casing of the external vocabulary and average embeddings for any resulting duplicate tokens 110 | if normalize: 111 | tokens_with_embeddings = self._normalize_tokens_with_embeddings(tokens_with_embeddings) 112 | 113 | import_stats.external_vocabulary_size = len(tokens_with_embeddings) 114 | 115 | #Apply dataset filters if applicable 116 | if dataset_vocab is not None: 117 | import_stats.dataset_vocabulary_size = dataset_vocab.size() 118 | 119 | if import_mode == VocabularyImportMode.ExternalIntersectDataset or import_mode == VocabularyImportMode.Dataset: 120 | #Get rid of all tokens that exist in the external vocabulary but don't exist in the dataset 121 | for token in list(tokens_with_embeddings.keys()): 122 | if not dataset_vocab.word_exists(token): 123 | del tokens_with_embeddings[token] 124 | import_stats.intersection_size = len(tokens_with_embeddings) 125 | 126 | if import_mode == VocabularyImportMode.ExternalUnionDataset or import_mode == VocabularyImportMode.Dataset: 127 | #Add any tokens that exist in the dataset but don't exist in the external vocabulary. 128 | #These added tokens will get word vectors sampled from the gaussian distributions of their components: 129 | # where the mean of each component is the mean of that component in the external embedding matrix 130 | # and the standard deviation of each component is the standard deviation of that component in the external embedding matrix 131 | embeddings_matrix = np.array(list(tokens_with_embeddings.values()), dtype=np.float32) 132 | emb_size = embeddings_matrix.shape[1] 133 | emb_mean = np.mean(embeddings_matrix, axis=0) 134 | emb_stdev = np.std(embeddings_matrix, axis=0) 135 | for i in range(dataset_vocab.size()): 136 | dataset_token = dataset_vocab.int2word(i, capitalize_i=False) 137 | if dataset_token not in tokens_with_embeddings: 138 | tokens_with_embeddings[dataset_token] = np.random.normal(emb_mean, emb_stdev, emb_size) 139 | 140 | if len(tokens_with_embeddings) == 0: 141 | raise ValueError("Imported vocabulary size is 0. Try a different VocabularyImportMode (currently {0})".format( 142 | VocabularyImportMode(import_mode).name)) 143 | 144 | tokens, embeddings_matrix = zip(*tokens_with_embeddings.items()) 145 | embeddings_matrix = np.array(embeddings_matrix, dtype=np.float32) 146 | 147 | #Create the vocabulary instance 148 | vocabulary = Vocabulary(external_embeddings = embeddings_matrix) 149 | for i in range(len(tokens)): 150 | vocabulary.load_word(tokens[i], i) 151 | vocabulary.compile(loading = True) 152 | return vocabulary, import_stats 153 | 154 | def _normalize_tokens_with_embeddings(self, tokens_with_embeddings): 155 | """Convert all word tokens to lower case and then average the embedding vectors for any duplicate words 156 | """ 157 | norm_tokens_with_embeddings = OrderedDict() 158 | for token, embedding in tokens_with_embeddings.items(): 159 | if token not in Vocabulary.special_tokens: 160 | token = token.lower() 161 | if token in norm_tokens_with_embeddings: 162 | norm_tokens_with_embeddings[token].append(embedding) 163 | else: 164 | norm_tokens_with_embeddings[token] = [embedding] 165 | 166 | for token, embedding in norm_tokens_with_embeddings.items(): 167 | norm_tokens_with_embeddings[token] = np.mean(embedding, axis=0) 168 | 169 | return norm_tokens_with_embeddings -------------------------------------------------------------------------------- /vocabulary_importers/vocabulary_importer_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vocabulary importer implementation factory 3 | """ 4 | from os import path 5 | from vocabulary_importers.word2vec_wikipedia_vocabulary_importer import Word2vecWikipediaVocabularyImporter 6 | from vocabulary_importers.nnlm_en_vocabulary_importer import NnlmEnVocabularyImporter 7 | from vocabulary_importers.dependency_based_vocabulary_importer import DependencyBasedVocabularyImporter 8 | 9 | def get_vocabulary_importer(vocabulary_dir): 10 | """Gets the appropriate importer implementation for the specified vocabulary name. 11 | 12 | Args: 13 | vocabulary_dir: The directory of the vocabulary to get a importer implementation for. 14 | """ 15 | vocabulary_name = path.basename(vocabulary_dir) 16 | 17 | #When adding support for new vocabularies, add an instance of their importer class to the importer array below. 18 | importers = [Word2vecWikipediaVocabularyImporter(), 19 | NnlmEnVocabularyImporter(), 20 | DependencyBasedVocabularyImporter()] 21 | 22 | for importer in importers: 23 | if importer.vocabulary_name == vocabulary_name: 24 | return importer 25 | 26 | raise ValueError("There is no vocabulary importer implementation for '{0}'. If this is a new vocabulary, please add one!".format(vocabulary_name)) -------------------------------------------------------------------------------- /vocabulary_importers/word2vec_wikipedia_vocabulary_importer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Importer class for the word2vec wikipedia vocabulary (Mikolov et al, 2013) 3 | https://www.tensorflow.org/hub/modules/google/Wiki-words-250/1 4 | https://arxiv.org/abs/1301.3781 5 | """ 6 | from os import path 7 | 8 | from vocabulary_importers.checkpoint_vocabulary_importer import CheckpointVocabularyImporter 9 | from vocabulary import Vocabulary 10 | 11 | class Word2vecWikipediaVocabularyImporter(CheckpointVocabularyImporter): 12 | """Importer implementation for the word2vec wikipedia vocabulary 13 | """ 14 | def __init__(self): 15 | super(Word2vecWikipediaVocabularyImporter, self).__init__("word2vec_wikipedia", "tokens.txt", "embeddings") 16 | 17 | def _process_token(self, token): 18 | """Perform token preprocessing (See base class for explanation) 19 | 20 | Args: 21 | See base class 22 | 23 | Returns: 24 | See base class 25 | """ 26 | 27 | if token == "": 28 | token = Vocabulary.SOS 29 | elif token == "": 30 | token = Vocabulary.EOS 31 | elif token == "": 32 | token = Vocabulary.OUT 33 | elif token == "#!#": 34 | token = "!" 35 | elif token == "#.#": 36 | token = "." 37 | elif token == "#?#": 38 | token = "?" 39 | 40 | return token --------------------------------------------------------------------------------