├── LICENSE.txt ├── README.md ├── chatbot.py ├── data └── scotus │ └── scotus.bz2 ├── img └── chatbot-training.png ├── model.py ├── models └── models-folder-readme.txt ├── reddit-parse ├── parser_config_standard.json └── reddit_parse.py ├── train.py └── utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) Copyright (c) 2017 Brian Walker 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chatbot-rnn 2 | A toy chatbot powered by deep learning and trained on data from Reddit. 3 | 4 | Here is a sample chat transcript (not cherry-picked). Lines that start with ">" are me; the lines without ">" are the chatbot's responses. 5 | 6 | ``` 7 | $ python chatbot.py 8 | Creating model... 9 | Restoring weights... 10 | 11 | > Hi, how are you? 12 | It's all good, I'm going to bed now but I'll see if I can get home tomorrow. I'll get back to you. 13 | > Don't go to bed just yet 14 | That's great, what do you think of me? 15 | > I think you're a pretty interesting conversationalist 16 | Thank you, I appreciate it. 17 | > No problem. What do you think of me? 18 | Well, you know, that's fine. I don't think I could get behind any of them. 19 | > I see. Is anyone asking you to get behind them? 20 | Of course not. I don't want to make myself look bad. 21 | > Yeah, I understand. So what's new? 22 | Anything that isn't childish and ignorant of yourself. 23 | > Do you think I'm childish and ignorant? 24 | Yes. Yes you are. 25 | > Why would you say that? 26 | I don't think you know what you are talking about. 27 | > But we've only just started talking! 28 | Yeah.. I guess I shouldn't have gotten over it... 29 | > It's all right. 30 | That's not what I meant. 31 | ``` 32 | 33 | ## Get Started 34 | 35 | - Clone this project to your computer. 36 | 37 | - **Install [Tensorflow](http://www.tensorflow.org/)**. To run a pretrained model, the CPU-only installation should suffice. If you want to train your own models, you'll need the GPU installation of Tensorflow (and a powerful CUDA-compatible GPU). 38 | 39 | ### Run my pre-trained model 40 | 41 | - **Download [my pre-trained model](https://drive.google.com/uc?export=download&id=0B6noVJLTV1jCT29uMzliMVVRWWM)** (201 MB). The zip file extracts into a folder named "reddit". Place that folder into the "models" directory of this project. 42 | 43 | - **Run the chatbot**. Open a terminal session and run `python chatbot.py`. Warning: this pre-trained model was trained on a diverse set of frequently off-color Reddit comments. It can (and eventually will) say things that are offensive, disturbing, bizarre or sexually explicit. It may insult minorities, it may call you names, it may accuse you of being a pedophile, it may try to seduce you. Please don't use the chatbot if these possibilities would distress you! 44 | 45 | Try playing around with the arguments to `chatbot.py` to obtain better samples: 46 | 47 | - **beam_width**: By default, `chatbot.py` will use beam search with a beam width of 2 to sample responses. Set this higher for more careful, more conservative (and slower) responses, or set it to 1 to disable beam search. 48 | 49 | - **temperature**: At each step, the model ascribes a certain probability to each character. Temperature can adjust the probability distribution. 1.0 is neutral (and the default), lower values increase high probability values and decrease lower probability values to make the choices more conservative, and higher values will do the reverse. Values outside of the range of 0.5-1.5 are unlikely to give coherent results. 50 | 51 | - **relevance**: Two models are run in parallel: the primary model and the mask model. The mask model is scaled by the relevance value, and then the probabilities of the primary model are combined according to equation 9 in [Li, Jiwei, et al. "A diversity-promoting objective function for neural conversation models." arXiv preprint arXiv:1510.03055 (2015)](https://arxiv.org/abs/1510.03055). The state of the mask model is reset upon each newline character. The net effect is that the model is encouraged to choose a line of dialogue that is most relevant to the prior line of dialogue, even if a more generic response (e.g. "I don't know anything about that") may be more absolutely probable. Higher relevance values put more pressure on the model to produce relevant responses, at the cost of the coherence of the responses. Going much above 0.4 compromises the quality of the responses. Setting it to a negative value disables relevance, and this is the default, because I'm not confident that it qualitatively improves the outputs and it halves the speed of sampling. 52 | 53 | These values can also be manipulated during a chat, and the model state can be reset, without restarting the chatbot: 54 | 55 | ``` 56 | $ python chatbot.py 57 | Creating model... 58 | Restoring weights... 59 | 60 | > --temperature 1.3 61 | [Temperature set to 1.3] 62 | 63 | > --relevance 0.3 64 | [Relevance set to 0.3] 65 | 66 | > --relevance -1 67 | [Relevance disabled] 68 | 69 | > --beam_width 5 70 | [Beam width set to 5] 71 | 72 | > --reset 73 | [Model state reset] 74 | ``` 75 | 76 | ### Get training data 77 | 78 | If you'd like to train your own model, you'll need training data. There are a few options here. 79 | 80 | - **Provide your own training data.** Training data should be one or more newline-delimited text files. Each line of dialogue should begin with "> " and end with a newline. You'll need a lot of it. Several megabytes of uncompressed text is probably the minimum, and even that may not suffice if you want to train a large model. Text can be provided as raw .txt files or as bzip2-compressed (.bz2) files. 81 | 82 | - **Simulate the United States Supreme Court.** I've included a corpus of United States Supreme Court oral argument transcripts (2.7 MB compressed) in the project under the `data/scotus` directory. 83 | 84 | - **Use Reddit data.** This is what the pre-trained model was trained on: 85 | 86 | First, download a torrent of Reddit comments from a torrent link [listed here](https://www.reddit.com/r/datasets/comments/3bxlg7/i_have_every_publicly_available_reddit_comment/). You can use the single month of comments (~5 GB compressed), or you can download the entire archive (~160 GB compressed). Do not extract the individual bzip2 (.bz2) files contained in these archives. 87 | 88 | Once you have your raw reddit data, place it in the `reddit-parse/reddit_data` subdirectory and use the `reddit-parse.py` script included in the project file to convert them into compressed text files of appropriately formatted conversations. This script chooses qualifying comments (must be under 200 characters, can't contain certain substrings such as 'http://', can't have been posted on certain subreddits) and assembles them into conversations of at least four lines. Coming up with good rules to curate conversations from raw reddit data is more art than science. I encourage you to play around with the parameters in the included `parser_config_standard.json` file, or to mess around with the parsing script itself, to come up with an interesting data set. 89 | 90 | Please be aware that there is a *lot* of Reddit data included in the torrents. It is very easy to run out of memory or hard drive space. I used the entire archive (~160 GB compressed, although some files appear to be corrupted and are skipped by `reddit-parse.py`), and ran the `reddit-parse.py` script with the configuration I included as the default, which holds a million comments (several GB) in memory at a time, takes about 12 hours to run on the entire archive, and produces 2.2 GB of bzip2-compressed output. When training the model, this raw data will be converted into numpy tensors, compressed, and saved back to disk, which consumes another ~5 GB of hard drive space. I acknowledge that this may be overkill relative to the size of the model. 91 | 92 | Once you have training data in hand (and located in a subdirectory of the `data` directory): 93 | 94 | ### Train your own model 95 | 96 | - **Train.** Use `train.py` to train the model. The default hyperparameters (four layers of 1500 GRUs per layer) are the best that I've found, and are what I used to train the pre-trained model for about 37 days. These hyperparameters will just about fill the memory of a Titan X GPU, so if you have a smaller GPU, you will need to adjust them accordingly -- most straightforwardly, by reducing the batch size. 97 | 98 | Training can be interrupted with crtl-c at any time, and will immediately save the model when interrupted. Training can be resumed on a saved model and will automatically carry on from where it was interrupted. 99 | 100 | ![Alt text](/img/chatbot-training.png?raw=true) 101 | 102 | ## Thanks 103 | 104 | Thanks to Andrej Karpathy for his excellent [char-rnn](https://github.com/karpathy/char-rnn) repo, and to Sherjil Ozair for his [tensorflow port](https://github.com/sherjilozair/char-rnn-tensorflow) of char-rnn, which this repo is based on. 105 | -------------------------------------------------------------------------------- /chatbot.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | import argparse 7 | import os 8 | import cPickle 9 | import copy 10 | import sys 11 | import string 12 | 13 | from utils import TextLoader 14 | from model import Model 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--save_dir', type=str, default='models/reddit', 19 | help='model directory to store checkpointed models') 20 | parser.add_argument('-n', type=int, default=500, 21 | help='number of characters to sample') 22 | parser.add_argument('--prime', type=str, default=' ', 23 | help='prime text') 24 | parser.add_argument('--beam_width', type=int, default=2, 25 | help='Width of the beam for beam search, default 2') 26 | parser.add_argument('--temperature', type=float, default=1.0, 27 | help='sampling temperature' 28 | '(lower is more conservative, default is 1.0, which is neutral)') 29 | parser.add_argument('--relevance', type=float, default=-1., 30 | help='amount of "relevance masking/MMI (disabled by default):"' 31 | 'higher is more pressure, 0.4 is probably as high as it can go without' 32 | 'noticeably degrading coherence;' 33 | 'set to <0 to disable relevance masking') 34 | args = parser.parse_args() 35 | sample_main(args) 36 | 37 | def get_paths(input_path): 38 | if os.path.isfile(input_path): 39 | # Passed a model rather than a checkpoint directory 40 | model_path = input_path 41 | save_dir = os.path.dirname(model_path) 42 | elif os.path.exists(input_path): 43 | # Passed a checkpoint directory 44 | save_dir = input_path 45 | checkpoint = tf.train.get_checkpoint_state(save_dir) 46 | if checkpoint: 47 | model_path = checkpoint.model_checkpoint_path 48 | else: 49 | raise ValueError('checkpoint not found in {}.'.format(save_dir)) 50 | else: 51 | raise ValueError('save_dir is not a valid path.') 52 | return model_path, os.path.join(save_dir, 'config.pkl'), os.path.join(save_dir, 'chars_vocab.pkl') 53 | 54 | def sample_main(args): 55 | model_path, config_path, vocab_path = get_paths(args.save_dir) 56 | # Arguments passed to sample.py direct us to a saved model. 57 | # Load the separate arguments by which that model was previously trained. 58 | # That's saved_args. Use those to load the model. 59 | with open(config_path) as f: 60 | saved_args = cPickle.load(f) 61 | # Separately load chars and vocab from the save directory. 62 | with open(vocab_path) as f: 63 | chars, vocab = cPickle.load(f) 64 | # Create the model from the saved arguments, in inference mode. 65 | print("Creating model...") 66 | net = Model(saved_args, True) 67 | config = tf.ConfigProto() 68 | config.gpu_options.allow_growth = True 69 | with tf.Session(config=config) as sess: 70 | # tf.initialize_all_variables().run() 71 | tf.global_variables_initializer().run() 72 | tf.local_variables_initializer().run() 73 | saver = tf.train.Saver(net.save_variables_list()) 74 | # Restore the saved variables, replacing the initialized values. 75 | print("Restoring weights...") 76 | saver.restore(sess, model_path) 77 | chatbot(net, sess, chars, vocab, args.n, args.beam_width, args.relevance, args.temperature) 78 | #beam_sample(net, sess, chars, vocab, args.n, args.prime, 79 | #args.beam_width, args.relevance, args.temperature) 80 | 81 | def initial_state(net, sess): 82 | # Return freshly initialized model states. 83 | return sess.run(net.cell.zero_state(1, tf.float32)) 84 | 85 | def forward_text(net, sess, states, vocab, prime_text=None): 86 | if prime_text is not None: 87 | for char in prime_text: 88 | if len(states) == 2: 89 | # Automatically forward the primary net. 90 | _, states[0] = net.forward_model(sess, states[0], vocab[char]) 91 | # If the token is newline, reset the mask net state; else, forward it. 92 | if vocab[char] == '\n': 93 | states[1] = initial_state(net, sess) 94 | else: 95 | _, states[1] = net.forward_model(sess, states[1], vocab[char]) 96 | else: 97 | _, states = net.forward_model(sess, states, vocab[char]) 98 | return states 99 | 100 | def scale_prediction(prediction, temperature): 101 | if (temperature == 1.0): return prediction # Temperature 1.0 makes no change 102 | np.seterr(divide='ignore') 103 | scaled_prediction = np.log(prediction) / temperature 104 | scaled_prediction = scaled_prediction - np.logaddexp.reduce(scaled_prediction) 105 | scaled_prediction = np.exp(scaled_prediction) 106 | np.seterr(divide='warn') 107 | return scaled_prediction 108 | 109 | def beam_sample(net, sess, chars, vocab, max_length=200, prime='The ', 110 | beam_width = 2, relevance=3.0, temperature=1.0): 111 | states = [initial_state(net, sess), initial_state(net, sess)] 112 | states = forward_text(net, sess, states, vocab, prime) 113 | computer_response_generator = beam_search_generator(sess, net, states, vocab[' '], 114 | None, beam_width, forward_with_mask, (temperature, vocab['\n'])) 115 | for i, char_token in enumerate(computer_response_generator): 116 | print(chars[char_token], end='') 117 | states = forward_text(net, sess, states, vocab, chars[char_token]) 118 | sys.stdout.flush() 119 | if i >= max_length: break 120 | print() 121 | 122 | def sanitize_text(vocab, text): 123 | return ''.join(i for i in text if i in vocab) 124 | 125 | def initial_state_with_relevance_masking(net, sess, relevance): 126 | if relevance <= 0.: return initial_state(net, sess) 127 | else: return [initial_state(net, sess), initial_state(net, sess)] 128 | 129 | def chatbot(net, sess, chars, vocab, max_length, beam_width, relevance, temperature): 130 | states = initial_state_with_relevance_masking(net, sess, relevance) 131 | while True: 132 | user_input = sanitize_text(vocab, raw_input('\n> ')) 133 | user_command_entered, reset, states, relevance, temperature, beam_width = process_user_command( 134 | user_input, states, relevance, temperature, beam_width) 135 | if reset: states = initial_state_with_relevance_masking(net, sess, relevance) 136 | if user_command_entered: continue 137 | states = forward_text(net, sess, states, vocab, '> ' + user_input + "\n>") 138 | computer_response_generator = beam_search_generator(sess=sess, net=net, 139 | initial_state=copy.deepcopy(states), initial_sample=vocab[' '], 140 | early_term_token=vocab['\n'], beam_width=beam_width, forward_model_fn=forward_with_mask, 141 | forward_args=(relevance, vocab['\n']), temperature=temperature) 142 | for i, char_token in enumerate(computer_response_generator): 143 | print(chars[char_token], end='') 144 | states = forward_text(net, sess, states, vocab, chars[char_token]) 145 | sys.stdout.flush() 146 | if i >= max_length: break 147 | states = forward_text(net, sess, states, vocab, '\n> ') 148 | 149 | def process_user_command(user_input, states, relevance, temperature, beam_width): 150 | user_command_entered = False 151 | reset = False 152 | try: 153 | if user_input.startswith('--temperature '): 154 | user_command_entered = True 155 | temperature = max(0.001, float(user_input[len('--temperature '):])) 156 | print("[Temperature set to {}]".format(temperature)) 157 | elif user_input.startswith('--relevance '): 158 | user_command_entered = True 159 | new_relevance = float(user_input[len('--relevance '):]) 160 | if relevance <= 0. and new_relevance > 0.: 161 | states = [states, copy.deepcopy(states)] 162 | elif relevance > 0. and new_relevance <= 0.: 163 | states = states[0] 164 | relevance = new_relevance 165 | print("[Relevance disabled]" if relevance < 0. else "[Relevance set to {}]".format(relevance)) 166 | elif user_input.startswith('--beam_width '): 167 | user_command_entered = True 168 | beam_width = max(1, int(user_input[len('--beam_width '):])) 169 | print("[Beam width set to {}]".format(beam_width)) 170 | elif user_input.startswith('--reset'): 171 | user_command_entered = True 172 | reset = True 173 | print("[Model state reset]") 174 | except ValueError: 175 | print("[Value error with provided argument.]") 176 | return user_command_entered, reset, states, relevance, temperature, beam_width 177 | 178 | def consensus_length(beam_outputs, early_term_token): 179 | for l in xrange(len(beam_outputs[0])): 180 | if l > 0 and beam_outputs[0][l-1] == early_term_token: 181 | return l-1, True 182 | for b in beam_outputs[1:]: 183 | if beam_outputs[0][l] != b[l]: return l, False 184 | return l, False 185 | 186 | def forward_with_mask(sess, net, states, input_sample, forward_args): 187 | if len(states) != 2: 188 | # No relevance masking. 189 | prob, states = net.forward_model(sess, states, input_sample) 190 | return prob / sum(prob), states 191 | # states should be a 2-length list: [primary net state, mask net state]. 192 | # forward_args should be a 2-length list/tuple: [relevance, mask_reset_token] 193 | relevance, mask_reset_token = forward_args 194 | if input_sample == mask_reset_token: 195 | # Reset the mask probs when reaching mask_reset_token (newline). 196 | states[1] = initial_state(net, sess) 197 | primary_prob, states[0] = net.forward_model(sess, states[0], input_sample) 198 | primary_prob /= sum(primary_prob) 199 | mask_prob, states[1] = net.forward_model(sess, states[1], input_sample) 200 | mask_prob /= sum(mask_prob) 201 | combined_prob = np.exp(np.log(primary_prob) - relevance * np.log(mask_prob)) 202 | # Normalize probabilities so they sum to 1. 203 | return combined_prob / sum(combined_prob), states 204 | 205 | def beam_search_generator(sess, net, initial_state, initial_sample, 206 | early_term_token, beam_width, forward_model_fn, forward_args, temperature): 207 | '''Run beam search! Yield consensus tokens sequentially, as a generator; 208 | return when reaching early_term_token (newline). 209 | 210 | Args: 211 | sess: tensorflow session reference 212 | net: tensorflow net graph (must be compatible with the forward_net function) 213 | initial_state: initial hidden state of the net 214 | initial_sample: single token (excluding any seed/priming material) 215 | to start the generation 216 | early_term_token: stop when the beam reaches consensus on this token 217 | (but do not return this token). 218 | beam_width: how many beams to track 219 | forward_model_fn: function to forward the model, must be of the form: 220 | probability_output, beam_state = 221 | forward_model_fn(sess, net, beam_state, beam_sample, forward_args) 222 | (Note: probability_output has to be a valid probability distribution!) 223 | temperature: how conservatively to sample tokens from each distribution 224 | (1.0 = neutral, lower means more conservative) 225 | tot_steps: how many tokens to generate before stopping, 226 | unless already stopped via early_term_token. 227 | Returns: a generator to yield a sequence of beam-sampled tokens.''' 228 | # Store state, outputs and probabilities for up to args.beam_width beams. 229 | # Initialize with just the one starting entry; it will branch to fill the beam 230 | # in the first step. 231 | beam_states = [initial_state] # Stores the best activation states 232 | beam_outputs = [[initial_sample]] # Stores the best generated output sequences so far. 233 | beam_probs = [1.] # Stores the cumulative normalized probabilities of the beams so far. 234 | 235 | while True: 236 | # Keep a running list of the best beam branches for next step. 237 | # Don't actually copy any big data structures yet, just keep references 238 | # to existing beam state entries, and then clone them as necessary 239 | # at the end of the generation step. 240 | new_beam_indices = [] 241 | new_beam_probs = [] 242 | new_beam_samples = [] 243 | 244 | # Iterate through the beam entries. 245 | for beam_index, beam_state in enumerate(beam_states): 246 | beam_prob = beam_probs[beam_index] 247 | beam_sample = beam_outputs[beam_index][-1] 248 | 249 | # Forward the model. 250 | prediction, beam_states[beam_index] = forward_model_fn( 251 | sess, net, beam_state, beam_sample, forward_args) 252 | prediction = scale_prediction(prediction, temperature) 253 | 254 | # Sample best_tokens from the probability distribution. 255 | # Sample from the scaled probability distribution beam_width choices 256 | # (but not more than the number of positive probabilities in scaled_prediction). 257 | count = min(beam_width, sum(1 if p > 0. else 0 for p in prediction)) 258 | best_tokens = np.random.choice(len(prediction), size=count, 259 | replace=False, p=prediction) 260 | for token in best_tokens: 261 | prob = prediction[token] * beam_prob 262 | if len(new_beam_indices) < beam_width: 263 | # If we don't have enough new_beam_indices, we automatically qualify. 264 | new_beam_indices.append(beam_index) 265 | new_beam_probs.append(prob) 266 | new_beam_samples.append(token) 267 | else: 268 | # Sample a low-probability beam to possibly replace. 269 | np_new_beam_probs = np.array(new_beam_probs) 270 | inverse_probs = -np_new_beam_probs + max(np_new_beam_probs) + min(np_new_beam_probs) 271 | inverse_probs = inverse_probs / sum(inverse_probs) 272 | sampled_beam_index = np.random.choice(beam_width, p=inverse_probs) 273 | if new_beam_probs[sampled_beam_index] <= prob: 274 | # Replace it. 275 | new_beam_indices[sampled_beam_index] = beam_index 276 | new_beam_probs[sampled_beam_index] = prob 277 | new_beam_samples[sampled_beam_index] = token 278 | # Replace the old states with the new states, first by referencing and then by copying. 279 | already_referenced = [False] * beam_width 280 | new_beam_states = [] 281 | new_beam_outputs = [] 282 | for i, new_index in enumerate(new_beam_indices): 283 | if already_referenced[new_index]: 284 | new_beam = copy.deepcopy(beam_states[new_index]) 285 | else: 286 | new_beam = beam_states[new_index] 287 | already_referenced[new_index] = True 288 | new_beam_states.append(new_beam) 289 | new_beam_outputs.append(beam_outputs[new_index] + [new_beam_samples[i]]) 290 | # Normalize the beam probabilities so they don't drop to zero 291 | beam_probs = new_beam_probs / sum(new_beam_probs) 292 | beam_states = new_beam_states 293 | beam_outputs = new_beam_outputs 294 | # Prune the agreed portions of the outputs 295 | # and yield the tokens on which the beam has reached consensus. 296 | l, early_term = consensus_length(beam_outputs, early_term_token) 297 | if l > 0: 298 | for token in beam_outputs[0][:l]: yield token 299 | beam_outputs = [output[l:] for output in beam_outputs] 300 | if early_term: return 301 | 302 | if __name__ == '__main__': 303 | main() 304 | -------------------------------------------------------------------------------- /data/scotus/scotus.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pierian-Data/chatbot-rnn/123c1cc53b111ec16a96d9b115d25dbcd5cb2aa7/data/scotus/scotus.bz2 -------------------------------------------------------------------------------- /img/chatbot-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pierian-Data/chatbot-rnn/123c1cc53b111ec16a96d9b115d25dbcd5cb2aa7/img/chatbot-training.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | class Model(): 5 | def __init__(self, args, infer=False): # infer is set to true during sampling. 6 | self.args = args 7 | if infer: 8 | # Worry about one character at a time during sampling; no batching or BPTT. 9 | args.batch_size = 1 10 | args.seq_length = 1 11 | 12 | # Set cell_fn to the type of network cell we're creating -- RNN, GRU or LSTM. 13 | if args.model == 'rnn': 14 | # cell_fn = tf.nn.rnn_cell.BasicRNNCell 15 | cell = tf.nn.rnn_cell.BasicRNNCell(args.rnn_size) 16 | elif args.model == 'gru': 17 | # cell_fn = tf.nn.rnn_cell.GRUCell 18 | cell = tf.nn.rnn_cell.GRUCell(args.rnn_size) 19 | elif args.model == 'lstm': 20 | # cell_fn = tf.nn.rnn_cell.BasicLSTMCell 21 | cell = tf.nn.rnn_cell.BasicLSTMCell(args.rnn_size, state_is_tuple=True) 22 | elif args.model == 'nas': 23 | cell = tf.contrib.rnn.NASCell(args.rnn_size) 24 | else: 25 | raise Exception("model type not supported: {}".format(args.model)) 26 | 27 | # Call tensorflow library tensorflow-master/tensorflow/python/ops/rnn_cell 28 | # to create a layer of rnn_size cells of the specified basic type (RNN/GRU/LSTM). 29 | # cell = cell_fn(args.rnn_size, state_is_tuple=True) 30 | 31 | # Use the same rnn_cell library to create a stack of these cells 32 | # of num_layers layers. Pass in a python list of these cells. 33 | # (The [cell] * arg.num_layers syntax literally duplicates cell multiple times in 34 | # a list. The syntax is such that [5, 6] * 3 would return [5, 6, 5, 6, 5, 6].) 35 | # self.cell = cell = tf.nn.rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True) 36 | self.cell = cell = tf.nn.rnn_cell.MultiRNNCell([cell for _ in range(args.num_layers)], state_is_tuple=True) 37 | 38 | # Create two TF placeholder nodes of 32-bit ints (NOT floats!), 39 | # each of shape batch_size x seq_length. This shape matches the batches 40 | # (listed in x_batches and y_batches) constructed in create_batches in utils.py. 41 | # input_data will receive input batches, and targets will be what it compares against 42 | # to calculate loss. 43 | self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) 44 | self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) 45 | 46 | # Using the zero_state function in the RNNCell master class in rnn_cell library, 47 | # create a tensor of zeros such that we can swap it in for the network state at any time 48 | # to zero out the network's state. 49 | # State dimensions are: cell_fn state size (2 for LSTM) x rnn_size x num_layers. 50 | # So an LSTM network with 100 cells per layer and 3 layers would have a state size of 600, 51 | # and initial_state would have a dimension of none x 600. 52 | self.initial_state = self.cell.zero_state(args.batch_size, tf.float32) 53 | 54 | # Scope our new variables to the scope identifier string "rnnlm". 55 | with tf.variable_scope('rnnlm'): 56 | # Create new variable softmax_w and softmax_b for output. 57 | # softmax_w is a weights matrix from the top layer of the model (of size rnn_size) 58 | # to the vocabulary output (of size vocab_size). 59 | softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size]) 60 | # softmax_b is a bias vector of the ouput characters (of size vocab_size). 61 | softmax_b = tf.get_variable("softmax_b", [args.vocab_size]) 62 | # [TODO: Why specify CPU? Same as the TF translation tutorial, but don't know why.] 63 | with tf.device("/cpu:0"): 64 | # Create new variable named 'embedding' to connect the character input to the base layer 65 | # of the RNN. Its role is the conceptual inverse of softmax_w. 66 | # It contains the trainable weights from the one-hot input vector to the lowest layer of RNN. 67 | embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size]) 68 | # Create an embedding tensor with tf.nn.embedding_lookup(embedding, self.input_data). 69 | # This tensor has dimensions batch_size x seq_length x rnn_size. 70 | # tf.split splits that embedding lookup tensor into seq_length tensors (along dimension 1). 71 | # Thus inputs is a list of seq_length different tensors, 72 | # each of dimension batch_size x 1 x rnn_size. 73 | inputs = tf.split(axis=1, num_or_size_splits=args.seq_length, value=tf.nn.embedding_lookup(embedding, self.input_data)) 74 | # Iterate through these resulting tensors and eliminate that degenerate second dimension of 1, 75 | # i.e. squeeze each from batch_size x 1 x rnn_size down to batch_size x rnn_size. 76 | # Thus we now have a list of seq_length tensors, each with dimension batch_size x rnn_size. 77 | inputs = [tf.squeeze(input_, [1]) for input_ in inputs] 78 | 79 | # THIS LOOP FUNCTION IS NEVER ACTUALLY USED. 80 | # IT IS EXPLICITLY NOT USED DURING TRAINING. 81 | # DURING INFERENCE, SEQ_LENGTH == 1, SO SEQ2SEQ.RNN_DECODER() ONLY USES THE LOOP ARGUMENT 82 | # ON SEQUENCE LENGTH ITEMS SUBSEQUENT TO THE FIRST. 83 | # This looping function is used as part of seq2seq.rnn_decoder only during sampling -- not training. 84 | # prev is a 2D Tensor of shape [batch_size x cell.output_size]. 85 | # returns a 2D Tensor of shape [batch_size x cell.input_size]. 86 | def loop(prev, _): 87 | # prev is initially the top cell state. 88 | # Convert the top cell state into character logits. 89 | prev = tf.matmul(prev, softmax_w) + softmax_b 90 | # Pull the character with the greatest logit (no sampling, just argmaxing). 91 | # WHY IS THIS ARGMAXING WHEN ACTUAL SAMPLING IS DONE PROBABILISTICALLY? 92 | # DOESN'T THIS CAUSE OUTPUTS NOT TO MATCH INPUTS DURING SEQUENCE GENERATION? 93 | prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) 94 | # Re-embed that symbol as the next step's input, and return that. 95 | return tf.nn.embedding_lookup(embedding, prev_symbol) 96 | 97 | # Set up a seq2seq decoder from the seq2seq.py library. 98 | # This constructs the outputs and states nodes of the network. 99 | # Outputs is a list (of len seq_length, same as inputs) of tensors of shape [batch_size x rnn_size]. 100 | # These are the raw output values of the top layer of the network at each time step. 101 | # They have NOT been fed through the decoder projection; they are still in network space, 102 | # not character space. 103 | # State is a tensor of shape [batch_size x cell.state_size]. 104 | # This is also the step where all of the trainable parameters for the LSTM (weights and biases) are defined. 105 | outputs, self.final_state = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, 106 | self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm') 107 | # tf.concat concatenates the output tensors along the rnn_size dimension, 108 | # to make a single tensor of shape [batch_size x (seq_length * rnn_size)]. 109 | # This gives the following 2D outputs matrix: 110 | # [(rnn output: batch 0, seq 0) (rnn output: batch 0, seq 1) ... (rnn output: batch 0, seq seq_len-1)] 111 | # [(rnn output: batch 1, seq 0) (rnn output: batch 1, seq 1) ... (rnn output: batch 1, seq seq_len-1)] 112 | # ... 113 | # [(rnn output: batch batch_size-1, seq 0) (rnn output: batch batch_size-1, seq 1) ... (rnn output: batch batch_size-1, seq seq_len-1)] 114 | # tf.reshape then reshapes it to a tensor of shape [(batch_size * seq_length) x rnn_size]. 115 | # Output will now be the following matrix: 116 | # [rnn output: batch 0, seq 0] 117 | # [rnn output: batch 0, seq 1] 118 | # ... 119 | # [rnn output: batch 0, seq seq_len-1] 120 | # [rnn output: batch 1, seq 0] 121 | # [rnn output: batch 1, seq 1] 122 | # ... 123 | # [rnn output: batch 1, seq seq_len-1] 124 | # ... 125 | # ... 126 | # [rnn output: batch batch_size-1, seq seq_len-1] 127 | # Note the following comment in rnn_cell.py: 128 | # Note: in many cases it may be more efficient to not use this wrapper, 129 | # but instead concatenate the whole sequence of your outputs in time, 130 | # do the projection on this batch-concatenated sequence, then split it 131 | # if needed or directly feed into a softmax. 132 | output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, args.rnn_size]) 133 | # Obtain logits node by applying output weights and biases to the output tensor. 134 | # Logits is a tensor of shape [(batch_size * seq_length) x vocab_size]. 135 | # Recall that outputs is a 2D tensor of shape [(batch_size * seq_length) x rnn_size], 136 | # and softmax_w is a 2D tensor of shape [rnn_size x vocab_size]. 137 | # The matrix product is therefore a new 2D tensor of [(batch_size * seq_length) x vocab_size]. 138 | # In other words, that multiplication converts a loooong list of rnn_size vectors 139 | # to a loooong list of vocab_size vectors. 140 | # Then add softmax_b (a single vocab-sized vector) to every row of that list. 141 | # That gives you the logits! 142 | self.logits = tf.matmul(output, softmax_w) + softmax_b 143 | # Convert logits to probabilities. Probs isn't used during training! That node is never calculated. 144 | # Like logits, probs is a tensor of shape [(batch_size * seq_length) x vocab_size]. 145 | # During sampling, this means it is of shape [1 x vocab_size]. 146 | self.probs = tf.nn.softmax(self.logits) 147 | # seq2seq.sequence_loss_by_example returns 1D float Tensor containing the log-perplexity 148 | # for each sequence. (Size is batch_size * seq_length.) 149 | # Targets are reshaped from a [batch_size x seq_length] tensor to a 1D tensor, of the following layout: 150 | # target character (batch 0, seq 0) 151 | # target character (batch 0, seq 1) 152 | # ... 153 | # target character (batch 0, seq seq_len-1) 154 | # target character (batch 1, seq 0) 155 | # ... 156 | # These targets are compared to the logits to generate loss. 157 | # Logits: instead of a list of character indices, it's a list of character index probability vectors. 158 | # seq2seq.sequence_loss_by_example will do the work of generating losses by comparing the one-hot vectors 159 | # implicitly represented by the target characters against the probability distrutions in logits. 160 | # It returns a 1D float tensor (a vector) where item i is the log-perplexity of 161 | # the comparison of the ith logit distribution to the ith one-hot target vector. 162 | loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([self.logits], # logits: 1-item list of 2D Tensors of shape [batch_size x vocab_size] 163 | [tf.reshape(self.targets, [-1])], # targets: 1-item list of 1D batch-sized int32 Tensors of the same length as logits 164 | [tf.ones([args.batch_size * args.seq_length])], # weights: 1-item list of 1D batch-sized float-Tensors of the same length as logits 165 | args.vocab_size) # num_decoder_symbols: integer, number of decoder symbols (output classes) 166 | # Cost is the arithmetic mean of the values of the loss tensor 167 | # (the sum divided by the total number of elements). 168 | # It is a single-element floating point tensor. This is what the optimizer seeks to minimize. 169 | self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length 170 | # Create a summary for our cost. 171 | # tf.scalar_summary("cost", self.cost) 172 | tf.summary.scalar("cost", self.cost) 173 | # Create a node to track the learning rate as it decays through the epochs. 174 | self.lr = tf.Variable(args.learning_rate, trainable=False) 175 | self.global_epoch_fraction = tf.Variable(0.0, trainable=False) 176 | self.global_seconds_elapsed = tf.Variable(0.0, trainable=False) 177 | tvars = tf.trainable_variables() # tvars is a python list of all trainable TF Variable objects. 178 | 179 | # tf.gradients returns a list of tensors of length len(tvars) where each tensor is sum(dy/dx). 180 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), 181 | args.grad_clip) 182 | optimizer = tf.train.AdamOptimizer(self.lr) # Use ADAM optimizer with the current learning rate. 183 | # Zip creates a list of tuples, where each tuple is (variable tensor, gradient tensor). 184 | # Training op nudges the variables along the gradient, with the given learning rate, using the ADAM optimizer. 185 | # This is the op that a training session should be instructed to perform. 186 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 187 | # self.summary_op = tf.merge_all_summaries() 188 | self.summary_op = tf.summary.merge_all() 189 | 190 | def save_variables_list(self): 191 | # Return a list of the trainable variables created within the rnnlm model. 192 | # This consists of the two projection softmax variables (softmax_w and softmax_b), 193 | # embedding, and all of the weights and biases in the MultiRNNCell model. 194 | # Save only the trainable variables and the placeholders needed to resume training; 195 | # discard the rest, including optimizer state. 196 | save_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='rnnlm') 197 | save_vars.append(self.lr) 198 | save_vars.append(self.global_epoch_fraction) 199 | save_vars.append(self.global_seconds_elapsed) 200 | return save_vars 201 | 202 | def forward_model(self, sess, state, input_sample): 203 | '''Run a forward pass. Return the updated hidden state and the output probabilities.''' 204 | shaped_input = np.array([[input_sample]], np.float32) 205 | inputs = {self.input_data: shaped_input, 206 | self.initial_state: state} 207 | [probs, state] = sess.run([self.probs, self.final_state], feed_dict=inputs) 208 | return probs[0], state 209 | -------------------------------------------------------------------------------- /models/models-folder-readme.txt: -------------------------------------------------------------------------------- 1 | Place folders containing downloaded models in this directory. 2 | 3 | You can download my own pre-trained model here: https://drive.google.com/uc?export=download&id=0B6noVJLTV1jCT29uMzliMVVRWWM -------------------------------------------------------------------------------- /reddit-parse/parser_config_standard.json: -------------------------------------------------------------------------------- 1 | { 2 | "subreddit_whitelist": [], 3 | "subreddit_blacklist": [ 4 | "sports", 5 | "announcements", 6 | "blog", 7 | "gaming", 8 | "CasualPokemonTrades", 9 | "pokemontrades", 10 | "YamakuHighSchool", 11 | "XMenRP" 12 | ], 13 | "substring_blacklist": [ 14 | "[", 15 | "http://", 16 | "https://", 17 | " r/", 18 | " u/", 19 | "/r/", 20 | "/u/", 21 | "reddit", 22 | "Reddit", 23 | "upvote", 24 | "downvote", 25 | "upvoting", 26 | "downvoting", 27 | "OOC:" 28 | ] 29 | } -------------------------------------------------------------------------------- /reddit-parse/reddit_parse.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from bz2 import BZ2File 3 | import argparse 4 | import os 5 | import json 6 | import re 7 | import sys 8 | 9 | FILE_SUFFIX = ".bz2" 10 | OUTPUT_FILE = "output.bz2" 11 | REPORT_FILE = "RC_report.txt" 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--input_file', type=str, default='reddit_data', 16 | help='data file or directory containing bz2 archive of json reddit data') 17 | parser.add_argument('--logdir', type=str, default='standard', 18 | help='directory to save the output and report') 19 | parser.add_argument('--config_file', type=str, default='parser_config_standard.json', 20 | help='json parameters for parsing') 21 | parser.add_argument('--comment_cache_size', type=int, default=1e7, 22 | help='max number of comments to cache in memory before flushing') 23 | parser.add_argument('--output_file_size', type=int, default=2e8, 24 | help='max number of comments to cache in memory before flushing') 25 | parser.add_argument('--print_every', type=int, default=100, 26 | help='print an update to the screen this often') 27 | args = parser.parse_args() 28 | parse_main(args) 29 | 30 | class RedditComment(object): 31 | def __init__(self, json_object): 32 | self.body = json_object['body'] 33 | self.score = json_object['ups'] - json_object['downs'] 34 | self.author = json_object['author'] 35 | self.parent_id = json_object['parent_id'] 36 | self.child_id = None 37 | 38 | def parse_main(args): 39 | if not os.path.isfile(args.config_file): 40 | print("File not found: {}".format(args.input_file)) 41 | return 42 | with open(args.config_file, 'r') as f: 43 | config = json.load(f) 44 | subreddit_blacklist = set(config['subreddit_blacklist']) 45 | subreddit_whitelist = set(config['subreddit_whitelist']) 46 | substring_blacklist = set(config['substring_blacklist']) 47 | 48 | if not os.path.exists(args.input_file): 49 | print("File not found: {}".format(args.input_file)) 50 | return 51 | if os.path.isfile(args.logdir): 52 | print("File already exists at output directory location: {}".format(args.logdir)) 53 | return 54 | if not os.path.exists(args.logdir): 55 | os.mkdir(args.logdir) 56 | subreddit_dict = {} 57 | comment_dict = {} 58 | cache_count = 0 59 | raw_data = raw_data_generator(args.input_file) 60 | output_handler = OutputHandler(os.path.join(args.logdir, OUTPUT_FILE), args.output_file_size) 61 | for i, line in enumerate(raw_data): 62 | if len(line) > 1 and (line[-1] == '}' or line[-2] == '}'): 63 | comment = json.loads(line) 64 | if post_qualifies(comment, subreddit_blacklist, 65 | subreddit_whitelist, substring_blacklist): 66 | sub = comment['subreddit'] 67 | if sub in subreddit_dict: 68 | subreddit_dict[sub] += 1 69 | else: subreddit_dict[sub] = 1 70 | comment_dict[comment['name']] = RedditComment(comment) 71 | cache_count += 1 72 | if cache_count % args.print_every == 0: 73 | print("\rCached {} comments".format(cache_count), end='') 74 | sys.stdout.flush() 75 | if cache_count > args.comment_cache_size: 76 | print() 77 | process_comment_cache(comment_dict, args.print_every) 78 | write_comment_cache(comment_dict, output_handler, args.print_every) 79 | write_report(os.path.join(args.logdir, REPORT_FILE), subreddit_dict) 80 | comment_dict.clear() 81 | cache_count = 0 82 | print("\nRead all {} lines from {}.".format(i, args.input_file)) 83 | process_comment_cache(comment_dict, args.print_every) 84 | write_comment_cache(comment_dict, output_handler, args.print_every) 85 | write_report(os.path.join(args.logdir, REPORT_FILE), subreddit_dict) 86 | 87 | def raw_data_generator(path): 88 | if os.path.isdir(path): 89 | for walk_root, walk_dir, walk_files in os.walk(path): 90 | for file_name in walk_files: 91 | file_path = os.path.join(walk_root, file_name) 92 | if file_path.endswith(FILE_SUFFIX): 93 | print("\nReading from {}".format(file_path)) 94 | with BZ2File(file_path, "r") as raw_data: 95 | try: 96 | for line in raw_data: yield line 97 | except IOError: 98 | print("IOError from file {}".format(file_path)) 99 | continue 100 | else: print("Skipping file {} (doesn't end with {})".format(file_path, FILE_SUFFIX)) 101 | elif os.path.isfile(path): 102 | print("Reading from {}".format(path)) 103 | with BZ2File(path, "r") as raw_data: 104 | for line in raw_data: yield line 105 | 106 | class OutputHandler(): 107 | def __init__(self, path, output_file_size): 108 | if path.endswith(FILE_SUFFIX): 109 | path = path[:-len(FILE_SUFFIX)] 110 | self.base_path = path 111 | self.output_file_size = output_file_size 112 | self.file_reference = None 113 | 114 | def write(self, data): 115 | if self.file_reference is None: 116 | self._get_current_path() 117 | self.file_reference.write(data) 118 | self.current_file_size += len(data) 119 | if self.current_file_size >= self.output_file_size: 120 | self.file_reference.close() 121 | self.file_reference = None 122 | 123 | def _get_current_path(self): 124 | i = 1 125 | while True: 126 | path = "{} {}{}".format(self.base_path, i, FILE_SUFFIX) 127 | if not os.path.exists(path): break 128 | i += 1 129 | self.current_path = path 130 | self.current_file_size = 0 131 | self.file_reference = BZ2File(self.current_path, "w") 132 | 133 | def post_qualifies(json_object, subreddit_blacklist, 134 | subreddit_whitelist, substring_blacklist): 135 | body = json_object['body'].encode('ascii', 'ignore').strip() 136 | post_length = len(body) 137 | if post_length < 4 or post_length > 200: return False 138 | subreddit = json_object['subreddit'] 139 | if len(subreddit_whitelist) > 0 and subreddit not in subreddit_whitelist: return False 140 | if len(subreddit_blacklist) > 0 and subreddit in subreddit_blacklist: return False 141 | if len(substring_blacklist) > 0: 142 | for substring in substring_blacklist: 143 | if body.find(substring) >= 0: return False 144 | # Preprocess the comment text. 145 | body = re.sub('[ \t\n]+', ' ', body) # Replace runs of whitespace with a single space. 146 | body = re.sub('\^', '', body) # Strip out carets. 147 | body = re.sub('\\\\', '', body) # Strip out backslashes. 148 | body = re.sub('<', '<', body) # Replace '<' with '<' 149 | body = re.sub('>', '>', body) # Replace '>' with '>' 150 | body = re.sub('&', '&', body) # Replace '&' with '&' 151 | post_length = len(body) 152 | if post_length < 4 or post_length > 200: return False 153 | json_object['body'] = body # Save our changes 154 | return True 155 | 156 | def process_comment_cache(comment_dict, print_every): 157 | i = 0 158 | for my_id, my_comment in comment_dict.items(): 159 | i += 1 160 | if i % print_every == 0: 161 | print("\rProcessed {} comments".format(i), end='') 162 | sys.stdout.flush() 163 | if my_comment.parent_id is not None: # If we're not a top-level post... 164 | if my_comment.parent_id in comment_dict: # ...and the parent is in our data set... 165 | parent = comment_dict[my_comment.parent_id] 166 | if parent.child_id is None: # If my parent doesn't already have a child, adopt me! 167 | parent.child_id = my_id 168 | else: # My parent already has a child. 169 | parent_previous_child = comment_dict[parent.child_id] 170 | if parent.parent_id in comment_dict: # If my grandparent is in our data set... 171 | grandparent = comment_dict[parent.parent_id] 172 | if my_comment.author == grandparent.author: 173 | # If I share an author with grandparent, adopt me! 174 | parent.child_id = my_id 175 | elif (parent_previous_child.author != grandparent.author 176 | and my_comment.score > parent_previous_child.score): 177 | # If the existing child doesn't share an author with grandparent, 178 | # higher score prevails. 179 | parent.child_id = my_id 180 | elif my_comment.score > parent_previous_child.score: 181 | # If there's no grandparent, the higher-score child prevails. 182 | parent.child_id = my_id 183 | else: 184 | # Parent IDs that aren't in the data set get de-referenced. 185 | my_comment.parent_id = None 186 | print() 187 | 188 | def write_comment_cache(comment_dict, output_file, print_every): 189 | i = 0 190 | prev_print_count = 0 191 | for k, v in comment_dict.items(): 192 | if v.parent_id is None and v.child_id is not None: 193 | comment = v 194 | depth = 0 195 | output_string = "" 196 | while comment is not None: 197 | depth += 1 198 | output_string += '> ' + comment.body + '\n' 199 | if comment.child_id in comment_dict: 200 | comment = comment_dict[comment.child_id] 201 | else: 202 | comment = None 203 | if depth > 3: 204 | output_file.write(output_string + '\n') 205 | i += depth 206 | if i > prev_print_count + print_every: 207 | prev_print_count = i 208 | print("\rWrote {} comments".format(i), end='') 209 | sys.stdout.flush() 210 | print() 211 | 212 | def write_report(report_file_path, subreddit_dict): 213 | print("Updating subreddit report file") 214 | subreddit_list = sorted(subreddit_dict.items(), key=lambda x: -x[1]) 215 | with open(report_file_path, "w") as f: 216 | for item in subreddit_list: 217 | f.write("{}: {}\n".format(*item)) 218 | 219 | if __name__ == '__main__': 220 | main() 221 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import argparse 5 | import time 6 | import os 7 | import cPickle 8 | 9 | from utils import TextLoader 10 | from model import Model 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--data_dir', type=str, default='data/scotus', 15 | help='data directory containing input.txt') 16 | parser.add_argument('--save_dir', type=str, default='models/new_save', 17 | help='directory for checkpointed models (load from here if one is already present)') 18 | parser.add_argument('--rnn_size', type=int, default=1500, 19 | help='size of RNN hidden state') 20 | parser.add_argument('--num_layers', type=int, default=4, 21 | help='number of layers in the RNN') 22 | parser.add_argument('--model', type=str, default='gru', 23 | help='rnn, gru, lstm or nas') 24 | parser.add_argument('--batch_size', type=int, default=40, 25 | help='minibatch size') 26 | parser.add_argument('--seq_length', type=int, default=50, 27 | help='RNN sequence length') 28 | parser.add_argument('--num_epochs', type=int, default=50, 29 | help='number of epochs') 30 | parser.add_argument('--save_every', type=int, default=1000, 31 | help='save frequency') 32 | parser.add_argument('--grad_clip', type=float, default=5., 33 | help='clip gradients at this value') 34 | parser.add_argument('--learning_rate', type=float, default=6e-5, 35 | help='learning rate') 36 | parser.add_argument('--decay_rate', type=float, default=0.95, 37 | help='how much to decay the learning rate') 38 | parser.add_argument('--decay_steps', type=int, default=100000, 39 | help='how often to decay the learning rate') 40 | args = parser.parse_args() 41 | train(args) 42 | 43 | def train(args): 44 | # Create the data_loader object, which loads up all of our batches, vocab dictionary, etc. 45 | # from utils.py (and creates them if they don't already exist). 46 | # These files go in the data directory. 47 | data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length) 48 | args.vocab_size = data_loader.vocab_size 49 | 50 | load_model = False 51 | if not os.path.exists(args.save_dir): 52 | print("Creating directory %s" % args.save_dir) 53 | os.mkdir(args.save_dir) 54 | elif (os.path.exists(os.path.join(args.save_dir, 'config.pkl'))): 55 | # Trained model already exists 56 | ckpt = tf.train.get_checkpoint_state(args.save_dir) 57 | if ckpt and ckpt.model_checkpoint_path: 58 | with open(os.path.join(args.save_dir, 'config.pkl')) as f: 59 | saved_args = cPickle.load(f) 60 | args.rnn_size = saved_args.rnn_size 61 | args.num_layers = saved_args.num_layers 62 | args.model = saved_args.model 63 | print("Found a previous checkpoint. Overwriting model description arguments to:") 64 | print(" model: {}, rnn_size: {}, num_layers: {}".format( 65 | saved_args.model, saved_args.rnn_size, saved_args.num_layers)) 66 | load_model = True 67 | 68 | # Save all arguments to config.pkl in the save directory -- NOT the data directory. 69 | with open(os.path.join(args.save_dir, 'config.pkl'), 'w') as f: 70 | cPickle.dump(args, f) 71 | # Save a tuple of the characters list and the vocab dictionary to chars_vocab.pkl in 72 | # the save directory -- NOT the data directory. 73 | with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'w') as f: 74 | cPickle.dump((data_loader.chars, data_loader.vocab), f) 75 | 76 | # Create the model! 77 | print("Building the model") 78 | model = Model(args) 79 | 80 | config = tf.ConfigProto(log_device_placement=False) 81 | config.gpu_options.allow_growth = True 82 | with tf.Session(config=config) as sess: 83 | # tf.initialize_all_variables().run() 84 | tf.global_variables_initializer().run() 85 | tf.local_variables_initializer().run() 86 | saver = tf.train.Saver(model.save_variables_list()) 87 | if (load_model): 88 | print("Loading saved parameters") 89 | saver.restore(sess, ckpt.model_checkpoint_path) 90 | global_epoch_fraction = sess.run(model.global_epoch_fraction) 91 | global_seconds_elapsed = sess.run(model.global_seconds_elapsed) 92 | if load_model: print("Resuming from global epoch fraction {:.3f}," 93 | " total trained time: {}, learning rate: {}".format( 94 | global_epoch_fraction, global_seconds_elapsed, sess.run(model.lr))) 95 | data_loader.cue_batch_pointer_to_epoch_fraction(global_epoch_fraction) 96 | initial_batch_step = int((global_epoch_fraction 97 | - int(global_epoch_fraction)) * data_loader.total_batch_count) 98 | epoch_range = (int(global_epoch_fraction), 99 | args.num_epochs + int(global_epoch_fraction)) 100 | # writer = tf.train.SummaryWriter(args.save_dir, graph=tf.get_default_graph()) 101 | writer = tf.summary.FileWriter(args.save_dir, graph=tf.get_default_graph()) 102 | outputs = [model.cost, model.final_state, model.train_op, model.summary_op] 103 | is_lstm = args.model == 'lstm' 104 | global_step = epoch_range[0] * data_loader.total_batch_count + initial_batch_step 105 | try: 106 | for e in range(*epoch_range): 107 | # e iterates through the training epochs. 108 | # Reset the model state, so it does not carry over from the end of the previous epoch. 109 | state = sess.run(model.initial_state) 110 | batch_range = (initial_batch_step, data_loader.total_batch_count) 111 | initial_batch_step = 0 112 | for b in range(*batch_range): 113 | global_step += 1 114 | if global_step % args.decay_steps == 0: 115 | # Set the model.lr element of the model to track 116 | # the appropriately decayed learning rate. 117 | current_learning_rate = sess.run(model.lr) 118 | current_learning_rate *= args.decay_rate 119 | sess.run(tf.assign(model.lr, current_learning_rate)) 120 | print("Decayed learning rate to {}".format(current_learning_rate)) 121 | start = time.time() 122 | # Pull the next batch inputs (x) and targets (y) from the data loader. 123 | x, y = data_loader.next_batch() 124 | 125 | # feed is a dictionary of variable references and respective values for initialization. 126 | # Initialize the model's input data and target data from the batch, 127 | # and initialize the model state to the final state from the previous batch, so that 128 | # model state is accumulated and carried over between batches. 129 | feed = {model.input_data: x, model.targets: y} 130 | if is_lstm: 131 | for i, (c, h) in enumerate(model.initial_state): 132 | feed[c] = state[i].c 133 | feed[h] = state[i].h 134 | else: 135 | for i, c in enumerate(model.initial_state): 136 | feed[c] = state[i] 137 | # Run the session! Specifically, tell TensorFlow to compute the graph to calculate 138 | # the values of cost, final state, and the training op. 139 | # Cost is used to monitor progress. 140 | # Final state is used to carry over the state into the next batch. 141 | # Training op is not used, but we want it to be calculated, since that calculation 142 | # is what updates parameter states (i.e. that is where the training happens). 143 | train_loss, state, _, summary = sess.run(outputs, feed) 144 | elapsed = time.time() - start 145 | global_seconds_elapsed += elapsed 146 | writer.add_summary(summary, e * batch_range[1] + b + 1) 147 | print "{}/{} (epoch {}/{}), loss = {:.3f}, time/batch = {:.3f}s" \ 148 | .format(b, batch_range[1], e, epoch_range[1], train_loss, elapsed) 149 | # Every save_every batches, save the model to disk. 150 | # By default, only the five most recent checkpoint files are kept. 151 | if (e * batch_range[1] + b + 1) % args.save_every == 0 \ 152 | or (e == epoch_range[1] - 1 and b == batch_range[1] - 1): 153 | save_model(sess, saver, model, args.save_dir, global_step, 154 | data_loader.total_batch_count, global_seconds_elapsed) 155 | except KeyboardInterrupt: 156 | # Introduce a line break after ^C is displayed so save message 157 | # is on its own line. 158 | print() 159 | finally: 160 | writer.flush() 161 | global_step = e * data_loader.total_batch_count + b 162 | save_model(sess, saver, model, args.save_dir, global_step, 163 | data_loader.total_batch_count, global_seconds_elapsed) 164 | 165 | def save_model(sess, saver, model, save_dir, global_step, steps_per_epoch, global_seconds_elapsed): 166 | global_epoch_fraction = float(global_step) / float(steps_per_epoch) 167 | checkpoint_path = os.path.join(save_dir, 'model.ckpt') 168 | print "Saving model to {} (epoch fraction {:.3f})".format(checkpoint_path, global_epoch_fraction) 169 | sess.run(tf.assign(model.global_epoch_fraction, global_epoch_fraction)) 170 | sess.run(tf.assign(model.global_seconds_elapsed, global_seconds_elapsed)) 171 | saver.save(sess, checkpoint_path, global_step = global_step) 172 | print "Model saved." 173 | 174 | if __name__ == '__main__': 175 | main() 176 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import io 4 | import collections 5 | import cPickle 6 | from bz2 import BZ2File 7 | import numpy as np 8 | 9 | class TextLoader(): 10 | # Call this class to load text from a file. 11 | def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'): 12 | # TextLoader model remembers its initialization arguments. 13 | self.data_dir = data_dir 14 | self.batch_size = batch_size 15 | self.seq_length = seq_length 16 | self.encoding = encoding 17 | self.tensor_sizes = [] 18 | 19 | self.tensor_file_template = os.path.join(data_dir, "data{}.npz") 20 | vocab_file = os.path.join(data_dir, "vocab.pkl") 21 | sizes_file = os.path.join(data_dir, "sizes.pkl") 22 | 23 | self.input_files = self._get_input_file_list(data_dir) 24 | self.input_file_count = len(self.input_files) 25 | 26 | if self.input_file_count < 1: 27 | raise ValueError("Input files not found. File names must end in '.txt' or '.bz2'.") 28 | 29 | if self._preprocess_required(vocab_file, sizes_file, self.tensor_file_template, self.input_file_count): 30 | # If either the vocab file or the tensor file doesn't already exist, create them. 31 | print("Preprocessing the following files: {}".format(self.input_files)) 32 | vocab_counter = collections.Counter() 33 | for i in xrange(self.input_file_count): 34 | print("reading vocab from input file {}".format(self.input_files[i])) 35 | self._augment_vocab(vocab_counter, self.input_files[i]) 36 | print("saving vocab file") 37 | self._save_vocab(vocab_counter, vocab_file) 38 | 39 | for i in xrange(self.input_file_count): 40 | print("preprocessing input file {}".format(self.input_files[i])) 41 | self._preprocess(self.input_files[i], self.tensor_file_template.format(i)) 42 | self.tensor_sizes.append(self.tensor.size) 43 | 44 | with open(sizes_file, 'wb') as f: 45 | cPickle.dump(self.tensor_sizes, f) 46 | 47 | print ("processed input text file: {} characters loaded".format(self.tensor.size)) 48 | else: 49 | # If the vocab file and sizes file already exist, load them. 50 | print "loading vocab file" 51 | self._load_vocab(vocab_file) 52 | print "loading sizes file" 53 | with open(sizes_file, 'rb') as f: 54 | self.tensor_sizes = cPickle.load(f) 55 | self.tensor_batch_counts = [n / (self.batch_size * self.seq_length) for n in self.tensor_sizes] 56 | self.total_batch_count = sum(self.tensor_batch_counts) 57 | print("total batch count: {}".format(self.total_batch_count)) 58 | 59 | self.tensor_index = -1 60 | 61 | def _preprocess_required(self, vocab_file, sizes_file, tensor_file_template, input_file_count): 62 | if not os.path.exists(vocab_file): 63 | print("No vocab file found. Preprocessing...") 64 | return True 65 | if not os.path.exists(sizes_file): 66 | print("No sizes file found. Preprocessing...") 67 | return True 68 | for i in xrange(input_file_count): 69 | if not os.path.exists(tensor_file_template.format(i)): 70 | print ("Couldn't find {}. Preprocessing...".format(tensor_file_template.format(i))) 71 | return True 72 | return False 73 | 74 | def _get_input_file_list(self, data_dir): 75 | suffixes = ['.txt', '.bz2'] 76 | input_file_list = [] 77 | if os.path.isdir(data_dir): 78 | for walk_root, walk_dir, walk_files in os.walk(data_dir): 79 | for file_name in walk_files: 80 | if file_name.startswith("."): continue 81 | file_path = os.path.join(walk_root, file_name) 82 | if file_path.endswith(suffixes[0]) or file_path.endswith(suffixes[1]): 83 | input_file_list.append(file_path) 84 | else: raise ValueError("Not a directory: {}".format(data_dir)) 85 | return sorted(input_file_list) 86 | 87 | def _augment_vocab(self, vocab_counter, input_file): 88 | # Load up the input.txt file and use it to create a vocab file and a tensor file 89 | # at the specified file paths. 90 | if input_file.endswith(".bz2"): file_reference = BZ2File(input_file, "r") 91 | elif input_file.endswith(".txt"): file_reference = io.open(input_file, "r") 92 | raw_data = file_reference.read() 93 | file_reference.close() 94 | u_data = raw_data.encode(encoding=self.encoding) 95 | vocab_counter.update(u_data) 96 | 97 | def _save_vocab(self, vocab_counter, vocab_file): 98 | # count_pairs is a list of these dictionary entries, sorted in descending order. 99 | # The first item of the list is a 2-item tuple of the most common character 100 | # and the number of times it occurs, then the second-most common, etc. -- e.g.: 101 | # [(' ', 17), ('a', 11), ('e', 7), ('n', 7), ...] 102 | count_pairs = sorted(vocab_counter.items(), key=lambda x: -x[1]) 103 | # self.chars is a tuple (immutable ordered list) of characters, in descending order 104 | # from most common to least. E.g.: 105 | # (' ', 'a', 'e', 'n', 't', ...) 106 | # This is a lookup device to convert index number to character. 107 | # How does this work? 108 | # zip(*___) returns an iterator of tuples, where the i-th tuple contains 109 | # the i-th element from each of the argument sequences or iterables. 110 | # So zip(*count_pairs) returns an iterator over two tuples, the first tuple being 111 | # characters in descending order of frequency, and the second being the frequency 112 | # of the same characters. 113 | # list() then packages these two tuples into a list of the same two tuples, 114 | # and the assignment passes the first tuple (characters in descending order) to self.chars 115 | # and the second (character counts) to a disregarded variable. 116 | self.chars, _ = list(zip(*count_pairs)) 117 | # self.vocab_size counts the number of characters used in input.txt. 118 | self.vocab_size = len(self.chars) 119 | # self.vocab is a dictionary that maps each character to its index number. For example: 120 | # [(' ', 0), ('a', 1), ('e', 2), ('n', 3), ...] 121 | # This is a lookup device to convert a character to its index number. 122 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 123 | # Save the characters tuple to vocab.pkl (tiny file). 124 | with open(vocab_file, 'wb') as f: 125 | cPickle.dump(self.chars, f) 126 | print("saved vocab (vocab size: {})".format(self.vocab_size)) 127 | 128 | def _load_vocab(self, vocab_file): 129 | # Load the character tuple (vocab.pkl) to self.chars. 130 | # Remember that it is in descending order of character frequency in the data. 131 | with open(vocab_file, 'rb') as f: 132 | self.chars = cPickle.load(f) 133 | # Use the character tuple to regenerate vocab_size and the vocab dictionary. 134 | self.vocab_size = len(self.chars) 135 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 136 | 137 | def _preprocess(self, input_file, tensor_file): 138 | if input_file.endswith(".bz2"): file_reference = BZ2File(input_file, "r") 139 | elif input_file.endswith(".txt"): file_reference = io.open(input_file, "r") 140 | raw_data = file_reference.read() 141 | file_reference.close() 142 | data = raw_data.encode(encoding=self.encoding) 143 | # Convert the entirety of the data file from characters to indices via the vocab dictionary. 144 | # How? map(function, iterable) returns a list of the output of the function 145 | # executed on each member of the iterable. E.g.: 146 | # [14, 2, 9, 2, 0, 6, 7, 0, ...] 147 | # np.array converts the list into a numpy array. 148 | self.tensor = np.array(map(self.vocab.get, data)) 149 | # Compress and save the numpy tensor array to data.npz. 150 | np.savez_compressed(tensor_file, tensor_data=self.tensor) 151 | 152 | def _load_preprocessed(self, tensor_index): 153 | self.reset_batch_pointer() 154 | if tensor_index == self.tensor_index: 155 | return 156 | print("loading tensor data file {}".format(tensor_index)) 157 | tensor_file = self.tensor_file_template.format(tensor_index) 158 | # Load the data tensor file to self.tensor. 159 | with np.load(tensor_file) as loaded: 160 | self.tensor = loaded['tensor_data'] 161 | self.tensor_index = tensor_index 162 | # Calculate the number of batches in the data. Each batch is batch_size x seq_length, 163 | # so this is just the input data size divided by that product, rounded down. 164 | self.num_batches = self.tensor.size / (self.batch_size * self.seq_length) 165 | if self.tensor_batch_counts[tensor_index] != self.num_batches: 166 | print("Error in batch size! Expected {}; found {}".format(self.tensor_batch_counts[tensor_index], 167 | self.num_batches)) 168 | # Chop off the end of the data tensor so that the length of the data is a whole 169 | # multiple of the (batch_size x seq_length) product. 170 | # Do this with the slice operator on the numpy array. 171 | self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length] 172 | # Construct two numpy arrays to represent input characters (xdata) 173 | # and target characters (ydata). 174 | # In training, we will feed in input characters one at a time, and optimize along 175 | # a loss function computed against the target characters. 176 | # (We do this with batch_size characters at a time, in parallel.) 177 | # Since this is a sequence prediction net, the target is just the input right-shifted 178 | # by 1. 179 | xdata = self.tensor 180 | ydata = np.copy(self.tensor) # Y-data starts as a copy of x-data. 181 | ydata[:-1] = xdata[1:] # Right-shift y-data by 1 using the numpy array slice syntax. 182 | # Replace the very last character of y-data with the first character of the input data. 183 | ydata[-1] = xdata[0] 184 | # Split our unidemnsional data array into distinct batches. 185 | # How? xdata.reshape(self.batch_size, -1) returns a 2D numpy tensor view 186 | # in which the first dimension is the batch index (from 0 to num_batches), 187 | # and the second dimension is the index of the character within the batch 188 | # (from 0 to (batch_size x seq_length)). 189 | # Within each batch, characters follow the same sequence as in the input data. 190 | # Then, np.split(that 2D numpy tensor, num_batches, 1) gives a list of numpy arrays. 191 | # Say batch_size = 4, seq_length = 5, and data is the following string: 192 | # "Here is a new string named data. It is a new string named data. It is named data." 193 | # We truncate the string to lop off the last period (so there are now 80 characters, 194 | # which is evenly divisible by 4 x 5). After xdata.reshape, we have: 195 | # 196 | # [[Here is a new string], 197 | # [ named data. It is a], 198 | # [ new string named da], 199 | # [ta. It is named data]] 200 | # 201 | # After np.split, we have: 202 | # <[[Here ], <[[is a ], <[[new s], <[[tring], 203 | # [ name], [d dat], [a. It], [ is a], 204 | # [ new ], [strin], [g nam], [ed da], 205 | # [ta. I]]>, [t is ]]>, [named]]>, [ data]]> 206 | # 207 | # where the first item of the list is the numpy array on the left. 208 | # Thus x_batches is a list of numpy arrays. The first dimension of each numpy array 209 | # is the batch number (from 0 to batch_size), and the second dimension is the 210 | # character index (from 0 to seq_length). 211 | # 212 | # These will be fed to the model one at a time sequentially. 213 | # State is preserved between sequential batches. 214 | # 215 | self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1) 216 | self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1) 217 | 218 | def next_batch(self): 219 | if self.tensor_index < 0: 220 | self._load_preprocessed(0) 221 | if self.pointer >= self.num_batches: 222 | self._load_preprocessed((self.tensor_index + 1) % self.input_file_count) 223 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 224 | self.pointer += 1 225 | return x, y 226 | 227 | def reset_batch_pointer(self): 228 | self.pointer = 0 229 | 230 | def cue_batch_pointer_to_epoch_fraction(self, epoch_fraction): 231 | step_target = (epoch_fraction - int(epoch_fraction)) * self.total_batch_count 232 | self._cue_batch_pointer_to_step_count(step_target) 233 | 234 | def _cue_batch_pointer_to_step_count(self, step_target): 235 | for i, n in enumerate(self.tensor_batch_counts): 236 | if step_target < n: 237 | break 238 | step_target -= n 239 | self.pointer = n 240 | self.current_tensor_index = i 241 | self._load_preprocessed(i) 242 | 243 | --------------------------------------------------------------------------------