├── README.md
├── __pycache__
├── chat_command_handler.cpython-36.pyc
├── chat_settings.cpython-36.pyc
├── chatbot_model.cpython-36.pyc
├── general_utils.cpython-36.pyc
├── hparams.cpython-36.pyc
└── vocabulary.cpython-36.pyc
├── chat.py
├── chat_command_handler.py
├── chat_settings.py
├── chat_ui.html
├── chat_web.py
├── chatbot_model.py
├── chatbot_screenshots
├── seq1.png
├── seq10.png
├── seq11.png
├── seq2.png
├── seq3.png
├── seq4.png
├── seq5.png
├── seq6.png
├── seq7.png
├── seq8.png
└── seq9.png
├── dataset.py
├── dataset_readers
├── cornell_dataset_reader.py
├── csv_dataset_reader.py
├── dataset_reader.py
└── dataset_reader_factory.py
├── datasets
├── cornell_movie_dialog
│ ├── README.md
│ ├── movie_conversations.txt
│ ├── movie_lines.txt
│ ├── train_with_dependency_based_embeddings.bat
│ ├── train_with_dependency_based_embeddings_decoder_only.bat
│ ├── train_with_dependency_based_embeddings_encoder_only.bat
│ ├── train_with_nnlm_en_embeddings.bat
│ ├── train_with_nnlm_en_embeddings_decoder_only.bat
│ ├── train_with_nnlm_en_embeddings_encoder_only.bat
│ ├── train_with_random_embeddings.bat
│ ├── train_with_word2vec_wikipedia_embeddings.bat
│ ├── train_with_word2vec_wikipedia_embeddings_decoder_only.bat
│ └── train_with_word2vec_wikipedia_embeddings_encoder_only.bat
└── csv
│ ├── README.md
│ ├── csv_data.csv
│ ├── train_with_dependency_based_embeddings.bat
│ ├── train_with_dependency_based_embeddings_decoder_only.bat
│ ├── train_with_dependency_based_embeddings_encoder_only.bat
│ ├── train_with_nnlm_en_embeddings.bat
│ ├── train_with_nnlm_en_embeddings_decoder_only.bat
│ ├── train_with_nnlm_en_embeddings_encoder_only.bat
│ ├── train_with_random_embeddings.bat
│ ├── train_with_word2vec_wikipedia_embeddings.bat
│ ├── train_with_word2vec_wikipedia_embeddings_decoder_only.bat
│ └── train_with_word2vec_wikipedia_embeddings_encoder_only.bat
├── general_utils.py
├── hparams.json
├── hparams.py
├── models
└── cornell_movie_dialog
│ └── README.md
├── roadmap.md
├── train.py
├── train_console_helper.py
├── training_stats.py
├── vocabulary.py
└── vocabulary_importers
├── __pycache__
└── vocabulary_importer.cpython-36.pyc
├── checkpoint_vocabulary_importer.py
├── dependency_based_vocabulary_importer.py
├── flatfile_vocabulary_importer.py
├── nnlm_en_vocabulary_importer.py
├── vocabulary_importer.py
├── vocabulary_importer_factory.py
└── word2vec_wikipedia_vocabulary_importer.py
/README.md:
--------------------------------------------------------------------------------
1 | # seq2seq-chatbot
2 | A sequence2sequence chatbot implementation with TensorFlow.
3 |
4 | ## Chatting with a trained model
5 |
6 | For console chat:
7 | 1. Run `chat_console_best_weights_training.bat` or `chat_console_best_weights_validation.bat`
8 |
9 | For web chat:
10 | 1. Run `chat_web_best_weights_training.bat` or `chat_web_best_weights_validation.bat`
11 |
12 | 2. Open a browser to the URL indicated by the server console, followed by `/chat_ui.html`. This is typically: [http://localhost:8080/chat_ui.html](http://localhost:8080/chat_ui.html)
13 |
14 | ### To chat with a trained model from a python console:
15 |
16 | 1. Set console working directory to the **seq2seq-chatbot** directory. This directory should have the **models** and **datasets** directories directly within it.
17 |
18 | 2. Run chat.py with the model checkpoint path:
19 | ```shell
20 | run chat.py models\cornell_movie_dialog\trained_model\best_weights_training.ckpt
21 | ```
22 |
23 | ## Training a model
24 | To train a model from a python console:
25 |
26 | 1. To train a new model, run train.py with the dataset path:
27 | ```shell
28 | run train.py --datasetdir=datasets\dataset_name
29 | ```
30 |
31 | Or to resume training an existing model, run train.py with the model checkpoint path:
32 | ```shell
33 | run train.py --checkpointfile=models\dataset_name\model_name\checkpoint.ckpt
34 | ```
35 | ## Visualizing a model in TensorBoard
36 |
37 | To start TensorBoard from a terminal:
38 | ```shell
39 | tensorboard --logdir=model_dir
40 | ```
41 | ## Dependencies
42 | The following python packages are used in seq2seq-chatbot:
43 | (excluding packages that come with Anaconda)
44 |
45 | - [TensorFlow](https://www.tensorflow.org/)
46 | ```shell
47 | pip install --upgrade tensorflow
48 | ```
49 | For GPU support: [(See here for full GPU install instructions including CUDA and cuDNN)](https://www.tensorflow.org/install/)
50 | ```shell
51 | pip install --upgrade tensorflow-gpu
52 | ```
53 |
54 | - [jsonpickle](https://jsonpickle.github.io/)
55 | ```shell
56 | pip install --upgrade jsonpickle
57 | ```
58 |
59 | - [flask 0.12.4](http://flask.pocoo.org/) and [flask-restful](https://flask-restful.readthedocs.io/en/latest/) (required to run the web interface)
60 | ```shell
61 | pip install flask==0.12.4
62 | pip install --upgrade flask-restful
63 | ```
64 |
--------------------------------------------------------------------------------
/__pycache__/chat_command_handler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/chat_command_handler.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/chat_settings.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/chat_settings.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/chatbot_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/chatbot_model.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/general_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/general_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/hparams.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/hparams.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/vocabulary.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/__pycache__/vocabulary.cpython-36.pyc
--------------------------------------------------------------------------------
/chat.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for chatting with a trained chatbot model
3 | """
4 | import datetime
5 | from os import path
6 |
7 | import general_utils
8 | import chat_command_handler
9 | from chat_settings import ChatSettings
10 | from chatbot_model import ChatbotModel
11 | from vocabulary import Vocabulary
12 |
13 | #Read the hyperparameters and configure paths
14 | _, model_dir, hparams, checkpoint, _, _ = general_utils.initialize_session("chat")
15 |
16 | #Load the vocabulary
17 | print()
18 | print("Loading vocabulary...")
19 | if hparams.model_hparams.share_embedding:
20 | shared_vocab_filepath = path.join(model_dir, Vocabulary.SHARED_VOCAB_FILENAME)
21 | input_vocabulary = Vocabulary.load(shared_vocab_filepath)
22 | output_vocabulary = input_vocabulary
23 | else:
24 | input_vocab_filepath = path.join(model_dir, Vocabulary.INPUT_VOCAB_FILENAME)
25 | input_vocabulary = Vocabulary.load(input_vocab_filepath)
26 | output_vocab_filepath = path.join(model_dir, Vocabulary.OUTPUT_VOCAB_FILENAME)
27 | output_vocabulary = Vocabulary.load(output_vocab_filepath)
28 |
29 | # Setting up the chat
30 | chatlog_filepath = path.join(model_dir, "chat_logs", "chatlog_{0}.txt".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S")))
31 | chat_settings = ChatSettings(hparams.model_hparams, hparams.inference_hparams)
32 | terminate_chat = False
33 | reload_model = False
34 | while not terminate_chat:
35 | #Create the model
36 | print()
37 | print("Initializing model..." if not reload_model else "Re-initializing model...")
38 | print()
39 | with ChatbotModel(mode = "infer",
40 | model_hparams = chat_settings.model_hparams,
41 | input_vocabulary = input_vocabulary,
42 | output_vocabulary = output_vocabulary,
43 | model_dir = model_dir) as model:
44 |
45 | #Load the weights
46 | print()
47 | print("Loading model weights...")
48 | print()
49 | model.load(checkpoint)
50 |
51 | #Show the commands
52 | if not reload_model:
53 | chat_command_handler.print_commands()
54 |
55 | while True:
56 | #Get the input and check if it is a question or a command, and execute if it is a command
57 | question = input("You: ")
58 | is_command, terminate_chat, reload_model = chat_command_handler.handle_command(question, model, chat_settings)
59 | if terminate_chat or reload_model:
60 | break
61 | elif is_command:
62 | continue
63 | else:
64 | #If it is not a command (it is a question), pass it on to the chatbot model to get the answer
65 | question_with_history, answer = model.chat(question, chat_settings)
66 |
67 | #Print the answer or answer beams and log to chat log
68 | if chat_settings.show_question_context:
69 | print("Question with history (context): {0}".format(question_with_history))
70 |
71 | if chat_settings.show_all_beams:
72 | for i in range(len(answer)):
73 | print("ChatBot (Beam {0}): {1}".format(i, answer[i]))
74 | else:
75 | print("ChatBot: {0}".format(answer))
76 |
77 | print()
78 |
79 | if chat_settings.inference_hparams.log_chat:
80 | chat_command_handler.append_to_chatlog(chatlog_filepath, question, answer)
--------------------------------------------------------------------------------
/chat_command_handler.py:
--------------------------------------------------------------------------------
1 | """
2 | Command handler for chat session
3 | """
4 | import os
5 |
6 | def append_to_chatlog(chatlog_filepath, question, answer):
7 | """Append a question and answer to the chat log.
8 |
9 | Args:
10 | chatlog_filepath: Path to the chat log file
11 |
12 | question: the question string entered by the user
13 |
14 | answer: the answer string returned by the chatbot
15 | If chat_settings.show_all_beams = True, answer is the array of all answer beams with one string per beam.
16 | """
17 | chatlog_dir = os.path.dirname(chatlog_filepath)
18 | if not os.path.isdir(chatlog_dir):
19 | os.makedirs(chatlog_dir)
20 | with open(chatlog_filepath, "a", encoding="utf-8") as file:
21 | file.write("You: {0}".format(question))
22 | file.write('\n')
23 | file.write("ChatBot: {0}".format(answer))
24 | file.write('\n\n')
25 |
26 | def print_commands():
27 | """Print the list of available commands and their descriptions.
28 | """
29 | print()
30 | print()
31 | print("Commands:")
32 | print("-----------General-----------------")
33 | print("--help (Show this list of commands) --reset (Reset to default settings from hparams.json [*]);")
34 | print("--exit (Quit);")
35 | print()
36 | print("-----------Chat Options:-----------")
37 | print("--enableautopunct (Auto add punctuation to questions); --disableautopunct (Enter punctuation exactly as typed);")
38 | print("--enablenormwords (Auto replace 'don't' with 'do not', etc.); --disablenormwords (Enter words exactly as typed);")
39 | print("--showquestioncontext (Show conversation history as context); --hidequestioncontext (Show questions only);")
40 | print("--showbeams (Output all predicted beams); --hidebeams (Output only the highest ranked beam);")
41 | print("--convhistlength=N (Set conversation history length to N); --clearconvhist (Clear history and start a new conversation);")
42 | print()
43 | print("-----------Model Options:----------")
44 | print("--beamwidth=N (Set beam width to N. 0 disables beamsearch [*]); --beamlenpenalty=N (Set beam length penalty to N);")
45 | print("--enablesampling (Use sampling decoder if beamwidth=0 [*]); --disableasampling (Use greedy decoder if beamwidth=0 [*]);")
46 | print("--samplingtemp=N (Set sampling temperature to N); --maxanswerlen=N (Set max words in answer to N);")
47 | print()
48 | print()
49 | print("[*] Causes model to reload")
50 | print()
51 | print()
52 |
53 | def handle_command(input_str, model, chat_settings):
54 | """Given a user input string, determine if it is a command or a question and process if it is a command.
55 |
56 | Args:
57 | input_str: the user input string
58 |
59 | model: the ChatbotModel instance
60 |
61 | chat_settings: the ChatSettings instance
62 | """
63 | reload_model = False
64 | terminate_chat = False
65 | is_command = True
66 | cmd_value = _get_command_value(input_str)
67 | #General Commands
68 | if input_str == '--help':
69 | print_commands()
70 | elif input_str == '--reset':
71 | chat_settings.reset_to_defaults()
72 | reload_model = True
73 | print ("[Reset to default settings.]")
74 | elif input_str == '--exit':
75 | terminate_chat = True
76 | #Chat Options
77 | elif input_str == '--enableautopunct':
78 | chat_settings.enable_auto_punctuation = True
79 | print ("[Auto-punctuation enabled.]")
80 | elif input_str == '--disableautopunct':
81 | chat_settings.enable_auto_punctuation = False
82 | print ("[Auto-punctuation disabled.]")
83 | elif input_str == '--enablenormwords':
84 | chat_settings.inference_hparams.normalize_words = True
85 | print ("[Word normalization enabled.]")
86 | elif input_str == '--disablenormwords':
87 | chat_settings.inference_hparams.normalize_words = False
88 | print ("[Word normalization disabled.]")
89 | elif input_str == '--showquestioncontext':
90 | chat_settings.show_question_context = True
91 | print ("[Show question context enabled.]")
92 | elif input_str == "--hidequestioncontext":
93 | chat_settings.show_question_context = False
94 | print ("[Show question context disabled.]")
95 | elif input_str == '--showbeams':
96 | chat_settings.show_all_beams = True
97 | print ("[Show all beams enabled.]")
98 | elif input_str == "--hidebeams":
99 | chat_settings.show_all_beams = False
100 | print ("[Show all beams disabled.]")
101 | elif input_str.startswith("--convhistlength"):
102 | if cmd_value is not None:
103 | chat_settings.inference_hparams.conv_history_length = int(cmd_value)
104 | model.trim_conversation_history(chat_settings.inference_hparams.conv_history_length)
105 | print ("[Conversation history length set to {0}.]".format(chat_settings.inference_hparams.conv_history_length))
106 | elif input_str == '--clearconvhist':
107 | model.trim_conversation_history(0)
108 | print ("[Conversation history cleared.]")
109 | #Model Options
110 | elif input_str.startswith("--beamwidth"):
111 | if cmd_value is not None:
112 | chat_settings.model_hparams.beam_width = int(cmd_value)
113 | reload_model = True
114 | print ("[Beam width set to {0}.]".format(chat_settings.model_hparams.beam_width))
115 | elif input_str.startswith("--beamlenpenalty"):
116 | if cmd_value is not None:
117 | chat_settings.inference_hparams.beam_length_penalty_weight = float(cmd_value)
118 | print ("[Beam length penalty weight set to {0}.]".format(chat_settings.inference_hparams.beam_length_penalty_weight))
119 | elif input_str == '--enablesampling':
120 | chat_settings.model_hparams.enable_sampling = True
121 | if chat_settings.model_hparams.beam_width == 0:
122 | reload_model = True
123 | print ("[Sampling decoder enabled (if beamwidth=0).]")
124 | elif input_str == '--disablesampling':
125 | chat_settings.model_hparams.enable_sampling = False
126 | if chat_settings.model_hparams.beam_width == 0:
127 | reload_model = True
128 | print ("[Sampling decoder disabled. Using greedy decoder (if beamwidth=0).]")
129 | elif input_str.startswith("--samplingtemp"):
130 | if cmd_value is not None:
131 | chat_settings.inference_hparams.sampling_temperature = float(cmd_value)
132 | print ("[Sampling temperature set to {0}.]".format(chat_settings.inference_hparams.sampling_temperature))
133 | elif input_str.startswith("--maxanswerlen"):
134 | if cmd_value is not None:
135 | chat_settings.inference_hparams.max_answer_words = int(cmd_value)
136 | print ("[Max words in answer set to {0}.]".format(chat_settings.inference_hparams.max_answer_words))
137 | #Not a command
138 | else:
139 | is_command = False
140 |
141 | return is_command, terminate_chat, reload_model
142 |
143 | def _get_command_value(input_str):
144 | """Parses a command string and returns the value to the right of the '=' sign
145 |
146 | Args:
147 | input_str: the command string
148 | """
149 | idx = input_str.find("=")
150 | if idx > -1:
151 | return input_str[idx+1:].strip()
152 | else:
153 | return None
--------------------------------------------------------------------------------
/chat_settings.py:
--------------------------------------------------------------------------------
1 | """
2 | ChatSettings class
3 | """
4 | import copy
5 |
6 | class ChatSettings(object):
7 | """Contains settings for a chat session.
8 | """
9 | def __init__(self, model_hparams, inference_hparams):
10 | """
11 | Args:
12 | inference_hparams: the loaded InferenceHparams instance to use as default for this chat session
13 | """
14 | self.show_question_context = False
15 | self.show_all_beams = False
16 | self.enable_auto_punctuation = True
17 | self.model_hparams = None
18 | self.inference_hparams = None
19 |
20 | self._default_model_hparams = model_hparams
21 | self._default_inference_hparams = inference_hparams
22 | self.reset_to_defaults()
23 |
24 | def reset_to_defaults(self):
25 | """Reset all settings to defaults
26 | """
27 | self.show_question_context = False
28 | self.show_all_beams = False
29 | self.enable_auto_punctuation = True
30 | self.model_hparams = copy.copy(self._default_model_hparams)
31 | self.inference_hparams = copy.copy(self._default_inference_hparams)
--------------------------------------------------------------------------------
/chat_ui.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ChatBot
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
82 |
83 |
160 |
161 |
188 |
189 |
--------------------------------------------------------------------------------
/chat_web.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for serving a trained chatbot model over http
3 | """
4 | import datetime
5 | import click
6 | from os import path
7 | from flask import Flask, request, send_from_directory
8 | from flask_cors import CORS
9 | from flask_restful import Resource, Api
10 |
11 | import general_utils
12 | import chat_command_handler
13 | from chat_settings import ChatSettings
14 | from chatbot_model import ChatbotModel
15 | from vocabulary import Vocabulary
16 |
17 | app = Flask(__name__)
18 | CORS(app)
19 |
20 | @app.cli.command()
21 | @click.argument("checkpointfile")
22 | @click.option("-p", "--port", type=int)
23 | def serve_chat(checkpointfile, port):
24 |
25 | api = Api(app)
26 |
27 | #Read the hyperparameters and configure paths
28 | model_dir, hparams, checkpoint = general_utils.initialize_session_server(checkpointfile)
29 |
30 | #Load the vocabulary
31 | print()
32 | print ("Loading vocabulary...")
33 | if hparams.model_hparams.share_embedding:
34 | shared_vocab_filepath = path.join(model_dir, Vocabulary.SHARED_VOCAB_FILENAME)
35 | input_vocabulary = Vocabulary.load(shared_vocab_filepath)
36 | output_vocabulary = input_vocabulary
37 | else:
38 | input_vocab_filepath = path.join(model_dir, Vocabulary.INPUT_VOCAB_FILENAME)
39 | input_vocabulary = Vocabulary.load(input_vocab_filepath)
40 | output_vocab_filepath = path.join(model_dir, Vocabulary.OUTPUT_VOCAB_FILENAME)
41 | output_vocabulary = Vocabulary.load(output_vocab_filepath)
42 |
43 | #Create the model
44 | print ("Initializing model...")
45 | print()
46 | with ChatbotModel(mode = "infer",
47 | model_hparams = hparams.model_hparams,
48 | input_vocabulary = input_vocabulary,
49 | output_vocabulary = output_vocabulary,
50 | model_dir = model_dir) as model:
51 |
52 | #Load the weights
53 | print()
54 | print ("Loading model weights...")
55 | model.load(checkpoint)
56 |
57 | # Setting up the chat
58 | chatlog_filepath = path.join(model_dir, "chat_logs", "web_chatlog_{0}.txt".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S")))
59 | chat_settings = ChatSettings(hparams.model_hparams, hparams.inference_hparams)
60 | chat_command_handler.print_commands()
61 |
62 | class Answer(Resource):
63 | def get(self, question):
64 | is_command, terminate_chat, _ = chat_command_handler.handle_command(question, model, chat_settings)
65 | if terminate_chat:
66 | answer = "[Can't terminate from http request]"
67 | elif is_command:
68 | answer = "[Command processed]"
69 | else:
70 | #If it is not a command (it is a question), pass it on to the chatbot model to get the answer
71 | _, answer = model.chat(question, chat_settings)
72 |
73 | if chat_settings.inference_hparams.log_chat:
74 | chat_command_handler.append_to_chatlog(chatlog_filepath, question, answer)
75 |
76 | return answer
77 |
78 | class UI(Resource):
79 | def get(self):
80 | return send_from_directory(".", "chat_ui.html")
81 |
82 | api.add_resource(Answer, "/chat/")
83 | api.add_resource(UI, "/chat_ui/")
84 | app.run(debug=False, port=port)
--------------------------------------------------------------------------------
/chatbot_model.py:
--------------------------------------------------------------------------------
1 | """
2 | ChatbotModel class
3 | """
4 | import numpy as np
5 | import tensorflow as tf
6 | from tensorflow.python.layers import core as layers_core
7 | from tensorflow.contrib.tensorboard.plugins import projector
8 | from os import path
9 |
10 | from vocabulary import Vocabulary
11 |
12 | class ChatbotModel(object):
13 | """Seq2Seq chatbot model class.
14 |
15 | This class encapsulates all interaction with TensorFlow. Users of this class need not be aware of the dependency on TF.
16 | This abstraction allows for future use of alternative DL libraries* by writing ChatbotModel implementations for them
17 | without the calling code needing to change.
18 |
19 | This class implementation was influenced by studying the TensorFlow Neural Machine Translation (NMT) tutorial:
20 | - https://www.tensorflow.org/tutorials/seq2seq
21 | These files were helpful in understanding the seq2seq implementation:
22 | - https://github.com/tensorflow/nmt/blob/master/nmt/model.py
23 | - https://github.com/tensorflow/nmt/blob/master/nmt/model_helper.py
24 | - https://github.com/tensorflow/nmt/blob/master/nmt/attention_model.py
25 |
26 | Although the code in this implementation is original, there are some similarities with the NMT tutorial in certain places.
27 | """
28 |
29 | def __init__(self,
30 | mode,
31 | model_hparams,
32 | input_vocabulary,
33 | output_vocabulary,
34 | model_dir):
35 | """Create the Seq2Seq chatbot model.
36 |
37 | Args:
38 | mode: "train" or "infer"
39 |
40 | model_hparams: parameters which determine the architecture and complexity of the model.
41 | See hparams.py for in-depth comments.
42 |
43 | input_vocabulary: the input vocabulary. Each word in this vocabulary gets its own vector
44 | in the encoder embedding matrix.
45 |
46 | output_vocabulary: the output vocabulary. Each word in this vocabulary gets its own vector
47 | in the decoder embedding matrix.
48 |
49 | model_dir: the directory to output model summaries and load / save checkpoints.
50 | """
51 | self.mode = mode
52 | self.model_hparams = model_hparams
53 | self.input_vocabulary = input_vocabulary
54 | self.output_vocabulary = output_vocabulary
55 | self.model_dir = model_dir
56 |
57 | #Validate arguments
58 | if self.model_hparams.share_embedding and self.input_vocabulary.size() != self.output_vocabulary.size():
59 | raise ValueError("Cannot share embedding matrices when the input and output vocabulary sizes are different.")
60 |
61 | if self.model_hparams.share_embedding and self.model_hparams.encoder_embedding_size != self.model_hparams.decoder_embedding_size:
62 | raise ValueError("Cannot share embedding matrices when the encoder and decoder embedding sizes are different.")
63 |
64 | if self.input_vocabulary.external_embeddings is not None and self.input_vocabulary.external_embeddings.shape[1] != self.model_hparams.encoder_embedding_size:
65 | raise ValueError("Cannot use external embeddings with vector size {0} for the encoder when the hparams encoder embedding size is {1}".format(self.input_vocabulary.external_embeddings.shape[1],
66 | self.model_hparams.encoder_embedding_size))
67 | if self.output_vocabulary.external_embeddings is not None and self.output_vocabulary.external_embeddings.shape[1] != self.model_hparams.decoder_embedding_size:
68 | raise ValueError("Cannot use external embeddings with vector size {0} for the decoder when the hparams decoder embedding size is {1}".format(self.output_vocabulary.external_embeddings.shape[1],
69 | self.model_hparams.decoder_embedding_size))
70 |
71 | tf.contrib.learn.ModeKeys.validate(self.mode)
72 |
73 | if self.mode == tf.contrib.learn.ModeKeys.TRAIN or self.model_hparams.beam_width is None:
74 | self.beam_width = 0
75 | else:
76 | self.beam_width = self.model_hparams.beam_width
77 |
78 | #Reset the default TF graph
79 | tf.reset_default_graph()
80 |
81 | #Define general model inputs
82 | self.inputs = tf.placeholder(tf.int32, [None, None], name = "inputs")
83 | self.input_sequence_length = tf.placeholder(tf.int32, [None], name = "input_sequence_length")
84 |
85 | #Build model
86 | initializer_feed_dict = {}
87 | if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
88 | #Define training model inputs
89 | self.targets = tf.placeholder(tf.int32, [None, None], name = "targets")
90 | self.target_sequence_length = tf.placeholder(tf.int32, [None], name = "target_sequence_length")
91 | self.learning_rate = tf.placeholder(tf.float32, name= "learning_rate")
92 | self.keep_prob = tf.placeholder(tf.float32, name = "keep_prob")
93 | if self.input_vocabulary.external_embeddings is not None:
94 | self.input_external_embeddings = tf.placeholder(tf.float32,
95 | shape = self.input_vocabulary.external_embeddings.shape,
96 | name = "input_external_embeddings")
97 | initializer_feed_dict[self.input_external_embeddings] = self.input_vocabulary.external_embeddings
98 | if self.output_vocabulary.external_embeddings is not None and not self.model_hparams.share_embedding:
99 | self.output_external_embeddings = tf.placeholder(tf.float32,
100 | shape = self.output_vocabulary.external_embeddings.shape,
101 | name = "output_external_embeddings")
102 | initializer_feed_dict[self.output_external_embeddings] = self.output_vocabulary.external_embeddings
103 |
104 | self.loss, self.training_step = self._build_model()
105 |
106 | elif self.mode == tf.contrib.learn.ModeKeys.INFER:
107 | #Define inference model inputs
108 | self.max_output_sequence_length = tf.placeholder(tf.int32, [], name="max_output_sequence_length")
109 | self.beam_length_penalty_weight = tf.placeholder(tf.float32, name = "beam_length_penalty_weight")
110 | self.sampling_temperature = tf.placeholder(tf.float32, name = "sampling_temperature")
111 |
112 | self.predictions, self.predictions_seq_lengths = self._build_model()
113 | self.conversation_history = []
114 | else:
115 | raise ValueError("Unsupported model mode. Choose 'train' or 'infer'.")
116 |
117 | # Get the final merged summary for writing to TensorBoard
118 | self.merged_summary = tf.summary.merge_all()
119 |
120 | # Defining the session, summary writer, and checkpoint saver
121 | self.session = self._create_session()
122 |
123 | self.summary_writer = tf.summary.FileWriter(self.model_dir, self.session.graph)
124 |
125 | self.session.run(tf.global_variables_initializer(), initializer_feed_dict)
126 |
127 | self.saver = tf.train.Saver()
128 |
129 | def load(self, filename):
130 | """Loads a trained model from a checkpoint
131 |
132 | Args:
133 | filename: Checkpoint filename, such as best_model_checkpoint.ckpt
134 | This file must exist within model_dir.
135 | """
136 | filepath = path.join(self.model_dir, filename)
137 | self.saver.restore(self.session, filepath)
138 |
139 | def save(self, filename):
140 | """Saves a checkpoint of the current model weights
141 |
142 | Args:
143 | filename: Checkpoint filename, such as best_model_checkpoint.ckpt.
144 | This file must exist within model_dir.
145 | """
146 | filepath = path.join(self.model_dir, filename)
147 | self.saver.save(self.session, filepath)
148 |
149 | config = projector.ProjectorConfig()
150 | if self.model_hparams.share_embedding:
151 | shared_embedding = config.embeddings.add()
152 | shared_embedding.tensor_name = "model/encoder/shared_embeddings_matrix"
153 | shared_embedding.metadata_path = Vocabulary.SHARED_VOCAB_FILENAME
154 | else:
155 | encoder_embedding = config.embeddings.add()
156 | encoder_embedding.tensor_name = "model/encoder/encoder_embeddings_matrix"
157 | encoder_embedding.metadata_path = Vocabulary.INPUT_VOCAB_FILENAME
158 | decoder_embedding = config.embeddings.add()
159 | decoder_embedding.tensor_name = "model/decoder/decoder_embeddings_matrix"
160 | decoder_embedding.metadata_path = Vocabulary.OUTPUT_VOCAB_FILENAME
161 |
162 | projector.visualize_embeddings(self.summary_writer, config)
163 |
164 | def train_batch(self, inputs, targets, input_sequence_length, target_sequence_length, learning_rate, dropout, global_step, log_summary=True):
165 | """Train the model on one batch, and return the training loss.
166 |
167 | Args:
168 | inputs: The input matrix of shape (batch_size, sequence_length)
169 | where each value in the sequences are words encoded as integer indexes of the input vocabulary.
170 |
171 | targets: The target matrix of shape (batch_size, sequence_length)
172 | where each value in the sequences are words encoded as integer indexes of the output vocabulary.
173 |
174 | input_sequence_length: A vector of sequence lengths of shape (batch_size)
175 | containing the lengths of every input sequence in the batch. This allows for dynamic sequence lengths.
176 |
177 | target_sequence_length: A vector of sequence lengths of shape (batch_size)
178 | containing the lengths of every target sequence in the batch. This allows for dynamic sequence lengths.
179 |
180 | learning_rate: The learning rate to use for the weight updates.
181 |
182 | dropout: The probability (0 <= p <= 1) that any neuron will be randomly disabled for this training step.
183 | This regularization technique can allow the model to learn more independent relationships in the input data
184 | and reduce overfitting. Too much dropout can make the model underfit. Typical values range between 0.2 - 0.5
185 |
186 | global_step: The index of this training step across all batches and all epochs. This allows TensorBoard to trend
187 | the training visually over time.
188 |
189 | log_summary: Flag indicating if the training summary should be logged (for visualization in TensorBoard).
190 | """
191 |
192 | if self.mode != tf.contrib.learn.ModeKeys.TRAIN:
193 | raise ValueError("train_batch can only be called when the model is initialized in train mode.")
194 |
195 | #Calculate the keep_probability (prob. a neuron will not be dropped) as 1 - dropout rate
196 | keep_probability = 1.0 - dropout
197 |
198 | #Train on the batch
199 | _, batch_training_loss, merged_summary = self.session.run([self.training_step, self.loss, self.merged_summary],
200 | { self.inputs: inputs,
201 | self.targets: targets,
202 | self.input_sequence_length: input_sequence_length,
203 | self.target_sequence_length: target_sequence_length,
204 | self.learning_rate: learning_rate,
205 | self.keep_prob: keep_probability })
206 |
207 | #Write the training summary for this step if summary logging is enabled.
208 | if log_summary:
209 | self.summary_writer.add_summary(merged_summary, global_step)
210 |
211 | return batch_training_loss
212 |
213 | def validate_batch(self, inputs, targets, input_sequence_length, target_sequence_length, metric = "loss"):
214 | """Evaluate the metric on one batch and return.
215 |
216 | Args:
217 | inputs: The input matrix of shape (batch_size, sequence_length)
218 | where each value in the sequences are words encoded as integer indexes of the input vocabulary.
219 |
220 | targets: The target matrix of shape (batch_size, sequence_length)
221 | where each value in the sequences are words encoded as integer indexes of the output vocabulary.
222 |
223 | input_sequence_length: A vector of sequence lengths of shape (batch_size)
224 | containing the lengths of every input sequence in the batch. This allows for dynamic sequence lengths.
225 |
226 | target_sequence_length: A vector of sequence lengths of shape (batch_size)
227 | containing the lengths of every target sequence in the batch. This allows for dynamic sequence lengths.
228 |
229 | metric: The desired validation metric. Currently only "loss" is supported. This will eventually support
230 | "accuracy", "bleu", and other common validation metrics.
231 | """
232 |
233 | if self.mode != tf.contrib.learn.ModeKeys.TRAIN:
234 | raise ValueError("validate_batch can only be called when the model is initialized in train mode.")
235 |
236 | if metric == "loss":
237 | metric_op = self.loss
238 | else:
239 | raise ValueError("Unsupported validation metric: '{0}'".format(metric))
240 |
241 | metric_value = self.session.run(metric_op, { self.inputs: inputs,
242 | self.targets: targets,
243 | self.input_sequence_length: input_sequence_length,
244 | self.target_sequence_length: target_sequence_length,
245 | self.keep_prob: 1 })
246 |
247 | return metric_value
248 |
249 | def predict_batch(self, inputs, input_sequence_length, max_output_sequence_length, beam_length_penalty_weight, sampling_temperature, log_summary=True):
250 | """Predict a batch of output sequences given a batch of input sequences.
251 |
252 | Args:
253 | inputs: The input matrix of shape (batch_size, sequence_length)
254 | where each value in the sequences are words encoded as integer indexes of the input vocabulary.
255 |
256 | input_sequence_length: A vector of sequence lengths of shape (batch_size)
257 | containing the lengths of every input sequence in the batch. This allows for dynamic sequence lengths.
258 |
259 | max_output_sequence_length: The maximum number of timesteps the decoder can generate.
260 | If the decoder generates an EOS token sooner, it will end there. This maximum value just makes sure
261 | the decoder doesn't go on forever if no EOS is generated.
262 |
263 | beam_length_penalty_weight: When using beam search decoding, this penalty weight influences how
264 | beams are ranked. Large negative values rank very short beams first while large postive values rank very long beams first.
265 | A value of 0 will not influence the beam ranking. For a chatbot model, positive values between 0 and 2 can be beneficial
266 | to help the bot avoid short generic answers.
267 |
268 | sampling_temperature: When using sampling decoding, higher temperature values result in more random sampling
269 | while lower temperature values behave more like greedy decoding which takes the argmax of the output class distribution
270 | (softmax probability distribution over the output vocabulary). If this value is set to 0, sampling is disabled
271 | and greedy decoding is used.
272 |
273 | log_summary: Flag indicating if the inference summary should be logged (for visualization in TensorBoard).
274 | """
275 |
276 | if self.mode != tf.contrib.learn.ModeKeys.INFER:
277 | raise ValueError("predict_batch can only be called when the model is initialized in infer mode.")
278 |
279 | fetches = [{ "predictions": self.predictions, "predictions_seq_lengths": self.predictions_seq_lengths }]
280 | if self.merged_summary is not None:
281 | fetches.append(self.merged_summary)
282 |
283 | predicted_output_info = self.session.run(fetches, { self.inputs: inputs,
284 | self.input_sequence_length: input_sequence_length,
285 | self.max_output_sequence_length: max_output_sequence_length,
286 | self.beam_length_penalty_weight: beam_length_penalty_weight,
287 | self.sampling_temperature: sampling_temperature })
288 |
289 | #Write the training summary for this prediction if summary logging is enabled.
290 | if log_summary and len(predicted_output_info) == 2:
291 | merged_summary = predicted_output_info[1]
292 | self.summary_writer.add_summary(merged_summary)
293 |
294 | return predicted_output_info[0]
295 |
296 | def chat(self, question, chat_settings):
297 | """Chat with the chatbot model by predicting an answer to a question.
298 | 'question' and 'answer' in this context are generic terms for the interactions in a dialog exchange
299 | and can be statements, remarks, queries, requests, or any other type of dialog speech.
300 | For example:
301 | Question: "How are you?" Answer: "Fine."
302 | Question: "That's great." Answer: "Yeah."
303 |
304 | Args:
305 | question: The input question for which the model should predict an answer.
306 |
307 | chat_settings: The ChatSettings instance containing the chat settings and inference hyperparameters
308 |
309 | Returns:
310 | q_with_hist: question with history if chat_settings.show_question_context = True otherwise None.
311 |
312 | answers: array of answer beams if chat_settings.show_all_beams = True otherwise the single selected answer.
313 |
314 | """
315 | #Process the question by cleaning it and converting it to an integer encoded vector
316 | if chat_settings.enable_auto_punctuation:
317 | question = Vocabulary.auto_punctuate(question)
318 | question = Vocabulary.clean_text(question, normalize_words = chat_settings.inference_hparams.normalize_words)
319 | question = self.input_vocabulary.words2ints(question)
320 |
321 | #Prepend the currently tracked steps of the conversation history separated by EOS tokens.
322 | #This allows for deeper dialog context to influence the answer prediction.
323 | question_with_history = []
324 | for i in range(len(self.conversation_history)):
325 | question_with_history += self.conversation_history[i] + [self.input_vocabulary.eos_int()]
326 | question_with_history += question
327 |
328 | #Get the answer prediction
329 | batch = np.zeros((1, len(question_with_history)))
330 | batch[0] = question_with_history
331 | max_output_sequence_length = chat_settings.inference_hparams.max_answer_words + 1 # + 1 since the EOS token is counted as a timestep
332 | predicted_answer_info = self.predict_batch(inputs = batch,
333 | input_sequence_length = np.array([len(question_with_history)]),
334 | max_output_sequence_length = max_output_sequence_length,
335 | beam_length_penalty_weight = chat_settings.inference_hparams.beam_length_penalty_weight,
336 | sampling_temperature = chat_settings.inference_hparams.sampling_temperature,
337 | log_summary = chat_settings.inference_hparams.log_summary)
338 |
339 | #Read the answer prediction
340 | answer_beams = []
341 | if self.beam_width > 0:
342 | #For beam search decoding: if show_all_beams is enabeled then output all beams (sequences), otherwise take the first beam.
343 | # The beams (in the "predictions" matrix) are ordered with the highest ranked beams first.
344 | beam_count = 1 if not chat_settings.show_all_beams else len(predicted_answer_info["predictions_seq_lengths"][0])
345 | for i in range(beam_count):
346 | predicted_answer_seq_length = predicted_answer_info["predictions_seq_lengths"][0][i] - 1 #-1 to exclude the EOS token
347 | predicted_answer = predicted_answer_info["predictions"][0][:predicted_answer_seq_length, i].tolist()
348 | answer_beams.append(predicted_answer)
349 | else:
350 | #For greedy / sampling decoding: only one beam (sequence) is returned, based on the argmax for greedy decoding
351 | # or the sampling distribution for sampling decoding. Return this beam.
352 | beam_count = 1
353 | predicted_answer_seq_length = predicted_answer_info["predictions_seq_lengths"][0] - 1 #-1 to exclude the EOS token
354 | predicted_answer = predicted_answer_info["predictions"][0][:predicted_answer_seq_length].tolist()
355 | answer_beams.append(predicted_answer)
356 |
357 | #Add new conversation steps to the end of the history and trim from the beginning if it is longer than conv_history_length
358 | #Answers need to be converted from output_vocabulary ints to input_vocabulary ints (since they will be fed back in to the encoder)
359 | self.conversation_history.append(question)
360 | answer_for_history = self.output_vocabulary.ints2words(answer_beams[0], is_punct_discrete_word = True, capitalize_i = False)
361 | answer_for_history = self.input_vocabulary.words2ints(answer_for_history)
362 | self.conversation_history.append(answer_for_history)
363 | self.trim_conversation_history(chat_settings.inference_hparams.conv_history_length)
364 |
365 | #Convert the answer(s) to text and return
366 | answers = []
367 | for i in range(beam_count):
368 | answer = self.output_vocabulary.ints2words(answer_beams[i])
369 | answers.append(answer)
370 |
371 | q_with_hist = None if not chat_settings.show_question_context else self.input_vocabulary.ints2words(question_with_history)
372 | if chat_settings.show_all_beams:
373 | return q_with_hist, answers
374 | else:
375 | return q_with_hist, answers[0]
376 |
377 | def trim_conversation_history(self, length):
378 | """Trims the conversation history to the desired length by removing entries from the beginning of the array.
379 | This is the same conversation history prepended to each question to enable deep dialog context, so the shorter
380 | the length the less context the next question will have.
381 |
382 | Args:
383 | length: The desired length to trim the conversation history down to.
384 | """
385 | while len(self.conversation_history) > length:
386 | self.conversation_history.pop(0)
387 |
388 | def __enter__(self):
389 | return self
390 |
391 | def __exit__(self, exception_type, exception_value, traceback):
392 | try:
393 | self.summary_writer.close()
394 | except:
395 | pass
396 | try:
397 | self.session.close()
398 | except:
399 | pass
400 |
401 | def _build_model(self):
402 | """Create the seq2seq model graph.
403 |
404 | Since TensorFlow's default behavior is deferred execution, none of the tensor objects below actually have values until
405 | session.Run is called to train, validate, or predict a batch of inputs.
406 |
407 | Eager execution was introduced in TF 1.5, but as of now this code does not use it.
408 | """
409 | with tf.variable_scope("model"):
410 | #Batch size for each batch is infered by looking at the first dimension of the input matrix.
411 | #While batch size is generally defined as a hyperparameter and does not change, in practice it can vary.
412 | #An example of this is if the number of samples in the training set is not evenly divisible by the batch size
413 | #in which case the last batch of each epoch will be smaller than the preset hyperparameter value.
414 | batch_size = tf.shape(self.inputs)[0]
415 |
416 | #encoder
417 | with tf.variable_scope("encoder"):
418 | #encoder_embeddings_matrix is a trainable matrix of values that contain the word embeddings for the input sequences.
419 | # when a word is "embedded", it means that the input to the model is a dense N-dimensional vector that represents the word
420 | # instead of a sparse one-hot encoded vector with the dimension of the word's index in the entire vocabulary set to 1.
421 | # At training time, the dense embedding values that represent each word are updated in the direction of the loss gradient
422 | # just like normal weights. Thus the model learns the contextual relationships between the words (the embedding) along with
423 | # the objective function that depends on the words (the decoding).
424 | encoder_embeddings_matrix_name = "shared_embeddings_matrix" if self.model_hparams.share_embedding else "encoder_embeddings_matrix"
425 | encoder_embeddings_matrix_shape = [self.input_vocabulary.size(), self.model_hparams.encoder_embedding_size]
426 | encoder_embeddings_initial_value = (
427 | self.input_external_embeddings if self.input_vocabulary.external_embeddings is not None
428 | else tf.random_uniform(encoder_embeddings_matrix_shape, 0, 1))
429 | encoder_embeddings_matrix = tf.Variable(encoder_embeddings_initial_value,
430 | name = encoder_embeddings_matrix_name,
431 | trainable = self.model_hparams.encoder_embedding_trainable,
432 | expected_shape = encoder_embeddings_matrix_shape)
433 |
434 | #As described above, the sequences of word vocabulary indexes in the inputs matrix are converted to sequences of
435 | #N-dimensional dense vectors, by "looking them up" by index in the encoder_embeddings_matrix.
436 | encoder_embedded_input = tf.nn.embedding_lookup(encoder_embeddings_matrix, self.inputs)
437 |
438 | #Build the encoder RNN
439 | encoder_outputs, encoder_state = self._build_encoder(encoder_embedded_input)
440 |
441 | #Decoder
442 | with tf.variable_scope("decoder") as decoder_scope:
443 | #For description of word embeddings, see comments above on the encoder_embeddings_matrix.
444 | #If the share_embedding flag is set to True, the same matrix is used to embed words in the input and target sequences.
445 | #This is useful to avoid redundency when the same vocabulary is used for the inputs and targets
446 | # (why learn two ways to embed the same words?)
447 | if self.model_hparams.share_embedding:
448 | decoder_embeddings_matrix = encoder_embeddings_matrix
449 | else:
450 | decoder_embeddings_matrix_shape = [self.output_vocabulary.size(), self.model_hparams.decoder_embedding_size]
451 | decoder_embeddings_initial_value = (
452 | self.output_external_embeddings if self.output_vocabulary.external_embeddings is not None
453 | else tf.random_uniform(decoder_embeddings_matrix_shape, 0, 1))
454 | decoder_embeddings_matrix = tf.Variable(decoder_embeddings_initial_value,
455 | name = "decoder_embeddings_matrix",
456 | trainable = self.model_hparams.decoder_embedding_trainable,
457 | expected_shape = decoder_embeddings_matrix_shape)
458 |
459 | #Create the attentional decoder cell
460 | decoder_cell, decoder_initial_state = self._build_attention_decoder_cell(encoder_outputs,
461 | encoder_state,
462 | batch_size)
463 |
464 | #Output (projection) layer
465 | weights = tf.truncated_normal_initializer(stddev=0.1)
466 | biases = tf.zeros_initializer()
467 | output_layer = layers_core.Dense(units = self.output_vocabulary.size(),
468 | kernel_initializer = weights,
469 | bias_initializer = biases,
470 | use_bias = True,
471 | name = "output_dense")
472 |
473 | #Build the decoder RNN using the attentional decoder cell and output layer
474 | if self.mode != tf.contrib.learn.ModeKeys.INFER:
475 | #In train / validate mode, the training step and loss are returned.
476 | loss, training_step = self._build_training_decoder(batch_size,
477 | decoder_embeddings_matrix,
478 | decoder_cell,
479 | decoder_initial_state,
480 | decoder_scope,
481 | output_layer)
482 | return loss, training_step
483 | else:
484 | #In inference mode, the predictions and prediction sequence lengths are returned.
485 | #The sequence lengths can differ, but the predictions matrix will be one fixed size.
486 | #The predictions_seq_lengths array can be used to properly read the sequences of variable lengths.
487 | predictions, predictions_seq_lengths = self._build_inference_decoder(batch_size,
488 | decoder_embeddings_matrix,
489 | decoder_cell,
490 | decoder_initial_state,
491 | decoder_scope,
492 | output_layer)
493 | return predictions, predictions_seq_lengths
494 |
495 | def _build_encoder(self, encoder_embedded_input):
496 | """Create the encoder RNN
497 |
498 | Args:
499 | encoder_embedded_input: The embedded input sequences.
500 | """
501 | keep_prob = self.keep_prob if self.mode == tf.contrib.learn.ModeKeys.TRAIN else None
502 | if self.model_hparams.use_bidirectional_encoder:
503 | #Bi-directional encoding designates one or more RNN cells to read the sequence forward and one or more RNN cells to read
504 | #the sequence backward. The resulting states are concatenated before sending them on to the decoder.
505 | num_bi_layers = int(self.model_hparams.encoder_num_layers / 2)
506 |
507 | encoder_cell_forward = self._create_rnn_cell(self.model_hparams.rnn_size, num_bi_layers, keep_prob)
508 | encoder_cell_backward = self._create_rnn_cell(self.model_hparams.rnn_size, num_bi_layers, keep_prob)
509 |
510 | bi_encoder_outputs, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn(
511 | cell_fw = encoder_cell_forward,
512 | cell_bw = encoder_cell_backward,
513 | sequence_length = self.input_sequence_length,
514 | inputs = encoder_embedded_input,
515 | dtype = tf.float32,
516 | swap_memory=True)
517 |
518 | #Manipulating encoder state to handle multi bidirectional layers
519 | encoder_outputs = tf.concat(bi_encoder_outputs, -1)
520 |
521 | if num_bi_layers == 1:
522 | encoder_state = bi_encoder_state
523 | else:
524 | # alternatively concat forward and backward states
525 | encoder_state = []
526 | for layer_id in range(num_bi_layers):
527 | encoder_state.append(bi_encoder_state[0][layer_id]) # forward
528 | encoder_state.append(bi_encoder_state[1][layer_id]) # backward
529 | encoder_state = tuple(encoder_state)
530 |
531 | else:
532 | #Uni-directional encoding uses all RNN cells to read the sequence forward.
533 | encoder_cell = self._create_rnn_cell(self.model_hparams.rnn_size, self.model_hparams.encoder_num_layers, keep_prob)
534 |
535 | encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
536 | cell = encoder_cell,
537 | sequence_length = self.input_sequence_length,
538 | inputs = encoder_embedded_input,
539 | dtype = tf.float32,
540 | swap_memory=True)
541 |
542 | return encoder_outputs, encoder_state
543 |
544 | def _build_attention_decoder_cell(self, encoder_outputs, encoder_state, batch_size):
545 | """Create the RNN cell to be used as the decoder and apply an attention mechanism.
546 |
547 | Args:
548 | encoder_outputs: a tensor containing the output of the encoder at each timestep of the input sequence.
549 | this is used as the input to the attention mechanism.
550 |
551 | encoder_state: a tensor containing the final encoder state for each encoder cell after reading the input sequence.
552 | if the encoder and decoder have the same structure, this becomes the decoder initial state.
553 |
554 | batch_size: the batch size tensor
555 | (defined at the beginning of the model graph as the length of the first dimension of the input matrix)
556 | """
557 | #If beam search decoding - repeat the input sequence length, encoder output, encoder state, and batch size tensors
558 | #once for every beam.
559 | input_sequence_length = self.input_sequence_length
560 | if self.beam_width > 0:
561 | encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier = self.beam_width)
562 | input_sequence_length = tf.contrib.seq2seq.tile_batch(input_sequence_length, multiplier = self.beam_width)
563 | encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier = self.beam_width)
564 | batch_size = batch_size * self.beam_width
565 |
566 | #Construct the attention mechanism
567 | if self.model_hparams.attention_type == "bahdanau" or self.model_hparams.attention_type == "normed_bahdanau":
568 | normalize = self.model_hparams.attention_type == "normed_bahdanau"
569 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units = self.model_hparams.rnn_size,
570 | memory = encoder_outputs,
571 | memory_sequence_length = input_sequence_length,
572 | normalize = normalize)
573 | elif self.model_hparams.attention_type == "luong" or self.model_hparams.attention_type == "scaled_luong":
574 | scale = self.model_hparams.attention_type == "scaled_luong"
575 | attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units = self.model_hparams.rnn_size,
576 | memory = encoder_outputs,
577 | memory_sequence_length = input_sequence_length,
578 | scale = scale)
579 | else:
580 | raise ValueError("Unsupported attention type. Use ('bahdanau' / 'normed_bahdanau') for Bahdanau attention or ('luong' / 'scaled_luong') for Luong attention.")
581 |
582 | #Create the decoder cell and wrap with the attention mechanism
583 | with tf.variable_scope("decoder_cell"):
584 | keep_prob = self.keep_prob if self.mode == tf.contrib.learn.ModeKeys.TRAIN else None
585 | decoder_cell = self._create_rnn_cell(self.model_hparams.rnn_size, self.model_hparams.decoder_num_layers, keep_prob)
586 |
587 | alignment_history = self.mode == tf.contrib.learn.ModeKeys.INFER and self.beam_width == 0
588 | output_attention = self.model_hparams.attention_type == "luong" or self.model_hparams.attention_type == "scaled_luong"
589 | attention_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(cell = decoder_cell,
590 | attention_mechanism = attention_mechanism,
591 | attention_layer_size = self.model_hparams.rnn_size,
592 | alignment_history = alignment_history,
593 | output_attention = output_attention,
594 | name = "attention_decoder_cell")
595 |
596 | #If the encoder and decoder are the same structure, set the decoder initial state to the encoder final state.
597 | decoder_initial_state = attention_decoder_cell.zero_state(batch_size, tf.float32)
598 | if self.model_hparams.encoder_num_layers == self.model_hparams.decoder_num_layers:
599 | decoder_initial_state = decoder_initial_state.clone(cell_state = encoder_state)
600 |
601 | return attention_decoder_cell, decoder_initial_state
602 |
603 | def _build_training_decoder(self, batch_size, decoder_embeddings_matrix, decoder_cell, decoder_initial_state, decoder_scope, output_layer):
604 | """Build the decoder RNN for training mode.
605 |
606 | Currently this is implemented using the TensorFlow TrainingHelper, which uses the Teacher Forcing technique.
607 | "Teacher Forcing" means that at each output timestep, the next word of the target sequence is fed as input
608 | to the decoder without regard to the output prediction at the previous timestep.
609 |
610 | Args:
611 | batch_size: the batch size tensor
612 | (defined at the beginning of the model graph as the length of the first dimension of the input matrix)
613 |
614 | decoder_embeddings_matrix: The matrix containing the decoder embeddings
615 |
616 | decoder_cell: The RNN cell (or cells) used in the decoder.
617 |
618 | decoder_initial_state: The initial cell state of the decoder. This is the final encoder cell state if the encoder
619 | and decoder cells are structured the same. Otherwise it is a memory cell in zero state.
620 | """
621 | #Prepend each target sequence with the SOS token, as this is always the input of the first decoder timestep.
622 | preprocessed_targets = self._preprocess_targets(batch_size)
623 | #The sequences of word vocabulary indexes in the targets matrix are converted to sequences of
624 | #N-dimensional dense vectors, by "looking them up" by index in the decoder_embeddings_matrix.
625 | decoder_embedded_input = tf.nn.embedding_lookup(decoder_embeddings_matrix, preprocessed_targets)
626 |
627 | #Create the training decoder
628 | helper = tf.contrib.seq2seq.TrainingHelper(inputs = decoder_embedded_input,
629 | sequence_length = self.target_sequence_length)
630 |
631 | decoder = tf.contrib.seq2seq.BasicDecoder(cell = decoder_cell,
632 | helper = helper,
633 | initial_state = decoder_initial_state)
634 |
635 | #Get the decoder output
636 | decoder_output, _final_context_state, _final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder = decoder,
637 | swap_memory = True,
638 | scope = decoder_scope)
639 | #Pass the decoder output through the output dense layer which will become a softmax distribution of classes -
640 | #one class per word in the output vocabulary.
641 | logits = output_layer(decoder_output.rnn_output)
642 |
643 | #Calculate the softmax loss. Since the logits tensor is fixed size and the output sequences are variable length,
644 | #we need to "mask" the timesteps beyond the sequence length for each sequence. This multiplies the loss calculated
645 | #for those extra timesteps by 0, cancelling them out so they do not affect the final sequence loss.
646 | loss_mask = tf.sequence_mask(self.target_sequence_length, tf.shape(self.targets)[1], dtype=tf.float32)
647 | loss = tf.contrib.seq2seq.sequence_loss(logits = logits,
648 | targets = self.targets,
649 | weights = loss_mask)
650 | tf.summary.scalar("sequence_loss", loss)
651 |
652 | #Set up the optimizer
653 | if self.model_hparams.optimizer == "sgd":
654 | optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
655 | elif self.model_hparams.optimizer == "adam":
656 | optimizer = tf.train.AdamOptimizer(self.learning_rate)
657 | else:
658 | raise ValueError("Unsupported optimizer. Use 'sgd' for GradientDescentOptimizer or 'adam' for AdamOptimizer.")
659 |
660 | tf.summary.scalar("learning_rate", self.learning_rate)
661 | #If max_gradient_norm is provided, enable gradient clipping.
662 | #This prevents an exploding gradient from derailing the training by too much.
663 | if self.model_hparams.max_gradient_norm > 0.0:
664 | params = tf.trainable_variables()
665 | gradients = tf.gradients(loss, params)
666 | clipped_gradients, gradient_norm = tf.clip_by_global_norm(gradients, self.model_hparams.max_gradient_norm)
667 | tf.summary.scalar("gradient_norm", gradient_norm)
668 | tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients))
669 | training_step = optimizer.apply_gradients(zip(clipped_gradients, params))
670 | else:
671 | training_step = optimizer.minimize(loss = loss)
672 |
673 | return loss, training_step
674 |
675 | def _build_inference_decoder(self, batch_size, decoder_embeddings_matrix, decoder_cell, decoder_initial_state, decoder_scope, output_layer):
676 | """Build the decoder RNN for inference mode.
677 |
678 | In inference mode, the decoder takes the output of each timestep and feeds it in as the input to the next timestep
679 | and repeats this process until either an EOS token is generated or the maximum sequence length is reached.
680 |
681 | If beam_width > 0, beam search decoding is used.
682 | Beam search will sample from the top most likely words at each timestep and branch off (create a beam)
683 | on one or more of these words to explore a different version of the output sequence. For example:
684 | Question: How are you?
685 | Answer (beam 1): I am fine.
686 | Answer (beam 2): I am doing well.
687 | Answer (beam 3): I am doing alright considering the circumstances.
688 | Answer (beam 4): Good and yourself?
689 | Answer (beam 5): Good good.
690 | Etc...
691 |
692 | If beam_width = 0, greedy or sampling decoding is used.
693 |
694 | Args:
695 | See _build_training_decoder
696 | """
697 | #Get the SOS and EOS tokens
698 | start_tokens = tf.fill([batch_size], self.output_vocabulary.sos_int())
699 | end_token = self.output_vocabulary.eos_int()
700 |
701 | #Build the beam search, greedy, or sampling decoder
702 | if self.beam_width > 0:
703 | decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell = decoder_cell,
704 | embedding = decoder_embeddings_matrix,
705 | start_tokens = start_tokens,
706 | end_token = end_token,
707 | initial_state = decoder_initial_state,
708 | beam_width = self.beam_width,
709 | output_layer = output_layer,
710 | length_penalty_weight = self.beam_length_penalty_weight)
711 | else:
712 | if self.model_hparams.enable_sampling:
713 | helper = tf.contrib.seq2seq.SampleEmbeddingHelper(embedding = decoder_embeddings_matrix,
714 | start_tokens = start_tokens,
715 | end_token = end_token,
716 | softmax_temperature = self.sampling_temperature)
717 | else:
718 | helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding = decoder_embeddings_matrix,
719 | start_tokens = start_tokens,
720 | end_token = end_token)
721 |
722 | decoder = tf.contrib.seq2seq.BasicDecoder(cell = decoder_cell,
723 | helper = helper,
724 | initial_state = decoder_initial_state,
725 | output_layer = output_layer)
726 |
727 | #Get the decoder output
728 | decoder_output, final_context_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder = decoder,
729 | maximum_iterations = self.max_output_sequence_length,
730 | swap_memory = True,
731 | scope = decoder_scope)
732 |
733 | #Return the predicted sequences along with an array of the sequence lengths for each predicted sequence in the batch
734 | if self.beam_width > 0:
735 | predictions = decoder_output.predicted_ids
736 | predictions_seq_lengths = final_context_state.lengths
737 | else:
738 | predictions = decoder_output.sample_id
739 | predictions_seq_lengths = final_sequence_lengths
740 | #Create attention alignment summary for visualization in TensorBoard
741 | self._create_attention_images_summary(final_context_state)
742 |
743 | return predictions, predictions_seq_lengths
744 |
745 | def _create_rnn_cell(self, rnn_size, num_layers, keep_prob):
746 | """Create a single RNN cell or stack of RNN cells (depending on num_layers)
747 |
748 | Args:
749 | rnn_size: number of units (neurons) in each RNN cell
750 |
751 | num_layers: number of stacked RNN cells to create
752 |
753 | keep_prob: probability of not being dropped out (1 - dropout), or None for inference mode
754 | """
755 |
756 | cells = []
757 | for _ in range(num_layers):
758 | if self.model_hparams.rnn_cell_type == "lstm":
759 | rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=rnn_size)
760 | elif self.model_hparams.rnn_cell_type == "gru":
761 | rnn_cell = tf.nn.rnn_cell.GRUCell(num_units=rnn_size)
762 | else:
763 | raise ValueError("Unsupported RNN cell type. Use 'lstm' for LSTM or 'gru' for GRU.")
764 |
765 | if keep_prob is not None:
766 | rnn_cell = tf.contrib.rnn.DropoutWrapper(cell=rnn_cell, input_keep_prob=keep_prob)
767 |
768 | cells.append(rnn_cell)
769 |
770 | if len(cells) == 1:
771 | return cells[0]
772 | else:
773 | return tf.contrib.rnn.MultiRNNCell(cells = cells)
774 |
775 | def _preprocess_targets(self, batch_size):
776 | """Prepend the SOS token to all target sequences in the batch
777 |
778 | Args:
779 | batch_size: the batch size tensor
780 | (defined at the beginning of the model graph as the length of the first dimension of the input matrix)
781 | """
782 | left_side = tf.fill([batch_size, 1], self.output_vocabulary.sos_int())
783 | right_side = tf.strided_slice(self.targets, [0,0], [batch_size, -1], [1,1])
784 | preprocessed_targets = tf.concat([left_side, right_side], 1)
785 | return preprocessed_targets
786 |
787 | def _create_session(self):
788 | """Initialize the TensorFlow session
789 | """
790 | if self.model_hparams.gpu_dynamic_memory_growth:
791 | config = tf.ConfigProto()
792 | config.gpu_options.allow_growth = True
793 | session = tf.Session(config=config)
794 | else:
795 | session = tf.Session()
796 |
797 | return session
798 |
799 | def _create_attention_images_summary(self, final_context_state):
800 | """Create attention image and attention summary.
801 |
802 | TODO: this method was taken as is from the NMT tutorial and does not seem to work with the current model.
803 | figure this out and adjust as needed to the chatbot model.
804 |
805 | Args:
806 | final_context_state: final state of the decoder
807 | """
808 | attention_images = (final_context_state.alignment_history.stack())
809 | # Reshape to (batch, src_seq_len, tgt_seq_len,1)
810 | attention_images = tf.expand_dims(
811 | tf.transpose(attention_images, [1, 2, 0]), -1)
812 | # Scale to range [0, 255]
813 | attention_images *= 255
814 | tf.summary.image("attention_images", attention_images)
--------------------------------------------------------------------------------
/chatbot_screenshots/seq1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq1.png
--------------------------------------------------------------------------------
/chatbot_screenshots/seq10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq10.png
--------------------------------------------------------------------------------
/chatbot_screenshots/seq11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq11.png
--------------------------------------------------------------------------------
/chatbot_screenshots/seq2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq2.png
--------------------------------------------------------------------------------
/chatbot_screenshots/seq3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq3.png
--------------------------------------------------------------------------------
/chatbot_screenshots/seq4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq4.png
--------------------------------------------------------------------------------
/chatbot_screenshots/seq5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq5.png
--------------------------------------------------------------------------------
/chatbot_screenshots/seq6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq6.png
--------------------------------------------------------------------------------
/chatbot_screenshots/seq7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq7.png
--------------------------------------------------------------------------------
/chatbot_screenshots/seq8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq8.png
--------------------------------------------------------------------------------
/chatbot_screenshots/seq9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/chatbot_screenshots/seq9.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset class
3 | """
4 | import math
5 | import random
6 | import numpy as np
7 | from os import path
8 |
9 | class Dataset(object):
10 | """Class representing a chatbot dataset with questions, answers, and vocabulary.
11 | """
12 |
13 | def __init__(self, questions, answers, input_vocabulary, output_vocabulary):
14 | """Initializes a Dataset instance with a list of questions, answers, and input/output vocabularies.
15 |
16 | Args:
17 | questions: Can be a list of questions as space delimited sentence(s) of words
18 | or a list of lists of integer encoded words
19 |
20 | answers: Can be a list of answers as space delimited sentence(s) of words
21 | or a list of lists of integer encoded words
22 |
23 | input_vocabulary: The Vocabulary instance to use for encoding questions
24 |
25 | output_vocabulary: The Vocabulary instance to use for encoding answers
26 | """
27 | if len(questions) != len (answers):
28 | raise RuntimeError("questions and answers lists must be the same length, as they are lists of input-output pairs.")
29 |
30 | self.input_vocabulary = input_vocabulary
31 | self.output_vocabulary = output_vocabulary
32 | #If the questions and answers are already integer encoded, accept them as is.
33 | #Otherwise use the Vocabulary instances to encode the question and answer sequences.
34 | if len(questions) > 0 and isinstance(questions[0], str):
35 | self.questions_into_int = [self.input_vocabulary.words2ints(q) for q in questions]
36 | self.answers_into_int = [self.output_vocabulary.words2ints(a) for a in answers]
37 | else:
38 | self.questions_into_int = questions
39 | self.answers_into_int = answers
40 |
41 | def size(self):
42 | """ The size (number of samples) of the Dataset.
43 | """
44 | return len(self.questions_into_int)
45 |
46 | def train_val_split(self, val_percent = 20, random_split = True, move_samples = True):
47 | """Splits the dataset into training and validation sets.
48 |
49 | Args:
50 | val_percent: the percentage of the dataset to use as validation data.
51 |
52 | random_split: True to split the dataset randomly.
53 | False to split the dataset sequentially (validation samples are the last N samples, where N = samples * (val_percent / 100))
54 |
55 | move_samples: True to physically move the samples into the returned training and validation dataset objects (saves memory).
56 | False to copy the samples into the returned training and validation dataset objects, and preserve this dataset instance.
57 | """
58 |
59 | if move_samples:
60 | questions = self.questions_into_int
61 | answers = self.answers_into_int
62 | else:
63 | questions = self.questions_into_int[:]
64 | answers = self.answers_into_int[:]
65 |
66 | num_validation_samples = int(len(questions) * (val_percent / 100))
67 | num_training_samples = len(questions) - num_validation_samples
68 |
69 | training_questions = []
70 | training_answers = []
71 | validation_questions = []
72 | validation_answers = []
73 | if random_split:
74 | for _ in range(num_validation_samples):
75 | random_index = random.randint(0, len(questions) - 1)
76 | validation_questions.append(questions.pop(random_index))
77 | validation_answers.append(answers.pop(random_index))
78 |
79 | for _ in range(num_training_samples):
80 | training_questions.append(questions.pop(0))
81 | training_answers.append(answers.pop(0))
82 | else:
83 | for _ in range(num_training_samples):
84 | training_questions.append(questions.pop(0))
85 | training_answers.append(answers.pop(0))
86 |
87 | for _ in range(num_validation_samples):
88 | validation_questions.append(questions.pop(0))
89 | validation_answers.append(answers.pop(0))
90 |
91 | training_dataset = Dataset(training_questions, training_answers, self.input_vocabulary, self.output_vocabulary)
92 | validation_dataset = Dataset(validation_questions, validation_answers, self.input_vocabulary, self.output_vocabulary)
93 |
94 | return training_dataset, validation_dataset
95 |
96 | def sort(self):
97 | """Sorts the dataset by the lengths of the questions. This can speed up training by reducing the
98 | amount of padding the input sequences need.
99 | """
100 | if self.size() > 0:
101 | self.questions_into_int, self.answers_into_int = zip(*sorted(zip(self.questions_into_int, self.answers_into_int),
102 | key = lambda qa_pair: len(qa_pair[0])))
103 |
104 | def save(self, filepath):
105 | """Saves the dataset questions & answers exactly as represented by input_vocabulary and output_vocabulary.
106 | """
107 | filename, ext = path.splitext(filepath)
108 | questions_filepath = "{0}_questions{1}".format(filename, ext)
109 | answers_filepath = "{0}_answers{1}".format(filename, ext)
110 |
111 | with open(questions_filepath, mode="w", encoding="utf-8") as file:
112 | for question_into_int in self.questions_into_int:
113 | question = self.input_vocabulary.ints2words(question_into_int, is_punct_discrete_word = True, capitalize_i = False)
114 | file.write(question)
115 | file.write('\n')
116 |
117 | with open(answers_filepath, mode="w", encoding="utf-8") as file:
118 | for answer_into_int in self.answers_into_int:
119 | answer = self.output_vocabulary.ints2words(answer_into_int, is_punct_discrete_word = True, capitalize_i = False)
120 | file.write(answer)
121 | file.write('\n')
122 |
123 |
124 |
125 | def batches(self, batch_size):
126 | """Provide the dataset as an enumerable collection of batches of size batch_size.
127 | Each batch will be a matrix of a fixed shape (batch_size, max_seq_length_in_batch).
128 | Sequences that are shorter than the largest one are padded at the end with the PAD token.
129 | Padding is largely just used as a placeholder since the dyamic encoder and decoder RNNs
130 | will never see the padded timesteps.
131 |
132 | Args:
133 | batch_size: size of each batch.
134 | If the total number of samples is not evenly divisible by batch_size, the last batch will contain the remainder
135 | which will be less than batch_size.
136 |
137 | Returns:
138 | padded_questions_in_batch: A list of padded, integer-encoded question sequences.
139 |
140 | padded_answers_in_batch: A list of padded, integer-encoded answer sequences.
141 |
142 | seqlen_questions_in_batch: A list of actual sequence lengths for each question in the batch.
143 |
144 | seqlen_answers_in_batch: A list of actual sequence lengths for each answer in the batch.
145 |
146 | """
147 | for batch_index in range(0, math.ceil(len(self.questions_into_int) / batch_size)):
148 | start_index = batch_index * batch_size
149 | questions_in_batch = self.questions_into_int[start_index : start_index + batch_size]
150 | answers_in_batch = self.answers_into_int[start_index : start_index + batch_size]
151 |
152 | seqlen_questions_in_batch = np.array([len(q) for q in questions_in_batch])
153 | seqlen_answers_in_batch = np.array([len(a) for a in answers_in_batch])
154 |
155 | padded_questions_in_batch = np.array(self._apply_padding(questions_in_batch, self.input_vocabulary))
156 | padded_answers_in_batch = np.array(self._apply_padding(answers_in_batch, self.output_vocabulary))
157 |
158 | yield padded_questions_in_batch, padded_answers_in_batch, seqlen_questions_in_batch, seqlen_answers_in_batch
159 |
160 |
161 | def _apply_padding(self, batch_of_sequences, vocabulary):
162 | """Padding the sequences with the token to ensure all sequences in the batch are the same physical size.
163 |
164 | Input and target sequences can be any length, but each batch that is fed into the model
165 | must be defined as a fixed size matrix of shape (batch_size, max_seq_length_in_batch).
166 |
167 | Padding allows for this dynamic sequence length within a fixed matrix.
168 | However, the actual padded sequence timesteps are never seen by the encoder or decoder nor are they
169 | counted toward the softmax loss. This is possible since we provide the actual sequence lengths to the model
170 | as a separate vector of shape (batch_size), where the RNNs are instructed to only unroll for the number of
171 | timesteps that have real (unpadded) values. The sequence loss also accepts a masking weight matrix where
172 | we can specify that loss values for padded timesteps should be ignored.
173 |
174 | Args:
175 | batch_of_sequences: list of integer-encoded sequences to pad
176 |
177 | vocabulary: Vocabulary instance to use to look up the integer encoding of the PAD token
178 | """
179 | max_sequence_length = max([len(sequence) for sequence in batch_of_sequences])
180 | return [sequence + ([vocabulary.pad_int()] * (max_sequence_length - len(sequence))) for sequence in batch_of_sequences]
--------------------------------------------------------------------------------
/dataset_readers/cornell_dataset_reader.py:
--------------------------------------------------------------------------------
1 | """
2 | Reader class for the Cornell movie dialog dataset
3 | """
4 | from os import path
5 |
6 | from dataset_readers.dataset_reader import DatasetReader
7 |
8 | class CornellDatasetReader(DatasetReader):
9 | """Reader implementation for the Cornell movie dialog dataset
10 | """
11 | def __init__(self):
12 | super(CornellDatasetReader, self).__init__("cornell_movie_dialog")
13 |
14 | def _get_dialog_lines_and_conversations(self, dataset_dir):
15 | """Get dialog lines and conversations. See base class for explanation.
16 |
17 | Args:
18 | See base class
19 | """
20 | movie_lines_filepath = path.join(dataset_dir, "movie_lines.txt")
21 | movie_conversations_filepath = path.join(dataset_dir, "movie_conversations.txt")
22 |
23 | # Importing the dataset
24 | with open(movie_lines_filepath, encoding="utf-8", errors="ignore") as file:
25 | lines = file.read().split("\n")
26 |
27 | with open(movie_conversations_filepath, encoding="utf-8", errors="ignore") as file:
28 | conversations = file.read().split("\n")
29 |
30 | # Creating a dictionary that maps each line and its id
31 | id2line = {}
32 | for line in lines:
33 | _line = line.split(" +++$+++ ")
34 | if len(_line) == 5:
35 | id2line[_line[0]] = _line[4]
36 |
37 | # Creating a list of all of the conversations
38 | conversations_ids = []
39 | for conversation in conversations[:-1]:
40 | _conversation = conversation.split(" +++$+++ ")[-1][1:-1].replace("'", "").replace(" ", "")
41 | conv_ids = _conversation.split(",")
42 | conversations_ids.append(conv_ids)
43 |
44 | return id2line, conversations_ids
45 |
--------------------------------------------------------------------------------
/dataset_readers/csv_dataset_reader.py:
--------------------------------------------------------------------------------
1 | """
2 | Reader class for generic CSV question-answer datasets
3 | """
4 | from os import path
5 | import pandas as pd
6 |
7 | from dataset_readers.dataset_reader import DatasetReader
8 |
9 | class CSVDatasetReader(DatasetReader):
10 | """Reader implementation for generic CSV question-answer datasets
11 | """
12 | def __init__(self):
13 | super(CSVDatasetReader, self).__init__("csv")
14 |
15 | def _get_dialog_lines_and_conversations(self, dataset_dir):
16 | """Get dialog lines and conversations. See base class for explanation.
17 | Args:
18 | See base class
19 | """
20 | csv_filepath = path.join(dataset_dir, "csv_data.csv")
21 |
22 | # Importing the dataset
23 | dataset = pd.read_csv(csv_filepath, dtype=str, na_filter=False)
24 | questions = dataset.iloc[:, 0].values
25 | answers = dataset.iloc[:, 1].values
26 |
27 | # Creating a dictionary that maps each line and its id
28 | conversations_ids = []
29 | id2line = {}
30 | for i in range(len(questions)):
31 | question = questions[i].strip()
32 | answer = answers[i].strip()
33 | if question != '' and answer != '':
34 | q_line_id = "{}_q".format(i)
35 | a_line_id = "{}_a".format(i)
36 | id2line[q_line_id] = question
37 | id2line[a_line_id] = answer
38 | conversations_ids.append([q_line_id, a_line_id])
39 |
40 | return id2line, conversations_ids
--------------------------------------------------------------------------------
/dataset_readers/dataset_reader.py:
--------------------------------------------------------------------------------
1 | """
2 | Base class for dataset readers
3 | """
4 | import abc
5 | from os import path
6 |
7 | from vocabulary_importers import vocabulary_importer_factory
8 | from vocabulary_importers.vocabulary_importer import VocabularyImportMode
9 | from vocabulary import Vocabulary
10 | from dataset import Dataset
11 |
12 | class DatasetReadStats(object):
13 | """Contains information about the read dataset.
14 | """
15 |
16 | def __init__(self):
17 | self.input_vocabulary_import_stats = None
18 | self.output_vocabulary_import_stats = None
19 |
20 | class DatasetReader(object):
21 | """Base class for dataset readers
22 | """
23 |
24 | def __init__(self, dataset_name):
25 | """Initialize the DatasetReader.
26 |
27 | Args:
28 | dataset_name: Name of the dataset. Subclass must pass this in.
29 | """
30 | self.dataset_name = dataset_name
31 |
32 | @abc.abstractmethod
33 | def _get_dialog_lines_and_conversations(self, dataset_dir):
34 | """Subclass must implement this
35 |
36 | Read the raw dataset files and extract a dictionary of dialog lines and a list of conversations.
37 | A conversation is a list of dictionary keys for dialog lines that sequentially form a conversation.
38 |
39 | Args:
40 | dataset_dir: directory to load the raw dataset file(s) from
41 | """
42 | pass
43 |
44 | def read_dataset(self, dataset_dir, model_dir, training_hparams, share_vocab = True, encoder_embeddings_dir = None, decoder_embeddings_dir = None):
45 | """Read and return a chatbot dataset based on the specified dataset
46 |
47 | Args:
48 | dataset_dir: directory to load the raw dataset file(s) from
49 |
50 | model_dir: directory to save the vocabulary to
51 |
52 | training_hparams: training parameters which determine how the dataset will be read.
53 | See hparams.py for in-depth comments.
54 |
55 | share_vocab: True to generate a single vocabulary file from the question and answer words.
56 | False to generate separate input and output vocabulary files, from the question and answer words respectively.
57 | (If training_hparams.conv_history_length > 0, share_vocab should be set to True since previous answers will be appended to the questions.
58 | This could cause many of these previous answer words to map to when looking up against the input vocabulary.
59 | An exception to this is if the output vocabulary is a subset of the input vocaulary.)
60 |
61 | encoder_embeddings_dir: Path to directory containing external embeddings to import for the encoder.
62 | If this is specified, the input vocabulary will be loaded from this source and optionally joined with the generated
63 | dataset vocabulary (see training_hparams.input_vocab_import_mode)
64 | If share_vocab is True, the imported vocabulary is used for both input and output.
65 |
66 | decoder_embeddings_dir: Path to directory containing external embeddings to import for the decoder.
67 | If this is specified, the output vocabulary will be loaded from this source and optionally joined with the generated
68 | dataset vocabulary (see training_hparams.output_vocab_import_mode)
69 | If share_vocab is True, this argument must be None or the same as encoder_embeddings_dir (both are equivalent).
70 | """
71 |
72 | if share_vocab:
73 | if training_hparams.input_vocab_threshold != training_hparams.output_vocab_threshold and (encoder_embeddings_dir is None or training_hparams.input_vocab_import_mode != VocabularyImportMode.External):
74 | raise ValueError("Cannot share generated or joined imported vocabulary when the input and output vocab thresholds are different.")
75 | if encoder_embeddings_dir is not None:
76 | if training_hparams.input_vocab_import_mode != training_hparams.output_vocab_import_mode:
77 | raise ValueError("Cannot share imported vocabulary when input and output vocab import modes are different.")
78 | if training_hparams.input_vocab_import_normalized != training_hparams.output_vocab_import_normalized:
79 | raise ValueError("Cannot share imported vocabulary when input and output normalization modes are different.")
80 | if decoder_embeddings_dir is not None and decoder_embeddings_dir != encoder_embeddings_dir:
81 | raise ValueError("Cannot share imported vocabulary from two different sources or share import and generated vocabulary.")
82 |
83 |
84 | read_stats = DatasetReadStats()
85 |
86 | #Get dialog line and conversation collections
87 | id2line, conversations_ids = self._get_dialog_lines_and_conversations(dataset_dir)
88 |
89 | #Clean dialog lines
90 | for line_id in id2line:
91 | id2line[line_id] = Vocabulary.clean_text(id2line[line_id], training_hparams.max_question_answer_words, training_hparams.normalize_words)
92 |
93 | #Output cleaned lines for debugging purposes
94 | if training_hparams.log_cleaned_dataset:
95 | self._log_cleaned_dataset(model_dir, id2line.values())
96 |
97 | # Getting separately the questions and the answers
98 | questions_for_count = []
99 | questions = []
100 | answers = []
101 | for conversation in conversations_ids[:training_hparams.max_conversations]:
102 | for i in range(len(conversation) - 1):
103 | conv_up_to_question = ''
104 | for j in range(max(0, i - training_hparams.conv_history_length), i):
105 | conv_up_to_question += id2line[conversation[j]] + " {0} ".format(Vocabulary.EOS)
106 | question = id2line[conversation[i]]
107 | question_with_history = conv_up_to_question + question
108 | answer = id2line[conversation[i+1]]
109 | if training_hparams.min_question_words <= len(question_with_history.split()):
110 | questions.append(conv_up_to_question + question)
111 | questions_for_count.append(question)
112 | answers.append(answer)
113 |
114 | # Create the vocabulary object & add the question & answer words
115 | if share_vocab:
116 | questions_and_answers = []
117 | for i in range(len(questions_for_count)):
118 | question = questions_for_count[i]
119 | answer = answers[i]
120 | if i == 0 or question != answers[i - 1]:
121 | questions_and_answers.append(question)
122 | questions_and_answers.append(answer)
123 |
124 | input_vocabulary, read_stats.input_vocabulary_import_stats = self._create_and_save_vocab(questions_and_answers,
125 | training_hparams.input_vocab_threshold,
126 | model_dir,
127 | Vocabulary.SHARED_VOCAB_FILENAME,
128 | encoder_embeddings_dir,
129 | training_hparams.input_vocab_import_normalized,
130 | training_hparams.input_vocab_import_mode)
131 | output_vocabulary = input_vocabulary
132 | read_stats.output_vocabulary_import_stats = read_stats.input_vocabulary_import_stats
133 | else:
134 | input_vocabulary, read_stats.input_vocabulary_import_stats = self._create_and_save_vocab(questions_for_count,
135 | training_hparams.input_vocab_threshold,
136 | model_dir,
137 | Vocabulary.INPUT_VOCAB_FILENAME,
138 | encoder_embeddings_dir,
139 | training_hparams.input_vocab_import_normalized,
140 | training_hparams.input_vocab_import_mode)
141 |
142 | output_vocabulary, read_stats.output_vocabulary_import_stats = self._create_and_save_vocab(answers,
143 | training_hparams.output_vocab_threshold,
144 | model_dir,
145 | Vocabulary.OUTPUT_VOCAB_FILENAME,
146 | decoder_embeddings_dir,
147 | training_hparams.output_vocab_import_normalized,
148 | training_hparams.output_vocab_import_mode)
149 |
150 | # Adding the End Of String tokens to the end of every answer
151 | for i in range(len(answers)):
152 | answers[i] += " {0}".format(Vocabulary.EOS)
153 |
154 | #Create the Dataset object from the questions / answers lists and the vocab object.
155 | dataset = Dataset(questions, answers, input_vocabulary, output_vocabulary)
156 |
157 | return dataset, read_stats
158 |
159 | def _create_and_save_vocab(self, word_sequences, vocab_threshold, model_dir, vocab_filename, embeddings_dir, normalize_imported_vocab, vocab_import_mode):
160 | """Create a Vocabulary instance from a list of word sequences, and save it to disk.
161 |
162 | Args:
163 | word_sequences: List of word sequences (sentence(s)) to use as basis for the vocabulary.
164 |
165 | vocab_threshold: Minimum number of times any word must appear within word_sequences
166 | in order to be included in the vocabulary.
167 |
168 | model_dir: directory to save the vocabulary file to
169 |
170 | vocab_filename: file name of the vocabulary file
171 |
172 | embeddings_dir: Optional directory to import external vocabulary & embeddings
173 | If provided, the external vocabulary will be imported and processed according to the vocab_import_mode.
174 | If None, only the generated vocabulary will be used.
175 |
176 | normalize_imported_vocab: See VocabularyImporter.import_vocabulary
177 |
178 | vocab_import_mode: If embeddings_dir is specified, this flag indicates if the dataset vocabulary should be generated
179 | and used in combination with the external vocabulary according to the rules of VocabularyImportMode.
180 | """
181 | vocabulary = None
182 | if embeddings_dir is None or vocab_import_mode != VocabularyImportMode.External:
183 | vocabulary = Vocabulary()
184 | for i in range(len(word_sequences)):
185 | word_seq = word_sequences[i]
186 | vocabulary.add_words(word_seq.split())
187 | vocabulary.compile(vocab_threshold)
188 |
189 | vocabulary_import_stats = None
190 | if embeddings_dir is not None:
191 | vocabulary_importer = vocabulary_importer_factory.get_vocabulary_importer(embeddings_dir)
192 | vocabulary, vocabulary_import_stats = vocabulary_importer.import_vocabulary(embeddings_dir,
193 | normalize_imported_vocab,
194 | vocab_import_mode,
195 | vocabulary)
196 |
197 | vocab_filepath = path.join(model_dir, vocab_filename)
198 | vocabulary.save(vocab_filepath)
199 | return vocabulary, vocabulary_import_stats
200 |
201 | def _log_cleaned_dataset(self, model_dir, lines):
202 | """Write the cleaned dataset to disk
203 | """
204 | log_filepath = path.join(model_dir, "cleaned_dataset.txt")
205 | with open(log_filepath, mode="w", encoding="utf-8") as file:
206 | for line in lines:
207 | file.write(line)
208 | file.write('\n')
--------------------------------------------------------------------------------
/dataset_readers/dataset_reader_factory.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset reader implementation factory
3 | """
4 | from os import path
5 | from dataset_readers.cornell_dataset_reader import CornellDatasetReader
6 | from dataset_readers.csv_dataset_reader import CSVDatasetReader
7 |
8 | def get_dataset_reader(dataset_dir):
9 | """Gets the appropriate reader implementation for the specified dataset name.
10 |
11 | Args:
12 | dataset_dir: The directory of the dataset to get a reader implementation for.
13 | """
14 | dataset_name = path.basename(dataset_dir)
15 |
16 | #When adding support for new datasets, add an instance of their reader class to the reader array below.
17 | readers = [CornellDatasetReader(), CSVDatasetReader()]
18 |
19 | for reader in readers:
20 | if reader.dataset_name == dataset_name:
21 | return reader
22 |
23 | raise ValueError("There is no dataset reader implementation for '{0}'. If this is a new dataset, please add one!".format(dataset_name))
24 |
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/README.md:
--------------------------------------------------------------------------------
1 | # Cornell Movie Dialog Corpus
2 |
3 | "This corpus contains a large metadata-rich collection of fictional conversations extracted from raw movie scripts:
4 |
5 | - 220,579 conversational exchanges between 10,292 pairs of movie characters
6 |
7 | - involves 9,035 characters from 617 movies
8 |
9 | - in total 304,713 utterances"
10 |
11 | Source: [https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html)
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/movie_lines.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/datasets/cornell_movie_dialog/movie_lines.txt
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/train_with_dependency_based_embeddings.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\dependency_based --decoderembeddingsdir=embeddings\dependency_based
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/train_with_dependency_based_embeddings_decoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\cornell_movie_dialog --decoderembeddingsdir=embeddings\dependency_based
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/train_with_dependency_based_embeddings_encoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\dependency_based
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/train_with_nnlm_en_embeddings.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\nnlm_en --decoderembeddingsdir=embeddings\nnlm_en
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/train_with_nnlm_en_embeddings_decoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\cornell_movie_dialog --decoderembeddingsdir=embeddings\nnlm_en
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/train_with_nnlm_en_embeddings_encoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\nnlm_en
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/train_with_random_embeddings.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\cornell_movie_dialog
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/train_with_word2vec_wikipedia_embeddings.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\word2vec_wikipedia --decoderembeddingsdir=embeddings\word2vec_wikipedia
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/train_with_word2vec_wikipedia_embeddings_decoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\cornell_movie_dialog --decoderembeddingsdir=embeddings\word2vec_wikipedia
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/cornell_movie_dialog/train_with_word2vec_wikipedia_embeddings_encoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\cornell_movie_dialog --encoderembeddingsdir=embeddings\word2vec_wikipedia
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/csv/README.md:
--------------------------------------------------------------------------------
1 | # CSV question-answer dataset
2 | A generic CSV dataset with two columns - question and answer.
3 |
4 | ## Instructions:
5 |
6 | 1) Drop a CSV file here with the name "csv_data.csv" (replace the sample included in this repo with your own data).
7 |
8 | 2) The CSV must have two columns - the first one is for questions and the second one is for answers (responses).
9 |
10 | 3) Train using the included batch files or the command line.
11 |
--------------------------------------------------------------------------------
/datasets/csv/csv_data.csv:
--------------------------------------------------------------------------------
1 | question,answer
2 | your question,your answer
3 | what is your name,bob
--------------------------------------------------------------------------------
/datasets/csv/train_with_dependency_based_embeddings.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\dependency_based --decoderembeddingsdir=embeddings\dependency_based
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/csv/train_with_dependency_based_embeddings_decoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\csv --decoderembeddingsdir=embeddings\dependency_based
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/csv/train_with_dependency_based_embeddings_encoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\dependency_based
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/csv/train_with_nnlm_en_embeddings.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\nnlm_en --decoderembeddingsdir=embeddings\nnlm_en
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/csv/train_with_nnlm_en_embeddings_decoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\csv --decoderembeddingsdir=embeddings\nnlm_en
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/csv/train_with_nnlm_en_embeddings_encoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\nnlm_en
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/csv/train_with_random_embeddings.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\csv
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/csv/train_with_word2vec_wikipedia_embeddings.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\word2vec_wikipedia --decoderembeddingsdir=embeddings\word2vec_wikipedia
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/csv/train_with_word2vec_wikipedia_embeddings_decoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\csv --decoderembeddingsdir=embeddings\word2vec_wikipedia
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/datasets/csv/train_with_word2vec_wikipedia_embeddings_encoder_only.bat:
--------------------------------------------------------------------------------
1 | call C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3
2 | cd ..\..
3 | python train.py --datasetdir=datasets\csv --encoderembeddingsdir=embeddings\word2vec_wikipedia
4 |
5 | cmd /k
--------------------------------------------------------------------------------
/general_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | General utility methods
3 | """
4 | import os
5 | import argparse
6 | import datetime
7 | import platform
8 | from shutil import copyfile
9 |
10 | from hparams import Hparams
11 |
12 | def initialize_session(mode):
13 | """Helper method for initializing a chatbot training session
14 | by loading the model dir from command line args and reading the hparams in
15 |
16 | Args:
17 | mode: "train" or "chat"
18 | """
19 | parser = argparse.ArgumentParser("Train a chatbot model" if mode == "train" else "Chat with a trained chatbot model")
20 | if mode == "train":
21 | ex_group = parser.add_mutually_exclusive_group(required=True)
22 | ex_group.add_argument("--datasetdir", "-d", help="Path structured as datasets/dataset_name. A new model will be trained using the dataset contained in this directory.")
23 | ex_group.add_argument("--checkpointfile", "-c", help="Path structured as 'models/dataset_name/model_name/checkpoint_name.ckpt'. Training will resume from the selected checkpoint. The hparams.json file should exist in the same directory as the checkpoint.")
24 | em_group = parser.add_argument_group()
25 | em_group.add_argument("--encoderembeddingsdir", "--embeddingsdir", "-e", help="Path structured as embeddings/embeddings_name. Encoder (& Decoder if shared) vocabulary and embeddings will be initialized from the checkpoint file and tokens file contained in this directory.")
26 | em_group.add_argument("--decoderembeddingsdir", help="Path structured as embeddings/embeddings_name. Decoder vocabulary and embeddings will be initialized from the checkpoint file and tokens file contained in this directory.")
27 | elif mode == "chat":
28 | parser.add_argument("checkpointfile", help="Path structured as 'models/dataset_name/model_name/checkpoint_name.ckpt'. The hparams.json file and the vocabulary file(s) should exist in the same directory as the checkpoint.")
29 | else:
30 | raise ValueError("Unsupported session mode. Choose 'train' or 'chat'.")
31 | args = parser.parse_args()
32 |
33 | #Make sure script was run in the correct working directory
34 | models_dir = "models"
35 | datasets_dir = "datasets"
36 | if not os.path.isdir(models_dir) or not os.path.isdir(datasets_dir):
37 | raise NotADirectoryError("Cannot find models directory 'models' and datasets directory 'datasets' within working directory '{0}'. Make sure to set the working directory to the chatbot root folder."
38 | .format(os.getcwd()))
39 |
40 | encoder_embeddings_dir = decoder_embeddings_dir = None
41 | if mode == "train":
42 | #If provided, make sure the embeddings exist
43 | if args.encoderembeddingsdir:
44 | encoder_embeddings_dir = os.path.relpath(args.encoderembeddingsdir)
45 | if not os.path.isdir(encoder_embeddings_dir):
46 | raise NotADirectoryError("Cannot find embeddings directory '{0}'".format(os.path.realpath(encoder_embeddings_dir)))
47 | if args.decoderembeddingsdir:
48 | decoder_embeddings_dir = os.path.relpath(args.decoderembeddingsdir)
49 | if not os.path.isdir(decoder_embeddings_dir):
50 | raise NotADirectoryError("Cannot find embeddings directory '{0}'".format(os.path.realpath(decoder_embeddings_dir)))
51 |
52 | if mode == "train" and args.datasetdir:
53 | #Make sure dataset exists
54 | dataset_dir = os.path.relpath(args.datasetdir)
55 | if not os.path.isdir(dataset_dir):
56 | raise NotADirectoryError("Cannot find dataset directory '{0}'".format(os.path.realpath(dataset_dir)))
57 | #Create the new model directory
58 | dataset_name = os.path.basename(dataset_dir)
59 | model_dir = os.path.join("models", dataset_name, datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
60 | os.makedirs(model_dir, exist_ok=True)
61 | copyfile("hparams.json", os.path.join(model_dir, "hparams.json"))
62 | checkpoint = None
63 | elif args.checkpointfile:
64 | #Make sure checkpoint file & hparams file exists
65 | checkpoint_filepath = os.path.relpath(args.checkpointfile)
66 | if not os.path.isfile(checkpoint_filepath + ".meta"):
67 | raise FileNotFoundError("The checkpoint file '{0}' was not found.".format(os.path.realpath(checkpoint_filepath)))
68 | #Get the checkpoint model directory
69 | checkpoint = os.path.basename(checkpoint_filepath)
70 | model_dir = os.path.dirname(checkpoint_filepath)
71 | dataset_name = os.path.basename(os.path.dirname(model_dir))
72 | dataset_dir = os.path.join(datasets_dir, dataset_name)
73 | else:
74 | raise ValueError("Invalid arguments. Use --help for proper usage.")
75 |
76 | #Load the hparams from file
77 | hparams_filepath = os.path.join(model_dir, "hparams.json")
78 | hparams = Hparams.load(hparams_filepath)
79 |
80 | return dataset_dir, model_dir, hparams, checkpoint, encoder_embeddings_dir, decoder_embeddings_dir
81 |
82 | def initialize_session_server(checkpointfile):
83 | #Make sure checkpoint file & hparams file exists
84 | checkpoint_filepath = os.path.relpath(checkpointfile)
85 | if not os.path.isfile(checkpoint_filepath + ".meta"):
86 | raise FileNotFoundError("The checkpoint file '{0}' was not found.".format(os.path.realpath(checkpoint_filepath)))
87 | #Get the checkpoint model directory
88 | checkpoint = os.path.basename(checkpoint_filepath)
89 | model_dir = os.path.dirname(checkpoint_filepath)
90 |
91 | #Load the hparams from file
92 | hparams_filepath = os.path.join(model_dir, "hparams.json")
93 | hparams = Hparams.load(hparams_filepath)
94 |
95 | return model_dir, hparams, checkpoint
96 |
97 | def create_batch_files(model_dir, checkpoint_training, checkpoint_val, encoder_embeddings_dir, decoder_embeddings_dir):
98 | os_type = platform.system().lower()
99 | if os_type == "windows":
100 | if checkpoint_training is not None:
101 | create_windows_batch_files(model_dir, checkpoint_training, encoder_embeddings_dir, decoder_embeddings_dir)
102 | if checkpoint_val is not None:
103 | create_windows_batch_files(model_dir, checkpoint_val, encoder_embeddings_dir, decoder_embeddings_dir)
104 | elif os_type == "darwin":
105 | pass
106 | elif os_type == "linux":
107 | pass
108 | else:
109 | pass
110 |
111 | def create_windows_batch_files(model_dir, checkpoint, encoder_embeddings_dir, decoder_embeddings_dir):
112 | if "CONDA_PREFIX" in os.environ:
113 | conda_prefix = os.environ["CONDA_PREFIX"]
114 | conda_activate = os.path.join(conda_prefix, r"scripts\activate.bat")
115 | checkpoint_file = os.path.join(model_dir, checkpoint)
116 | checkpoint_name = os.path.splitext(checkpoint)[0]
117 |
118 | #Resume training batch file
119 | batch_file = os.path.join(model_dir, "resume_training_{0}.bat".format(checkpoint_name))
120 | with open(batch_file, mode="w", encoding="utf-8") as file:
121 | file.write("\n".join([
122 | "call {0} {1}".format(conda_activate, conda_prefix),
123 | r"cd ..\..\..",
124 | "python train.py --checkpointfile=\"{0}\"{1}{2}".format(checkpoint_file,
125 | " --encoderembeddingsdir={0}".format(encoder_embeddings_dir) if encoder_embeddings_dir is not None else "",
126 | " --decoderembeddingsdir={0}".format(decoder_embeddings_dir) if decoder_embeddings_dir is not None else ""),
127 | "",
128 | "cmd /k"
129 | ]))
130 |
131 | #Chat batch file
132 | batch_file = os.path.join(model_dir, "chat_console_{0}.bat".format(checkpoint_name))
133 | with open(batch_file, mode="w", encoding="utf-8") as file:
134 | file.write("\n".join([
135 | "call {0} {1}".format(conda_activate, conda_prefix),
136 | r"cd ..\..\..",
137 | "python chat.py \"{0}\"".format(checkpoint_file),
138 | "",
139 | "cmd /k"
140 | ]))
141 |
142 | #Chat web batch file
143 | batch_file = os.path.join(model_dir, "chat_web_{0}.bat".format(checkpoint_name))
144 | with open(batch_file, mode="w", encoding="utf-8") as file:
145 | file.write("\n".join([
146 | "call {0} {1}".format(conda_activate, conda_prefix),
147 | r"cd ..\..\..",
148 | "set FLASK_APP=chat_web.py",
149 | "flask serve_chat \"{0}\" -p 8080".format(checkpoint_file),
150 | "",
151 | "cmd /k"
152 | ]))
153 |
154 | #Tensorboard batch file
155 | batch_file = os.path.join(model_dir, "tensorboard_{0}.bat".format(checkpoint_name))
156 | with open(batch_file, mode="w", encoding="utf-8") as file:
157 | file.write("\n".join([
158 | "call {0} {1}".format(conda_activate, conda_prefix),
159 | "tensorboard --logdir=."
160 | ]))
161 |
--------------------------------------------------------------------------------
/hparams.json:
--------------------------------------------------------------------------------
1 | {
2 | "py/object": "hparams.Hparams",
3 | "model_hparams" : {
4 | "py/object": "hparams.ModelHparams",
5 | "rnn_cell_type": "lstm",
6 | "rnn_size": 1024,
7 | "use_bidirectional_encoder": true,
8 | "encoder_num_layers": 4,
9 | "decoder_num_layers": 4,
10 | "encoder_embedding_size": 128,
11 | "decoder_embedding_size": 128,
12 | "encoder_embedding_trainable": true,
13 | "decoder_embedding_trainable": true,
14 | "share_embedding": true,
15 | "attention_type": "normed_bahdanau",
16 | "beam_width": 20,
17 | "enable_sampling": false,
18 | "optimizer": "sgd",
19 | "max_gradient_norm": 5.0,
20 | "gpu_dynamic_memory_growth": true
21 | },
22 | "training_hparams": {
23 | "py/object": "hparams.TrainingHparams",
24 | "min_question_words": 1,
25 | "max_question_answer_words": 30,
26 | "max_conversations": -1,
27 | "conv_history_length": 6,
28 | "normalize_words": false,
29 | "input_vocab_threshold": 1,
30 | "output_vocab_threshold": 1,
31 | "input_vocab_import_normalized": true,
32 | "output_vocab_import_normalized": true,
33 | "input_vocab_import_mode": "Dataset",
34 | "output_vocab_import_mode": "Dataset",
35 | "validation_set_percent": 0.0,
36 | "random_train_val_split": true,
37 | "validation_metric": "loss",
38 | "epochs": 500,
39 | "early_stopping_epochs": 500,
40 | "batch_size": 128,
41 | "learning_rate": 2.0,
42 | "learning_rate_decay": 0.9975,
43 | "min_learning_rate": 0.1,
44 | "dropout": 0.2,
45 | "checkpoint_on_training": true,
46 | "checkpoint_on_validation": false,
47 | "log_summary": true,
48 | "log_cleaned_dataset": true,
49 | "log_training_data": true,
50 | "stats_after_n_batches": 100,
51 | "backup_on_training_loss": [3.0, 2.5, 2.0, 1.5, 1.0, 0.5]
52 | },
53 | "inference_hparams": {
54 | "py/object": "hparams.InferenceHparams",
55 | "beam_length_penalty_weight": 1.25,
56 | "sampling_temperature": 0.5,
57 | "max_answer_words": 100,
58 | "conv_history_length": 20,
59 | "normalize_words": false,
60 | "log_summary": true,
61 | "log_chat": true
62 | }
63 | }
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | """
2 | Hyperparameters class
3 | """
4 |
5 | import jsonpickle
6 | from vocabulary_importers.vocabulary_importer import VocabularyImportMode
7 |
8 | class Hparams(object):
9 | """Container for model, training, and inference hyperparameters.
10 |
11 | Members:
12 | model_hparams: ModelHparams instance
13 |
14 | training_hparams: TrainingHparams instance
15 |
16 | inference_hparams: InferenceHparams instance
17 | """
18 | def __init__(self):
19 | """Initializes the Hparams instance.
20 | """
21 | self.model_hparams = ModelHparams()
22 | self.training_hparams = TrainingHparams()
23 | self.inference_hparams = InferenceHparams()
24 |
25 | @staticmethod
26 | def load(filepath):
27 | """Loads the hyperparameters from a JSON file.
28 |
29 | Args:
30 | filepath: path of the JSON file.
31 | """
32 | with open(filepath, "r") as file:
33 | json = file.read()
34 | hparams = jsonpickle.decode(json)
35 | hparams.training_hparams.input_vocab_import_mode = VocabularyImportMode[hparams.training_hparams.input_vocab_import_mode]
36 | hparams.training_hparams.output_vocab_import_mode = VocabularyImportMode[hparams.training_hparams.output_vocab_import_mode]
37 | return hparams
38 |
39 |
40 | class ModelHparams(object):
41 | """Hyperparameters which determine the architecture and complexity of the chatbot model.
42 |
43 | Members:
44 | rnn_cell_type: The architecture of RNN cell: "lstm" or "gru"
45 | LSTM: "Long-Short Term Memory"
46 | GRU: "Gated Recurrent Unit"
47 |
48 | rnn_size: the number of units (neurons) in each RNN cell. Applies to the encoder and decoder.
49 |
50 | use_bidirectional_encoder: True to use a bi-directional encoder.
51 | Bi-directional encoder: Two separate RNN cells (or stacks of cells) are used -
52 | one receives the input sequence (question) in forward order, one receives the input sequence (question) in reverse order.
53 | When creating stacked RNN layers, each direction is stacked separately, with one stack for forward cells
54 | and one stack for reverse cells.
55 | Uni-directional encoder: One RNN cell (or stack of cells) is used in the forward direction (traditional RNN)
56 |
57 | encoder_num_layers: the number of RNN cells to stack in the encoder.
58 | If use_bidirectional_encoder is set to true, this number is divided in half and applied to
59 | each direction. For example: 4 layers with bidrectional encoder means 2 forward & 2 backward cells.
60 |
61 | decoder_num_layers: the number of RNN cells to stack in the decoder.
62 | The encoder state can only be passed in to the decoder as its intial state if this value
63 | is the same as encoder_num_layers.
64 |
65 | encoder_embedding_size: the number of dimensions for each vector in the encoder embedding matrix.
66 | This matrix will be shaped (input_vocabulary.size(), encoder_embedding_size)
67 |
68 | decoder_embedding_size: the number of dimensions for each vector in the decoder embedding matrix.
69 | This matrix will be shaped (output_vocabulary.size(), decoder_embedding_size)
70 |
71 | encoder_embedding_trainable: True to allow gradient updates to be applied to the encoder embedding matrix.
72 | False to freeze the embedding matrix and only train the encoder & decoder RNNs, enabling greater training
73 | efficiency when loading pre-trained embeddings such as Word2Vec.
74 |
75 | decoder_embedding_trainable: True to allow gradient updates to be applied to the decoder embedding matrix.
76 | False to freeze the embedding matrix and only train the encoder & decoder RNNs, enabling greater training
77 | efficiency when loading pre-trained embeddings such as Word2Vec.
78 |
79 | share_embedding: True to reuse the same embedding matrix for the encoder and decoder.
80 | If the vocabulary is identical between input questions and output answers (as in a chatbot), then this should be True.
81 | If the vocabulary is different between input questions and output answers (as in a domain-specific Q&A system), then this should be False.
82 | If True -
83 | 1) input_vocabulary.size() & output_vocabulary.size() must have the same value
84 | 2) encoder_embedding_size & decoder_embedding_size must have the same value
85 | 3) encoder_embedding_trainable & decoder_embedding_trainable must have the same value
86 | 4) If loading pre-trained embeddings, --encoderembeddingsdir & --decoderembeddingsdir args
87 | must be supplied with the same value (or --embeddingsdir can be used instead)
88 | If all of the above conditions are not met, an error is raised.
89 |
90 | attention_type: Type of attention mechanism to use.
91 | ("bahdanau", "normed_bahdanau", "luong", "scaled_luong")
92 |
93 | beam_width: If mode is "infer", the number of beams to generate with the BeamSearchDecoder.
94 | Set to 0 for greedy / sampling decoding.
95 | This value is ignored if mode is "train".
96 | NOTE: this parameter should ideally be in InferenceHparams instead of ModelHparams, but is here for now
97 | because the graph of the model physically changes based on the beam width.
98 |
99 | enable_sampling: If True while beam_width = 0, the sampling decoder is used instead of the greedy decoder.
100 |
101 | optimizer: Type of optimizer to use when training.
102 | ("sgd", "adam")
103 | NOTE: this parameter should ideally be in TrainingHparams instead of ModelHparams, but is here for now
104 | because the graph of the model physically changes based on which optimizer is used.
105 |
106 | max_gradient_norm: max value to clip the gradients if gradient clipping is enabled.
107 | Set to 0 to disable gradient clipping. Defaults to 5.
108 | This value is ignored if mode is "infer".
109 | NOTE: this parameter should ideally be in TrainingHparams instead of ModelHparams, but is here for now
110 | because the graph of the model physically changes based on whether or not gradient clipping is used.
111 |
112 | gpu_dynamic_memory_growth: Configures the TensorFlow session to only allocate GPU memory as needed,
113 | instead of the default behavior of trying to aggresively allocate as much memory as possible.
114 | Defaults to True.
115 | """
116 | def __init__(self):
117 | """Initializes the ModelHparams instance.
118 | """
119 | self.rnn_cell_type = "lstm"
120 |
121 | self.rnn_size = 256
122 |
123 | self.use_bidirectional_encoder = True
124 |
125 | self.encoder_num_layers = 2
126 |
127 | self.decoder_num_layers = 2
128 |
129 | self.encoder_embedding_size = 256
130 |
131 | self.decoder_embedding_size = 256
132 |
133 | self.encoder_embedding_trainable = True
134 |
135 | self.decoder_embedding_trainable = True
136 |
137 | self.share_embedding = True
138 |
139 | self.attention_type = "normed_bahdanau"
140 |
141 | self.beam_width = 10
142 |
143 | self.enable_sampling = False
144 |
145 | self.optimizer = "adam"
146 |
147 | self.max_gradient_norm = 5.
148 |
149 | self.gpu_dynamic_memory_growth = True
150 |
151 | class TrainingHparams(object):
152 | """Hyperparameters used when training the chatbot model.
153 |
154 | Members:
155 | min_question_words: minimum length (in words) for a question.
156 | set this to a higher number if you wish to exclude shorter questions which
157 | can sometimes lead to higher training error.
158 |
159 | max_question_answer_words: maximum length (in words) for a question or answer.
160 | any questions or answers longer than this are truncated to fit. The higher this number, the more
161 | timesteps the encoder RNN will need to be unrolled.
162 |
163 | max_conversations: number of conversations to use from the cornell dataset. Specify -1 for no limit.
164 | pick a lower limit if training on the whole dataset is too slow (for lower-end GPUs)
165 |
166 | conv_history_length: number of conversation steps to prepend every question.
167 | For example, a length of 2 would output:
168 | "hello how are you ? i am fine thank you how is the new job?"
169 | where "how is the new job?" is the question and the rest is the prepended conversation history.
170 | the intent is to let the attention mechanism be able to pick up context clues from earlier in the
171 | conversation in order to determine the best way to respond.
172 | pick a lower limit if training is too slow or causes out of memory errors. The higher this number,
173 | the more timesteps the encoder RNN will need to be unrolled.
174 |
175 | normalize_words: True to preprocess the words in the training dataset by replacing word contractions
176 | with their full forms (e.g. i'm -> i am) and then stripping out any remaining apostrophes.
177 |
178 | input_vocab_threshold: the minimum number of times a word must appear in the questions in order to be included
179 | in the vocabulary embedding. Any words that are not included in the vocabulary
180 | get replaced with an token before training and inference.
181 | if model_params.share_embedding = True, this must equal output_vocab_threshold.
182 |
183 | output_vocab_threshold: the minimum number of times a word must appear in the answers in order to be included
184 | For more info see input_vocab_threshold.
185 | if model_params.share_embedding = True, this must equal input_vocab_threshold.
186 |
187 | input_vocab_import_normalized: True to normalize external word vocabularies and embeddings before import as input vocabulary.
188 | In this context normalization means convert all word tokens to lower case and then average the embedding vectors for any duplicate words.
189 | For example, "JOHN", "John", and "john" will be converted to "john" and it will take the mean of all three embedding vectors.
190 |
191 | output_vocab_import_normalized: True to normalize external word vocabularies and embeddings before import as output vocabulary.
192 | For more info see input_vocab_import_normalized.
193 |
194 | input_vocab_import_mode: Mode to govern how external vocabularies and embeddings are imported as input vocabulary.
195 | Ignored if no external vocabulary specified.
196 | See VocabularyImportMode.
197 |
198 | output_vocab_import_mode: Mode to govern how external vocabularies and embeddings are imported as output vocabulary.
199 | Ignored if no external vocabulary specified.
200 | See VocabularyImportMode.
201 | This should be set to 'Dataset' or 'ExternalIntersectDataset' for large vocabularies, since the size of the
202 | decoder output layer is the vocabulary size. For example, an external embedding may have a vocabulary size
203 | of 1 million, but only 30k words appear in the dataset and having an output layer of 30k dimensions is
204 | much more efficient than an output layer of 1m dimensions.
205 |
206 | validation_set_percent: the percentage of the training dataset to use as the validation set.
207 |
208 | random_train_val_split:
209 | True to split the dataset randomly.
210 | False to split the dataset sequentially
211 | (validation samples are the last N samples, where N = samples * (val_percent / 100))
212 |
213 | validation_metric: the metric to use to measure the model during validation.
214 | "loss" - cross-entropy loss between predictions and targets
215 | "accuracy" (coming soon)
216 | "bleu" (coming soon)
217 |
218 | epochs: Number of epochs to train (1 epoch = all samples in dataset)
219 |
220 | early_stopping_epochs: stop early if no improvement in the validation metric
221 | after training for the given number of epochs in a row.
222 |
223 | batch_size: Training batch size
224 |
225 | learning_rate: learning rate used by SGD.
226 |
227 | learning_rate_decay: rate at which the learning rate drops.
228 | for each epoch, current_lr = starting_lr * (decay) ^ (epoch - 1)
229 |
230 | min_learning_rate: lowest value that the learning rate can go.
231 |
232 | dropout: probability that any neuron will be temporarily disabled during any training iteration.
233 | this is a regularization technique that helps the model learn more independent correlations in the data
234 | and can reduce overfitting.
235 |
236 | checkpoint_on_training: Write a checkpoint after an epoch if the training loss improved.
237 |
238 | checkpoint_on_validation: Write a checkpoint after an epoch if the validation metric improved.
239 |
240 | log_summary: True to log training stats & graph for visualization in tensorboard.
241 |
242 | log_cleaned_dataset: True to save a copy of the cleaned dataset to disk for debugging purposes before training begins.
243 |
244 | log_training_data: True to save a copy of the training question-answer pairs as represented by their vocabularies to disk.
245 | this is useful to see how frequently words are replaced by and also how dialog context is prepended to questions.
246 |
247 | stats_after_n_batches: Output training statistics (loss, time, etc.) after every N batches.
248 |
249 | backup_on_training_loss: List of training loss values upon which to backup the model
250 | Backups are full copies of the latest checkpoint files to another directory, also including vocab and hparam files.
251 | """
252 | def __init__(self):
253 | """Initializes the TrainingHparams instance.
254 | """
255 | self.min_question_words = 1
256 |
257 | self.max_question_answer_words = 30
258 |
259 | self.max_conversations = -1
260 |
261 | self.conv_history_length = 6
262 |
263 | self.normalize_words = True
264 |
265 | self.input_vocab_threshold = 2
266 |
267 | self.output_vocab_threshold = 2
268 |
269 | self.input_vocab_import_normalized = True
270 |
271 | self.output_vocab_import_normalized = True
272 |
273 | self.input_vocab_import_mode = VocabularyImportMode.External
274 |
275 | self.output_vocab_import_mode = VocabularyImportMode.Dataset
276 |
277 | self.validation_set_percent = 0
278 |
279 | self.random_train_val_split = True
280 |
281 | self.validation_metric = "loss"
282 |
283 | self.epochs = 500
284 |
285 | self.early_stopping_epochs = 500
286 |
287 | self.batch_size = 128
288 |
289 | self.learning_rate = 2.0
290 |
291 | self.learning_rate_decay = 0.99
292 |
293 | self.min_learning_rate = 0.1
294 |
295 | self.dropout = 0.2
296 |
297 | self.checkpoint_on_training = True
298 |
299 | self.checkpoint_on_validation = True
300 |
301 | self.log_summary = True
302 |
303 | self.log_cleaned_dataset = True
304 |
305 | self.log_training_data = True
306 |
307 | self.stats_after_n_batches = 100
308 |
309 | self.backup_on_training_loss = []
310 |
311 | class InferenceHparams(object):
312 | """Hyperparameters used when chatting with the chatbot model (a.k.a prediction or inference).
313 |
314 | Members:
315 | beam_length_penalty_weight: higher values mean longer beams are scored better
316 | while lower (or negative) values mean shorter beams are scored better.
317 | Ignored if beam_width = 0
318 |
319 | sampling_temperature: This value sets the softmax temperature of the sampling decoder, if enabled.
320 |
321 | max_answer_words: Max length (in words) for an answer.
322 |
323 | conv_history_length: number of conversation steps to prepend every question.
324 | This can be different from the value used during training.
325 |
326 | normalize_words: True to preprocess the words in the input question by replacing word contractions
327 | with their full forms (e.g. i'm -> i am) and then stripping out any remaining apostrophes.
328 |
329 | log_summary: True to log attention alignment images and inference graph for visualization in tensorboard.
330 |
331 | log_chat: True to log conversation history (chatlog) to a file.
332 | """
333 | def __init__(self):
334 | """Initializes the InferenceHparams instance.
335 | """
336 | self.beam_length_penalty_weight = 1.25
337 |
338 | self.sampling_temperature = 0.5
339 |
340 | self.max_answer_words = 100
341 |
342 | self.conv_history_length = 6
343 |
344 | self.normalize_words = True
345 |
346 | self.log_summary = True
347 |
348 | self.log_chat = True
349 |
350 |
--------------------------------------------------------------------------------
/models/cornell_movie_dialog/README.md:
--------------------------------------------------------------------------------
1 | # Trained Models for Cornell Movie Dialog Corpus
2 |
3 | ## Trained Model v2 (Jul 21 2018)
4 | This model was initialized using the [nnlm_en](../../embeddings/nnlm_en/README.md) pre-trained word embedding vectors. Embeddings were updated during training.
5 |
6 | You can download it from [here](https://drive.google.com/uc?id=1y1b1vXeSti5lpBACNdYlo8HbVJDUO3ir&export=download).
7 |
8 | After download, unzip the folder **trained_model_v2** into this directory.
9 |
10 | ## Trained Model v1 (Mar 29 2018)
11 | You can download it from [here](https://drive.google.com/uc?id=1Ig-sgdka5QpgE-b9g4ZQqnGGCrE-2f4p&export=download).
12 |
13 | After download, unzip the folder **trained_model_v1** into this directory.
14 |
15 | To chat with the trained cornell movie dialog model **trained_model_v2**:
16 |
17 | 1. Download and unzip [trained_model_v2](seq2seq-chatbot/models/cornell_movie_dialog/README.md) into the [seq2seq-chatbot/models/cornell_movie_dialog](seq2seq-chatbot/models/cornell_movie_dialog) folder
18 |
19 | 2. Set console working directory to the **seq2seq-chatbot** directory
20 |
21 | 3. Run:
22 | ```shell
23 | run chat.py models\cornell_movie_dialog\trained_model_v2\best_weights_training.ckpt
24 | ```
25 |
--------------------------------------------------------------------------------
/roadmap.md:
--------------------------------------------------------------------------------
1 | # Roadmap
2 | For bugs and other issues see the repository issues on GitHub.
3 |
4 | ## General Improvements
5 | - Load training data from a file instead of storing it in memory to allow for quick resume of training without needing to regenerate the dataset from the source corpus.
6 |
7 | - Read training data from a file in chunks to allow for datasets beyond the size of available RAM.
8 |
9 | - Output attention alignments correctly in tensorboard summary logs.
10 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for training the chatbot model
3 | """
4 | import time
5 | import math
6 | from os import path
7 | from shutil import copytree
8 |
9 | import general_utils
10 | import train_console_helper
11 | from dataset_readers import dataset_reader_factory
12 | from vocabulary_importers import vocabulary_importer_factory
13 | from vocabulary import Vocabulary
14 | from chatbot_model import ChatbotModel
15 | from training_stats import TrainingStats
16 |
17 | #Read the hyperparameters and paths
18 | dataset_dir, model_dir, hparams, resume_checkpoint, encoder_embeddings_dir, decoder_embeddings_dir = general_utils.initialize_session("train")
19 | training_stats_filepath = path.join(model_dir, "training_stats.json")
20 |
21 | #Read the chatbot dataset and generate / import the vocabulary
22 | dataset_reader = dataset_reader_factory.get_dataset_reader(dataset_dir)
23 |
24 | print()
25 | print("Reading dataset '{0}'...".format(dataset_reader.dataset_name))
26 | dataset, dataset_read_stats = dataset_reader.read_dataset(dataset_dir = dataset_dir,
27 | model_dir = model_dir,
28 | training_hparams = hparams.training_hparams,
29 | share_vocab = hparams.model_hparams.share_embedding,
30 | encoder_embeddings_dir = encoder_embeddings_dir,
31 | decoder_embeddings_dir = decoder_embeddings_dir)
32 | if encoder_embeddings_dir is not None:
33 | print()
34 | print("Imported {0} vocab '{1}'...".format("shared" if hparams.model_hparams.share_embedding else "input", encoder_embeddings_dir))
35 | train_console_helper.write_vocabulary_import_stats(dataset_read_stats.input_vocabulary_import_stats)
36 |
37 | if decoder_embeddings_dir is not None and not hparams.model_hparams.share_embedding:
38 | print()
39 | print("Imported output vocab '{0}'...".format(decoder_embeddings_dir))
40 | train_console_helper.write_vocabulary_import_stats(dataset_read_stats.output_vocabulary_import_stats)
41 |
42 | print()
43 | print("Final {0} vocab size: {1}".format("shared" if hparams.model_hparams.share_embedding else "input", dataset.input_vocabulary.size()))
44 | if not hparams.model_hparams.share_embedding:
45 | print("Final output vocab size: {0}".format(dataset.output_vocabulary.size()))
46 |
47 | #Split the chatbot dataset into training & validation datasets
48 | print()
49 | print("Splitting {0} samples into training & validation sets ({1}% used for validation)..."
50 | .format(dataset.size(), hparams.training_hparams.validation_set_percent))
51 |
52 | training_dataset, validation_dataset = dataset.train_val_split(val_percent = hparams.training_hparams.validation_set_percent,
53 | random_split = hparams.training_hparams.random_train_val_split)
54 | training_dataset_size = training_dataset.size()
55 | validation_dataset_size = validation_dataset.size()
56 | print("Training set: {0} samples. Validation set: {1} samples."
57 | .format(training_dataset_size, validation_dataset_size))
58 |
59 | print("Sorting training & validation sets to increase training efficiency...")
60 | training_dataset.sort()
61 | validation_dataset.sort()
62 |
63 | #Log the final training dataset if configured to do so
64 | if hparams.training_hparams.log_training_data:
65 | training_data_log_filepath = path.join(model_dir, "training_data.txt")
66 | training_dataset.save(training_data_log_filepath)
67 |
68 | #Create the model
69 | print("Initializing model...")
70 | print()
71 | with ChatbotModel(mode = "train",
72 | model_hparams = hparams.model_hparams,
73 | input_vocabulary = dataset.input_vocabulary,
74 | output_vocabulary = dataset.output_vocabulary,
75 | model_dir = model_dir) as model:
76 |
77 | print()
78 |
79 | #Restore from checkpoint if specified
80 | best_train_checkpoint = "best_weights_training.ckpt"
81 | best_val_checkpoint = "best_weights_validation.ckpt"
82 | training_stats = TrainingStats(hparams.training_hparams)
83 | if resume_checkpoint is not None:
84 | print("Resuming training from checkpoint {0}...".format(resume_checkpoint))
85 | model.load(resume_checkpoint)
86 | training_stats.load(training_stats_filepath)
87 | else:
88 | print("Creating checkpoint batch files...")
89 | general_utils.create_batch_files(model_dir,
90 | best_train_checkpoint if hparams.training_hparams.checkpoint_on_training else None,
91 | best_val_checkpoint if hparams.training_hparams.checkpoint_on_validation else None,
92 | encoder_embeddings_dir,
93 | decoder_embeddings_dir)
94 |
95 | print("Initializing training...")
96 |
97 | print("Epochs: {0}".format(hparams.training_hparams.epochs))
98 | print("Batch Size: {0}".format(hparams.training_hparams.batch_size))
99 | print("Optimizer: {0}".format(hparams.model_hparams.optimizer))
100 |
101 | backup_on_training_loss = sorted(hparams.training_hparams.backup_on_training_loss.copy(), reverse=True)
102 |
103 | #Train on all batches in epoch
104 | for epoch in range(1, hparams.training_hparams.epochs + 1):
105 | batch_counter = 0
106 | batches_starting_time = time.time()
107 | batches_total_train_loss = 0
108 | epoch_starting_time = time.time()
109 | epoch_total_train_loss = 0
110 | train_batches = training_dataset.batches(hparams.training_hparams.batch_size)
111 | for batch_index, (questions, answers, seqlen_questions, seqlen_answers) in enumerate(train_batches):
112 | batch_train_loss = model.train_batch(inputs = questions,
113 | targets = answers,
114 | input_sequence_length = seqlen_questions,
115 | target_sequence_length = seqlen_answers,
116 | learning_rate = training_stats.learning_rate,
117 | dropout = hparams.training_hparams.dropout,
118 | global_step = training_stats.global_step,
119 | log_summary = hparams.training_hparams.log_summary)
120 | batches_total_train_loss += batch_train_loss
121 | epoch_total_train_loss += batch_train_loss
122 | batch_counter += 1
123 | training_stats.global_step += 1
124 | if batch_counter == hparams.training_hparams.stats_after_n_batches or batch_index == (training_dataset_size // hparams.training_hparams.batch_size):
125 | batches_average_train_loss = batches_total_train_loss / batch_counter
126 | epoch_average_train_loss = epoch_total_train_loss / (batch_index + 1)
127 | print('Epoch: {:>3}/{}, Batch: {:>4}/{}, Stats for last {} batches: (Training Loss: {:>6.3f}, Training Time: {:d} seconds), Stats for epoch: (Training Loss: {:>6.3f}, Training Time: {:d} seconds)'.format(
128 | epoch,
129 | hparams.training_hparams.epochs,
130 | batch_index + 1,
131 | math.ceil(training_dataset_size / hparams.training_hparams.batch_size),
132 | batch_counter,
133 | batches_average_train_loss,
134 | int(time.time() - batches_starting_time),
135 | epoch_average_train_loss,
136 | int(time.time() - epoch_starting_time)))
137 | batches_total_train_loss = 0
138 | batch_counter = 0
139 | batches_starting_time = time.time()
140 |
141 | #End of epoch activities
142 | #Run validation
143 | if validation_dataset_size > 0:
144 | total_val_metric_value = 0
145 | batches_starting_time = time.time()
146 | val_batches = validation_dataset.batches(hparams.training_hparams.batch_size)
147 | for batch_index_validation, (questions, answers, seqlen_questions, seqlen_answers) in enumerate(val_batches):
148 | batch_val_metric_value = model.validate_batch(inputs = questions,
149 | targets = answers,
150 | input_sequence_length = seqlen_questions,
151 | target_sequence_length = seqlen_answers,
152 | metric = hparams.training_hparams.validation_metric)
153 | total_val_metric_value += batch_val_metric_value
154 | average_val_metric_value = total_val_metric_value / math.ceil(validation_dataset_size / hparams.training_hparams.batch_size)
155 | print('Epoch: {:>3}/{}, Validation {}: {:>6.3f}, Batch Validation Time: {:d} seconds'.format(
156 | epoch,
157 | hparams.training_hparams.epochs,
158 | hparams.training_hparams.validation_metric,
159 | average_val_metric_value,
160 | int(time.time() - batches_starting_time)))
161 |
162 | #Apply learning rate decay
163 | if hparams.training_hparams.learning_rate_decay > 0:
164 | prev_learning_rate, learning_rate = training_stats.decay_learning_rate()
165 | print('Learning rate decay: adjusting from {:>6.3f} to {:>6.3f}'.format(prev_learning_rate, learning_rate))
166 |
167 | #Checkpoint - training
168 | if training_stats.compare_training_loss(epoch_average_train_loss):
169 | if hparams.training_hparams.checkpoint_on_training:
170 | model.save(best_train_checkpoint)
171 | training_stats.save(training_stats_filepath)
172 | print('Training loss improved!')
173 |
174 | #Checkpoint - validation
175 | if validation_dataset_size > 0:
176 | if training_stats.compare_validation_metric(average_val_metric_value):
177 | if hparams.training_hparams.checkpoint_on_validation:
178 | model.save(best_val_checkpoint)
179 | training_stats.save(training_stats_filepath)
180 | print('Validation {0} improved!'.format(hparams.training_hparams.validation_metric))
181 | else:
182 | if training_stats.early_stopping_check == hparams.training_hparams.early_stopping_epochs:
183 | print("Early stopping checkpoint reached - validation loss has not improved in {0} epochs. Terminating training...".format(hparams.training_hparams.early_stopping_epochs))
184 | break
185 |
186 | #Backup
187 | do_backup = False
188 | while len(backup_on_training_loss) > 0 and epoch_average_train_loss <= backup_on_training_loss[0]:
189 | backup_on_training_loss.pop(0)
190 | do_backup = True
191 | if do_backup:
192 | backup_dir = "{0}_backup_{1}".format(model_dir, "{:0.3f}".format(epoch_average_train_loss).replace(".", "_"))
193 | copytree(model_dir, backup_dir)
194 | general_utils.create_batch_files(backup_dir,
195 | best_train_checkpoint if hparams.training_hparams.checkpoint_on_training else None,
196 | best_val_checkpoint if hparams.training_hparams.checkpoint_on_validation else None,
197 | encoder_embeddings_dir,
198 | decoder_embeddings_dir)
199 | print('Backup to {0} complete!'.format(backup_dir))
200 |
201 | #Training is complete... if no checkpointing was turned on, save the final model state
202 | if not hparams.training_hparams.checkpoint_on_training and not hparams.training_hparams.checkpoint_on_validation:
203 | model.save(best_train_checkpoint)
204 | model.save(best_val_checkpoint)
205 | training_stats.save(training_stats_filepath)
206 | print('Model saved.')
207 | print("Training Complete!")
208 |
--------------------------------------------------------------------------------
/train_console_helper.py:
--------------------------------------------------------------------------------
1 | """
2 | Console helper for training session
3 | """
4 |
5 | def write_vocabulary_import_stats(vocabulary_import_stats):
6 | print(" Stats:")
7 | print(" External vocab size: {0}".format(vocabulary_import_stats.external_vocabulary_size))
8 | if vocabulary_import_stats.dataset_vocabulary_size is not None:
9 | print(" Dataset vocab size: {0}".format(vocabulary_import_stats.dataset_vocabulary_size))
10 | if vocabulary_import_stats.intersection_size is not None:
11 | print(" Intersection size: {0}".format(vocabulary_import_stats.intersection_size))
--------------------------------------------------------------------------------
/training_stats.py:
--------------------------------------------------------------------------------
1 | """
2 | TrainingStats class
3 | """
4 | import jsonpickle
5 |
6 | class TrainingStats(object):
7 | """Class that contains a set of metrics & stats that represent a model at a point in time.
8 | """
9 |
10 | def __init__(self, training_hparams):
11 | """Initializes the TrainingStats instance.
12 |
13 | Args:
14 | training_hparams: the training hyperparameters.
15 | """
16 | self.training_hparams = training_hparams
17 | self.best_validation_metric_value = self._get_metric_baseline(self.training_hparams.validation_metric)
18 | self.best_training_loss = self._get_metric_baseline("loss")
19 | self.learning_rate = self.training_hparams.learning_rate
20 | self.early_stopping_check = 0
21 | self.global_step = 0
22 |
23 | def __getstate__(self):
24 | state = self.__dict__.copy()
25 | del state["training_hparams"]
26 | state["best_validation_metric_value"] = float(self.best_validation_metric_value)
27 | state["best_training_loss"] = float(self.best_training_loss)
28 | return state
29 |
30 | def __setstate__(self, state):
31 | self.__dict__.update(state)
32 |
33 | def compare_training_loss(self, new_value):
34 | """Compare the best training loss against a new value.
35 |
36 | Args:
37 | new_value: the new training loss value to compare against.
38 |
39 | Returns:
40 | True if new_value is better than the best training loss.
41 | False if the best training loss is better than or equal to new_value
42 | """
43 | if self._compare_metric("loss", self.best_training_loss, new_value):
44 | self.best_training_loss = new_value
45 | return True
46 | else:
47 | return False
48 |
49 | def compare_validation_metric(self, new_value):
50 | """Compare the best validation metric value against a new value.
51 | Validation metric is specified in training_hparams.
52 |
53 | Args:
54 | new_value: the new validation metric value to compare against.
55 |
56 | Returns:
57 | True if new_value is better than the best validation metric value.
58 | False if the best validation metric value is better than or equal to new_value
59 | """
60 | if self._compare_metric(self.training_hparams.validation_metric, self.best_validation_metric_value, new_value):
61 | self.best_validation_metric_value = new_value
62 | self.early_stopping_check = 0
63 | return True
64 | else:
65 | self.early_stopping_check += 1
66 | return False
67 |
68 | def decay_learning_rate(self):
69 | """Multiply the current learning rate by the decay coefficient specified in training_hparams.
70 |
71 | If the learning rate falls below the minimum learning rate, it is set to the minimum.
72 | """
73 | prev_learning_rate = self.learning_rate
74 | self.learning_rate *= self.training_hparams.learning_rate_decay
75 | if self.learning_rate < self.training_hparams.min_learning_rate:
76 | self.learning_rate = self.training_hparams.min_learning_rate
77 | return prev_learning_rate, self.learning_rate
78 |
79 | def save(self, filepath):
80 | """Saves the TrainingStats to disk.
81 |
82 | Args:
83 | filepath: The path of the file to save to
84 | """
85 | json = jsonpickle.encode(self)
86 | with open(filepath, "w") as file:
87 | file.write(json)
88 |
89 | def load(self, filepath):
90 | """Loads the TrainingStats from a JSON file.
91 |
92 | Args:
93 | filepath: path of the JSON file.
94 | """
95 | with open(filepath) as file:
96 | json = file.read()
97 | training_stats = jsonpickle.decode(json)
98 | self.best_validation_metric_value = training_stats.best_validation_metric_value
99 | self.best_training_loss = training_stats.best_training_loss
100 | self.learning_rate = training_stats.learning_rate
101 | self.early_stopping_check = training_stats.early_stopping_check
102 | self.global_step = training_stats.global_step
103 |
104 |
105 | def _compare_metric(self, metric, previous_value, new_value):
106 | """Compare a new metric value with its previous known value and determine which value is better.
107 |
108 | Which value is better is specific to the metric.
109 | For instance, loss is a lower-is-better metric while accuracy is a higher-is-better metric.
110 |
111 | Args:
112 | metric: The metric being compared
113 |
114 | previous_value: The previous known value for the metric.
115 |
116 | new_value: The new value to compare against the previous value.
117 |
118 | Returns:
119 | True if new_value is better than previous_value
120 | False if previous_value is better than or equal to new_value
121 | """
122 | if metric == "loss":
123 | return new_value < previous_value
124 | else:
125 | raise ValueError("Unsupported metric: '{0}'".format(metric))
126 |
127 | def _get_metric_baseline(self, metric):
128 | """Gets a baseline value for a metric that can be used to compare the first measurement against.
129 |
130 | For lower-is-better metrics such as loss, this will be a very large number (99999)
131 |
132 | For higher-is-better metrics such as accuracy, this will be 0.
133 |
134 | Args:
135 | metric: The metric for which to get a baseline value
136 | """
137 | if metric == "loss":
138 | return 99999
139 | else:
140 | raise ValueError("Unsupported metric: '{0}'".format(metric))
--------------------------------------------------------------------------------
/vocabulary.py:
--------------------------------------------------------------------------------
1 | """
2 | Vocabulary class
3 | """
4 | import re
5 |
6 | class Vocabulary(object):
7 | """Class representing a chatbot vocabulary.
8 |
9 | The Vocabulary class is responsible for encoding words into integers and decoding integers into words.
10 | The number of times each word occurs in the source corpus is also tracked for visualization purposes.
11 |
12 | Special tokens that exist in every vocabulary instance:
13 | - PAD (""): The token used for extra sequence timesteps in a batch
14 | - SOS (""): Start Of Sequence token is used as the input of the first decoder timestep
15 | - EOS (""): End Of Sequence token is used to signal that the decoder should stop generating a sequence.
16 | It is also used to separate conversation history (context) questions prepended to the current input question.
17 | - OUT (""): If a word does not exist in the vocabulary, it is substituted with this token.
18 | """
19 |
20 | SHARED_VOCAB_FILENAME = "shared_vocab.tsv"
21 | INPUT_VOCAB_FILENAME = "input_vocab.tsv"
22 | OUTPUT_VOCAB_FILENAME = "output_vocab.tsv"
23 |
24 | PAD = ""
25 | SOS = ""
26 | EOS = ""
27 | OUT = ""
28 | special_tokens = [PAD, SOS, EOS, OUT]
29 |
30 | def __init__(self, external_embeddings = None):
31 | """Initializes the Vocabulary instance in an non-compiled state.
32 | Compile must be called before the Vocab instance can be used to integer encode/decode words.
33 |
34 | Args:
35 | external_embeddings: An optional 2d numpy array (matrix) containing external embedding vectors
36 | """
37 | self._word2count = {}
38 | self._words2int = {}
39 | self._ints2word = {}
40 | self._compiled = False
41 | self.external_embeddings = external_embeddings
42 |
43 | def load_word(self, word, word_int, count = 1):
44 | """Load a word and its integer encoding into the vocabulary instance.
45 |
46 | Args:
47 | word: The word to load.
48 |
49 | word_int: The integer encoding of the word to load.
50 |
51 | count: (Optional) The number of times the word occurs in the source corpus.
52 | """
53 | self._validate_compile(False)
54 |
55 | self._word2count[word] = count
56 | self._words2int[word] = word_int
57 | self._ints2word[word_int] = word
58 |
59 | def add_words(self, words):
60 | """Add a sequence of words to the vocabulary instance.
61 | If a word occurs more than once, its count will be incremented accordingly.
62 |
63 | Args:
64 | words: The sequence of words to add.
65 | """
66 | self._validate_compile(False)
67 |
68 | for i in range(len(words)):
69 | word = words[i]
70 | if word in self._word2count:
71 | self._word2count[word] += 1
72 | else:
73 | self._word2count[word] = 1
74 |
75 | def compile(self, vocab_threshold = 1, loading = False):
76 | """Compile the internal lookup dictionaries that enable words to be integer encoded / decoded.
77 |
78 | Args:
79 | vocab_threshold: Minimum number of times any word must appear within word_sequences in order to be included in the vocabulary.
80 | This is useful for filtering out rarely used words in order to reduce the size of the vocabulary
81 | (which consequently reduces the size of the model's embedding matrices & reduces the dimensionality of the output softmax)
82 | This value is ignored if loading is True.
83 |
84 | loading: Indicates if the vocabulary is being loaded from disk, in which case the compilation is already done and this method
85 | only needs to set the flag to indicate as such.
86 |
87 |
88 | """
89 | self._validate_compile(False)
90 |
91 | if not loading:
92 | #Add the special tokens to the lookup dictionaries
93 | for i, special_token in enumerate(Vocabulary.special_tokens):
94 | self._words2int[special_token] = i
95 | self._ints2word[i] = special_token
96 |
97 | #Add the words in _word2count to the lookup dictionaries if their count meets the threshold.
98 | #Any words that don't meet the threshold are removed.
99 | word_int = len(self._words2int)
100 | for word, count in sorted(self._word2count.items()):
101 | if count >= vocab_threshold:
102 | self._words2int[word] = word_int
103 | self._ints2word[word_int] = word
104 | word_int += 1
105 | else:
106 | del self._word2count[word]
107 |
108 | #Add the special tokens to _word2count so they have count values for saving to disk
109 | self.add_words(Vocabulary.special_tokens)
110 |
111 | #The Vocabulary instance may now be used for integer encoding / decoding
112 | self._compiled = True
113 |
114 |
115 |
116 | def size(self):
117 | """The size (number of words) of the Vocabulary
118 | """
119 | self._validate_compile(True)
120 | return len(self._word2count)
121 |
122 | def word_exists(self, word):
123 | """Check if the given word exists in the vocabulary.
124 |
125 | Args:
126 | word: The word to check.
127 | """
128 | self._validate_compile(True)
129 | return word in self._words2int
130 |
131 | def words2ints(self, words):
132 | """Encode a sequence of space delimited words into a sequence of integers
133 |
134 | Args:
135 | words: The sequence of space delimited words to encode
136 | """
137 | return [self.word2int(w) for w in words.split()]
138 |
139 | def word2int(self, word):
140 | """Encode a word into an integer
141 |
142 | Args:
143 | word: The word to encode
144 | """
145 | self._validate_compile(True)
146 | return self._words2int[word] if word in self._words2int else self.out_int()
147 |
148 | def ints2words(self, words_ints, is_punct_discrete_word = False, capitalize_i = True):
149 | """Decode a sequence of integers into a sequence of space delimited words
150 |
151 | Args:
152 | words_ints: The sequence of integers to decode
153 |
154 | is_punct_discrete_word: True to output a space before punctuation
155 | False to place punctuation immediately after the end of the preceeding word (normal usage).
156 | """
157 | words = ""
158 | for i in words_ints:
159 | word = self.int2word(i, capitalize_i)
160 | if is_punct_discrete_word or word not in ['.', '!', '?']:
161 | words += " "
162 | words += word
163 | words = words.strip()
164 | return words
165 |
166 | def int2word(self, word_int, capitalize_i = True):
167 | """Decode an integer into a word
168 |
169 | Args:
170 | words_int: The integer to decode
171 | """
172 | self._validate_compile(True)
173 | word = self._ints2word[word_int]
174 | if capitalize_i and word == 'i':
175 | word = 'I'
176 | return word
177 |
178 | def pad_int(self):
179 | """Get the integer encoding of the PAD token
180 | """
181 | return self.word2int(Vocabulary.PAD)
182 |
183 | def sos_int(self):
184 | """Get the integer encoding of the SOS token
185 | """
186 | return self.word2int(Vocabulary.SOS)
187 |
188 | def eos_int(self):
189 | """Get the integer encoding of the EOS token
190 | """
191 | return self.word2int(Vocabulary.EOS)
192 |
193 | def out_int(self):
194 | """Get the integer encoding of the OUT token
195 | """
196 | return self.word2int(Vocabulary.OUT)
197 |
198 | def save(self, filepath):
199 | """Saves the vocabulary to disk.
200 |
201 | Args:
202 | filepath: The path of the file to save to
203 | """
204 | total_words = self.size()
205 | with open(filepath, "w", encoding="utf-8") as file:
206 | file.write('\t'.join(["word", "count"]))
207 | file.write('\n')
208 | for i in range(total_words):
209 | word = self._ints2word[i]
210 | count = self._word2count[word]
211 | file.write('\t'.join([word, str(count)]))
212 | if i < total_words - 1:
213 | file.write('\n')
214 |
215 | def _validate_compile(self, expected_status):
216 | """Validate that the vocabulary is compiled or not based on the needs of the attempted operation
217 |
218 | Args:
219 | expected_status: The compilation status expected by the attempted operation
220 | """
221 | if self._compiled and not expected_status:
222 | raise ValueError("This vocabulary instance has already been compiled.")
223 | if not self._compiled and expected_status:
224 | raise ValueError("This vocabulary instance has not been compiled yet.")
225 |
226 | @staticmethod
227 | def load(filepath):
228 | """Loads the vocabulary from disk.
229 |
230 | Args:
231 | filepath: The path of the file to load from
232 | """
233 | vocabulary = Vocabulary()
234 |
235 | with open(filepath, encoding="utf-8") as file:
236 | for index, line in enumerate(file):
237 | if index > 0: #Skip header line
238 | word, count = line.split('\t')
239 | word_int = index - 1
240 | vocabulary.load_word(word, word_int, int(count))
241 |
242 | vocabulary.compile(loading = True)
243 | return vocabulary
244 |
245 | @staticmethod
246 | def clean_text(text, max_words = None, normalize_words = True):
247 | """Clean text to prepare for training and inference.
248 |
249 | Clean by removing unsupported special characters & extra whitespace,
250 | and by normalizing common word permutations (i.e. can't, cannot, can not)
251 |
252 | Args:
253 | text: the text to clean
254 |
255 | max_words: maximum number of words to output (assuming words are separated by spaces).
256 | any words beyond this limit are truncated.
257 | Defaults to None (unlimited number of words)
258 |
259 | normalize_words: True to replace word contractions with their full forms (e.g. i'm -> i am)
260 | and then strip out any remaining apostrophes.
261 | """
262 | text = text.lower()
263 | text = re.sub(r"'+", "'", text)
264 | if normalize_words:
265 | text = re.sub(r"i'm", "i am", text)
266 | text = re.sub(r"he's", "he is", text)
267 | text = re.sub(r"she's", "she is", text)
268 | text = re.sub(r"that's", "that is", text)
269 | text = re.sub(r"there's", "there is", text)
270 | text = re.sub(r"what's", "what is", text)
271 | text = re.sub(r"where's", "where is", text)
272 | text = re.sub(r"who's", "who is", text)
273 | text = re.sub(r"how's", "how is", text)
274 | text = re.sub(r"it's", "it is", text)
275 | text = re.sub(r"let's", "let us", text)
276 | text = re.sub(r"\'ll", " will", text)
277 | text = re.sub(r"\'ve", " have", text)
278 | text = re.sub(r"\'re", " are", text)
279 | text = re.sub(r"\'d", " would", text)
280 | text = re.sub(r"won't", "will not", text)
281 | text = re.sub(r"shan't", "shall not", text)
282 | text = re.sub(r"can't", "can not", text)
283 | text = re.sub(r"cannot", "can not", text)
284 | text = re.sub(r"n't", " not", text)
285 | text = re.sub(r"'", "", text)
286 | else:
287 | text = re.sub(r"(\W)'", r"\1", text)
288 | text = re.sub(r"'(\W)", r"\1", text)
289 | text = re.sub(r"[()\"#/@;:<>{}`+=~|$&*%\[\]_]", "", text)
290 | text = re.sub(r"[.]+", " . ", text)
291 | text = re.sub(r"[!]+", " ! ", text)
292 | text = re.sub(r"[?]+", " ? ", text)
293 | text = re.sub(r"[,-]+", " ", text)
294 | text = re.sub(r"[\t]+", " ", text)
295 | text = re.sub(r" +", " ", text)
296 | text = text.strip()
297 |
298 | #Truncate words beyond the limit, if provided. Remove partial sentences from the end if punctuation exists within the limit.
299 | if max_words is not None:
300 | text_parts = text.split()
301 | if len(text_parts) > max_words:
302 | truncated_text_parts = text_parts[:max_words]
303 | while len(truncated_text_parts) > 0 and not re.match("[.!?]", truncated_text_parts[-1]):
304 | truncated_text_parts.pop(-1)
305 | if len(truncated_text_parts) == 0:
306 | truncated_text_parts = text_parts[:max_words]
307 | text = " ".join(truncated_text_parts)
308 |
309 | return text
310 |
311 | @staticmethod
312 | def auto_punctuate(text):
313 | """Automatically apply punctuation to text that does not end with any punctuation marks.
314 |
315 | Args:
316 | text: the text to apply punctuation to.
317 | """
318 | text = text.strip()
319 | if not (text.endswith(".") or text.endswith("?") or text.endswith("!") or text.startswith("--")):
320 | tmp = re.sub(r"'", "", text.lower())
321 | if (tmp.startswith("who") or tmp.startswith("what") or tmp.startswith("when") or
322 | tmp.startswith("where") or tmp.startswith("why") or tmp.startswith("how") or
323 | tmp.endswith("who") or tmp.endswith("what") or tmp.endswith("when") or
324 | tmp.endswith("where") or tmp.endswith("why") or tmp.endswith("how") or
325 | tmp.startswith("are") or tmp.startswith("will") or tmp.startswith("wont") or tmp.startswith("can")):
326 | text = "{}?".format(text)
327 | else:
328 | text = "{}.".format(text)
329 | return text
--------------------------------------------------------------------------------
/vocabulary_importers/__pycache__/vocabulary_importer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kunalBhashkar/seq2seq_chatbot_tensorflow/39b01cf99c8d206504619a97254d8fb53d935206/vocabulary_importers/__pycache__/vocabulary_importer.cpython-36.pyc
--------------------------------------------------------------------------------
/vocabulary_importers/checkpoint_vocabulary_importer.py:
--------------------------------------------------------------------------------
1 | """
2 | Base class for TensorFlow Checkpoint vocabulary importers
3 | """
4 | import tensorflow as tf
5 | from collections import OrderedDict
6 | from os import path
7 | from vocabulary_importers.vocabulary_importer import VocabularyImporter
8 |
9 | class CheckpointVocabularyImporter(VocabularyImporter):
10 | """Base class for TensorFlow Checkpoint vocabulary importers
11 | """
12 |
13 | def __init__(self, vocabulary_name, tokens_filename, embeddings_variable_name):
14 | super(CheckpointVocabularyImporter, self).__init__(vocabulary_name)
15 | """Initialize the CheckpointVocabularyImporter.
16 |
17 | Args:
18 | vocabulary_name: See base class
19 |
20 | tokens_filename: Name of the file containing the token/word list. Subclass must pass this in.
21 |
22 | embeddings_variable_name: Name of the variable to read out of the checkpoint. Subclass must pass this in.
23 | """
24 |
25 | self.tokens_filename = tokens_filename
26 |
27 | self.embeddings_variable_name = embeddings_variable_name
28 |
29 | def _read_vocabulary_and_embeddings(self, vocabulary_dir):
30 | """Read the raw vocabulary file(s) and return the tokens list with corresponding word vectors
31 |
32 | Args:
33 | vocabulary_dir: See base class
34 | """
35 |
36 | #Import embeddings
37 | tf.reset_default_graph()
38 | embeddings = tf.Variable(tf.contrib.framework.load_variable(vocabulary_dir, self.embeddings_variable_name), name = "embeddings")
39 | with tf.Session() as sess:
40 | sess.run(tf.global_variables_initializer())
41 | loaded_embeddings_matrix = sess.run(embeddings)
42 |
43 | #Import vocabulary
44 | tokens_filepath = path.join(vocabulary_dir, self.tokens_filename)
45 | tokens_with_embeddings = OrderedDict()
46 | with open(tokens_filepath, encoding="utf-8") as file:
47 | for index, line in enumerate(file):
48 | token = line.strip()
49 | if token != "":
50 | token = self._process_token(token)
51 | tokens_with_embeddings[token] = loaded_embeddings_matrix[index]
52 |
53 | return tokens_with_embeddings
--------------------------------------------------------------------------------
/vocabulary_importers/dependency_based_vocabulary_importer.py:
--------------------------------------------------------------------------------
1 | """
2 | Importer class for the Dependency-Based vocabulary (Levy & Goldberg, 2014)
3 | https://levyomer.wordpress.com/2014/04/25/dependency-based-word-embeddings/
4 | https://levyomer.files.wordpress.com/2014/04/dependency-based-word-embeddings-acl-2014.pdf
5 | """
6 | from os import path
7 |
8 | from vocabulary_importers.flatfile_vocabulary_importer import FlatFileVocabularyImporter
9 | from vocabulary import Vocabulary
10 |
11 | class DependencyBasedVocabularyImporter(FlatFileVocabularyImporter):
12 | """Importer implementation for the Dependency-Based vocabulary
13 | """
14 | def __init__(self):
15 | super(DependencyBasedVocabularyImporter, self).__init__("dependency_based", "deps.words", " ")
16 |
17 | def _process_token(self, token):
18 | """Perform token preprocessing (See base class for explanation)
19 |
20 | Args:
21 | See base class
22 |
23 | Returns:
24 | See base class
25 | """
26 |
27 | if token == "[":
28 | token = Vocabulary.SOS
29 | elif token == "]":
30 | token = Vocabulary.EOS
31 | elif token == "iz":
32 | token = Vocabulary.OUT
33 | elif token == "--":
34 | token = Vocabulary.PAD
35 | elif token == "''":
36 | token = "?"
37 |
38 | return token
--------------------------------------------------------------------------------
/vocabulary_importers/flatfile_vocabulary_importer.py:
--------------------------------------------------------------------------------
1 | """
2 | Base class for Flat File vocabulary importers
3 | """
4 | import numpy as np
5 | from collections import OrderedDict
6 | from os import path
7 | from vocabulary_importers.vocabulary_importer import VocabularyImporter
8 |
9 | class FlatFileVocabularyImporter(VocabularyImporter):
10 | """Base class for Flat File vocabulary importers
11 | """
12 |
13 | def __init__(self, vocabulary_name, tokens_and_embeddings_filename, delimiter):
14 | super(FlatFileVocabularyImporter, self).__init__(vocabulary_name)
15 | """Initialize the FlatFileVocabularyImporter.
16 |
17 | Args:
18 | vocabulary_name: See base class
19 |
20 | tokens_and_embeddings_filename: Name of the file containing the token/word list and embeddings.
21 | Format should be one line per word where the word is at the beginning of the line and the embedding vector follows
22 | seperated by a delimiter.
23 |
24 | delimiter: Character that separates the word and the values of the embedding vector.
25 | """
26 |
27 | self.tokens_and_embeddings_filename = tokens_and_embeddings_filename
28 |
29 | self.delimiter = delimiter
30 |
31 | def _read_vocabulary_and_embeddings(self, vocabulary_dir):
32 | """Read the raw vocabulary file(s) and return the tokens list with corresponding word vectors
33 |
34 | Args:
35 | vocabulary_dir: See base class
36 | """
37 |
38 | tokens_and_embeddings_filepath = path.join(vocabulary_dir, self.tokens_and_embeddings_filename)
39 | tokens_with_embeddings = OrderedDict()
40 | with open(tokens_and_embeddings_filepath, encoding="utf-8") as file:
41 | for _, line in enumerate(file):
42 | values = line.split(self.delimiter)
43 | token = values[0].strip()
44 | if token != "":
45 | token = self._process_token(token)
46 | tokens_with_embeddings[token] = np.array(values[1:], dtype=np.float32)
47 |
48 | return tokens_with_embeddings
--------------------------------------------------------------------------------
/vocabulary_importers/nnlm_en_vocabulary_importer.py:
--------------------------------------------------------------------------------
1 | """
2 | Importer class for the nnlm english vocabulary (Bengio et al, 2003)
3 | https://www.tensorflow.org/hub/modules/google/nnlm-en-dim128/1
4 | http://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf
5 | """
6 | from os import path
7 |
8 | from vocabulary_importers.checkpoint_vocabulary_importer import CheckpointVocabularyImporter
9 | from vocabulary import Vocabulary
10 |
11 | class NnlmEnVocabularyImporter(CheckpointVocabularyImporter):
12 | """Importer implementation for the nnlm english vocabulary
13 | """
14 | def __init__(self):
15 | super(NnlmEnVocabularyImporter, self).__init__("nnlm_en", "tokens.txt", "embeddings")
16 |
17 | def _process_token(self, token):
18 | """Perform token preprocessing (See base class for explanation)
19 |
20 | Args:
21 | See base class
22 |
23 | Returns:
24 | See base class
25 | """
26 |
27 | if token == "":
28 | token = Vocabulary.SOS
29 | elif token == "":
30 | token = Vocabulary.EOS
31 | elif token == "":
32 | token = Vocabulary.OUT
33 | elif token == "--":
34 | token = Vocabulary.PAD
35 |
36 | return token
--------------------------------------------------------------------------------
/vocabulary_importers/vocabulary_importer.py:
--------------------------------------------------------------------------------
1 | """
2 | Base class for all vocabulary importers
3 | """
4 | import abc
5 | import numpy as np
6 | from enum import Enum
7 | from collections import OrderedDict
8 | from vocabulary import Vocabulary
9 |
10 | class VocabularyImportMode(Enum):
11 | """Vocabulary import modes
12 |
13 | Members:
14 | External: Vocabulary is all the external words. Entire vocabulary has pre-trained embeddings.
15 | Dataset words not in vocabulary are mapped to during training and chatting.
16 |
17 | ExternalIntersectDataset: Vocabulary is only the external words that also exist in the dataset.
18 | Entire vocabulary has pre-trained embeddings. Dataset words not in vocabulary are mapped to during training and chatting.
19 |
20 | ExternalUnionDataset: Vocabulary is the unique set of all external words and all dataset words.
21 | Vocabulary has partially pre-trained and partially randomized embeddings to be learned during training.
22 | Dataset words are always in vocabulary.
23 |
24 | Dataset: Vocabulary is all the dataset words. Vocabulary has partially pre-trained and partially randomized embeddings
25 | to be learned during training. Dataset words are always in vocabulary.
26 | """
27 | External = 1
28 | ExternalIntersectDataset = 2
29 | ExternalUnionDataset = 3
30 | Dataset = 4
31 |
32 | class VocabularyImportStats(object):
33 | """Contains information about the imported vocabulary.
34 | """
35 |
36 | def __init__(self):
37 | self.external_vocabulary_size = None
38 | self.dataset_vocabulary_size = None
39 | self.intersection_size = None
40 |
41 | class VocabularyImporter(object):
42 | """Base class for all vocabulary importers
43 | """
44 |
45 | def __init__(self, vocabulary_name):
46 | """Initialize the VocabularyImporter.
47 |
48 | Args:
49 | vocabulary_name: Name of the vocabulary. Subclass must pass this in.
50 | """
51 | self.vocabulary_name = vocabulary_name
52 |
53 | @abc.abstractmethod
54 | def _process_token(self, token):
55 | """Subclass must implement this.
56 |
57 | Perform any preprocessing (e.g. replacement) to each token before importing it.
58 |
59 | Args:
60 | token: the token/word to process
61 |
62 | Returns:
63 | the processed token/word
64 | """
65 | pass
66 |
67 | @abc.abstractmethod
68 | def _read_vocabulary_and_embeddings(self, vocabulary_dir):
69 | """Subclass must implement this.
70 |
71 | Read the raw vocabulary file(s) and return the tokens list with corresponding word vectors
72 |
73 | Args:
74 | vocabulary_dir: See import_vocabulary
75 |
76 | Returns:
77 | tokens_with_embeddings: OrderedDict of words in the vocabulary with corresponding word vectors as numpy arrays
78 | """
79 | pass
80 |
81 | def import_vocabulary(self, vocabulary_dir, normalize = True, import_mode = VocabularyImportMode.External, dataset_vocab = None):
82 | """Read the raw vocabulary file(s) and use it to initialize a Vocabulary object
83 |
84 | Args:
85 | vocabulary_dir: directory to load the raw vocabulary file from
86 |
87 | normalize: True to convert all word tokens to lower case and then average the
88 | embedding vectors for any duplicate words before import.
89 |
90 | import_mode: indicates import behavior. See VocabularyImportMode class for more info.
91 |
92 | dataset_vocab: Optionally provide the vocabulary instance generated from the dataset.
93 | This must be provided if import_mode is not 'External'.
94 |
95 | Returns:
96 | vocabulary: The final imported vocabulary instance
97 |
98 | import_stats: Informational statistics on the external vocabulary, the dataset vocabulary, and their intersection.
99 | """
100 |
101 | if dataset_vocab is None and import_mode != VocabularyImportMode.External:
102 | raise ValueError("dataset_vocab must be provided if import_mode is not 'External'.")
103 |
104 | import_stats = VocabularyImportStats()
105 |
106 | #Read the external vocabulary tokens and embeddings
107 | tokens_with_embeddings = self._read_vocabulary_and_embeddings(vocabulary_dir)
108 |
109 | #If normalize flag is true, normalize casing of the external vocabulary and average embeddings for any resulting duplicate tokens
110 | if normalize:
111 | tokens_with_embeddings = self._normalize_tokens_with_embeddings(tokens_with_embeddings)
112 |
113 | import_stats.external_vocabulary_size = len(tokens_with_embeddings)
114 |
115 | #Apply dataset filters if applicable
116 | if dataset_vocab is not None:
117 | import_stats.dataset_vocabulary_size = dataset_vocab.size()
118 |
119 | if import_mode == VocabularyImportMode.ExternalIntersectDataset or import_mode == VocabularyImportMode.Dataset:
120 | #Get rid of all tokens that exist in the external vocabulary but don't exist in the dataset
121 | for token in list(tokens_with_embeddings.keys()):
122 | if not dataset_vocab.word_exists(token):
123 | del tokens_with_embeddings[token]
124 | import_stats.intersection_size = len(tokens_with_embeddings)
125 |
126 | if import_mode == VocabularyImportMode.ExternalUnionDataset or import_mode == VocabularyImportMode.Dataset:
127 | #Add any tokens that exist in the dataset but don't exist in the external vocabulary.
128 | #These added tokens will get word vectors sampled from the gaussian distributions of their components:
129 | # where the mean of each component is the mean of that component in the external embedding matrix
130 | # and the standard deviation of each component is the standard deviation of that component in the external embedding matrix
131 | embeddings_matrix = np.array(list(tokens_with_embeddings.values()), dtype=np.float32)
132 | emb_size = embeddings_matrix.shape[1]
133 | emb_mean = np.mean(embeddings_matrix, axis=0)
134 | emb_stdev = np.std(embeddings_matrix, axis=0)
135 | for i in range(dataset_vocab.size()):
136 | dataset_token = dataset_vocab.int2word(i, capitalize_i=False)
137 | if dataset_token not in tokens_with_embeddings:
138 | tokens_with_embeddings[dataset_token] = np.random.normal(emb_mean, emb_stdev, emb_size)
139 |
140 | if len(tokens_with_embeddings) == 0:
141 | raise ValueError("Imported vocabulary size is 0. Try a different VocabularyImportMode (currently {0})".format(
142 | VocabularyImportMode(import_mode).name))
143 |
144 | tokens, embeddings_matrix = zip(*tokens_with_embeddings.items())
145 | embeddings_matrix = np.array(embeddings_matrix, dtype=np.float32)
146 |
147 | #Create the vocabulary instance
148 | vocabulary = Vocabulary(external_embeddings = embeddings_matrix)
149 | for i in range(len(tokens)):
150 | vocabulary.load_word(tokens[i], i)
151 | vocabulary.compile(loading = True)
152 | return vocabulary, import_stats
153 |
154 | def _normalize_tokens_with_embeddings(self, tokens_with_embeddings):
155 | """Convert all word tokens to lower case and then average the embedding vectors for any duplicate words
156 | """
157 | norm_tokens_with_embeddings = OrderedDict()
158 | for token, embedding in tokens_with_embeddings.items():
159 | if token not in Vocabulary.special_tokens:
160 | token = token.lower()
161 | if token in norm_tokens_with_embeddings:
162 | norm_tokens_with_embeddings[token].append(embedding)
163 | else:
164 | norm_tokens_with_embeddings[token] = [embedding]
165 |
166 | for token, embedding in norm_tokens_with_embeddings.items():
167 | norm_tokens_with_embeddings[token] = np.mean(embedding, axis=0)
168 |
169 | return norm_tokens_with_embeddings
--------------------------------------------------------------------------------
/vocabulary_importers/vocabulary_importer_factory.py:
--------------------------------------------------------------------------------
1 | """
2 | Vocabulary importer implementation factory
3 | """
4 | from os import path
5 | from vocabulary_importers.word2vec_wikipedia_vocabulary_importer import Word2vecWikipediaVocabularyImporter
6 | from vocabulary_importers.nnlm_en_vocabulary_importer import NnlmEnVocabularyImporter
7 | from vocabulary_importers.dependency_based_vocabulary_importer import DependencyBasedVocabularyImporter
8 |
9 | def get_vocabulary_importer(vocabulary_dir):
10 | """Gets the appropriate importer implementation for the specified vocabulary name.
11 |
12 | Args:
13 | vocabulary_dir: The directory of the vocabulary to get a importer implementation for.
14 | """
15 | vocabulary_name = path.basename(vocabulary_dir)
16 |
17 | #When adding support for new vocabularies, add an instance of their importer class to the importer array below.
18 | importers = [Word2vecWikipediaVocabularyImporter(),
19 | NnlmEnVocabularyImporter(),
20 | DependencyBasedVocabularyImporter()]
21 |
22 | for importer in importers:
23 | if importer.vocabulary_name == vocabulary_name:
24 | return importer
25 |
26 | raise ValueError("There is no vocabulary importer implementation for '{0}'. If this is a new vocabulary, please add one!".format(vocabulary_name))
--------------------------------------------------------------------------------
/vocabulary_importers/word2vec_wikipedia_vocabulary_importer.py:
--------------------------------------------------------------------------------
1 | """
2 | Importer class for the word2vec wikipedia vocabulary (Mikolov et al, 2013)
3 | https://www.tensorflow.org/hub/modules/google/Wiki-words-250/1
4 | https://arxiv.org/abs/1301.3781
5 | """
6 | from os import path
7 |
8 | from vocabulary_importers.checkpoint_vocabulary_importer import CheckpointVocabularyImporter
9 | from vocabulary import Vocabulary
10 |
11 | class Word2vecWikipediaVocabularyImporter(CheckpointVocabularyImporter):
12 | """Importer implementation for the word2vec wikipedia vocabulary
13 | """
14 | def __init__(self):
15 | super(Word2vecWikipediaVocabularyImporter, self).__init__("word2vec_wikipedia", "tokens.txt", "embeddings")
16 |
17 | def _process_token(self, token):
18 | """Perform token preprocessing (See base class for explanation)
19 |
20 | Args:
21 | See base class
22 |
23 | Returns:
24 | See base class
25 | """
26 |
27 | if token == "":
28 | token = Vocabulary.SOS
29 | elif token == "":
30 | token = Vocabulary.EOS
31 | elif token == "":
32 | token = Vocabulary.OUT
33 | elif token == "#!#":
34 | token = "!"
35 | elif token == "#.#":
36 | token = "."
37 | elif token == "#?#":
38 | token = "?"
39 |
40 | return token
--------------------------------------------------------------------------------