├── LICENSE.txt ├── README.md ├── chatbot.py ├── data └── scotus │ └── scotus.bz2 ├── model.py ├── models └── models-folder-readme.txt ├── reddit-parse ├── parser_config_standard.json └── reddit_parse.py ├── train.py └── utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) Copyright (c) 2018 Brian Walker 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chatbot-rnn 2 | A toy chatbot powered by deep learning and trained on data from Reddit. Built on TensorFlow v1.4.0 and Python v3.5.1. 3 | 4 | Here is a sample chat transcript (not cherry-picked). Lines that start with ">" are me; the lines without ">" are the chatbot's responses. 5 | 6 | ``` 7 | $ python3 chatbot.py 8 | Creating model... 9 | Restoring weights... 10 | 11 | > Hi, how are you? 12 | I'm not the only one here that's having fun with you, it's all me. 13 | > Who else is with you? 14 | Haha. I'm not the one who made the comment. I'm the one who said that and didn't think that much of anything. 15 | > Oh, I see. So what's new? 16 | Nothing. I just thought it was a bit of an overreaction. It's been a while since I've read it. 17 | > It's all good 18 | It's all good. It's one of those things that makes me smile :) 19 | > That's great. This is actually my demo chat, so it's nice that you're keeping things positive! 20 | Thank you :) 21 | > You're welcome! It's a nice change from the way that you sometimes are... 22 | *blushes* Sorry, I should have clarified that I wasn't clear. I apologize for my ignorance. 23 | > Sooo, where are you from? 24 | I am from South America. Where you from? 25 | > California. Do you like it in South America? 26 | I live in Australia, but south of Florida is one of the reasons I live in Florida. I've never been there, but I've never been there. 27 | > I see, a true citizen of the world 28 | Nope, I've never heard of it. I just spent a lot of time there. 29 | > Spend a lot of time where? 30 | I've never been there before, but I've heard it's about an hour and a half away. Maybe there's some sort of party called "Southern" but that's about it. 31 | > Maybe there is, that's a great observation. 32 | Not to mention it's all about being from another country. 33 | > All right, guess I should wrap this up, nice talking with you 34 | I'm not going to argue with you. You're doing it wrong. 35 | > Some might call that arguing. Any final words? 36 | Don't do that. I don't need to. 37 | ``` 38 | 39 | ## Get Started 40 | 41 | - **Install [TensorFlow](http://www.tensorflow.org/)** for Python 3. To run a pretrained model, the CPU-only installation should suffice. If you want to train your own models, you'll need the GPU installation of TensorFlow (and a powerful CUDA-compatible GPU). 42 | 43 | - Clone this project to your computer. 44 | 45 | ### Run my pre-trained model 46 | 47 | - **Download [my pre-trained model](https://drive.google.com/uc?id=1rRRY-y1KdVk4UB5qhu7BjQHtfadIOmMk&export=download)** (2.3 GB). The zip file extracts into a folder named "reddit". Place that folder into the "models" directory of this project. 48 | 49 | - **Run the chatbot**. Open a terminal session and run `python3 chatbot.py`. Warning: this pre-trained model was trained on a diverse set of frequently off-color Reddit comments. It can (and eventually will) say things that are offensive, disturbing, bizarre or sexually explicit. It may insult minorities, it may call you names, it may accuse you of being a pedophile, it may try to seduce you. Please don't use the chatbot if these possibilities would distress you! 50 | 51 | Try playing around with the arguments to `chatbot.py` to obtain better samples: 52 | 53 | - **beam_width**: By default, `chatbot.py` will use beam search with a beam width of 2 to sample responses. Set this higher for more careful, more conservative (and slower) responses, or set it to 1 to disable beam search. 54 | 55 | - **temperature**: At each step, the model ascribes a certain probability to each character. Temperature can adjust the probability distribution. 1.0 is neutral (and the default), lower values increase high probability values and decrease lower probability values to make the choices more conservative, and higher values will do the reverse. Values outside of the range of 0.5-1.5 are unlikely to give coherent results. 56 | 57 | - **top-n**: At each step, zero out the probability of all possible characters except the *n* most likely. Disabled by default. 58 | 59 | - **relevance**: Two models are run in parallel: the primary model and the mask model. The mask model is scaled by the relevance value, and then the probabilities of the primary model are combined according to equation 9 in [Li, Jiwei, et al. "A diversity-promoting objective function for neural conversation models." arXiv preprint arXiv:1510.03055 (2015)](https://arxiv.org/abs/1510.03055). The state of the mask model is reset upon each newline character. The net effect is that the model is encouraged to choose a line of dialogue that is most relevant to the prior line of dialogue, even if a more generic response (e.g. "I don't know anything about that") may be more absolutely probable. Higher relevance values put more pressure on the model to produce relevant responses, at the cost of the coherence of the responses. Going much above 0.4 compromises the quality of the responses. Setting it to a negative value disables relevance, and this is the default, because I'm not confident that it qualitatively improves the outputs and it halves the speed of sampling. 60 | 61 | These values can also be manipulated during a chat, and the model state can be reset, without restarting the chatbot: 62 | 63 | ``` 64 | $ python3 chatbot.py 65 | Creating model... 66 | Restoring weights... 67 | 68 | > --temperature 1.3 69 | [Temperature set to 1.3] 70 | 71 | > --relevance 0.3 72 | [Relevance set to 0.3] 73 | 74 | > --relevance -1 75 | [Relevance disabled] 76 | 77 | > --topn 2 78 | [Top-n filtering set to 2] 79 | 80 | > --topn -1 81 | [Top-n filtering disabled] 82 | 83 | > --beam_width 5 84 | [Beam width set to 5] 85 | 86 | > --reset 87 | [Model state reset] 88 | ``` 89 | 90 | ### Get training data 91 | 92 | If you'd like to train your own model, you'll need training data. There are a few options here. 93 | 94 | - **Use pre-formatted Reddit training data.** This is what the pre-trained model was trained on. 95 | 96 | [Download the training data](https://drive.google.com/uc?id=1s77S7COjrb3lOnfqvXYfn7sW_x5U1_l9&export=download) (2.1 GB). Unzip the monolithic zip file. You'll be left with a folder named "reddit" containing 34 files named "output 1.bz2", "output 2.bz2" etc. Do not extract those individual bzip2 files. Instead, place the whole "reddit" folder that contains those files inside the `data` folder of the repo. The first time you run `train.py` on this data, it will convert the raw data into numpy tensors, compress them and save them back to disk, which will create files named "data0.npz" through "data34.npz" (as well as a "sizes.pkl" file and a "vocab.pkl" file). This will fill another ~5 GB of disk space, and will take about half an hour to finish. 97 | 98 | - **Generate your own Reddit training data.** If you would like to generate training data from raw Reddit archives, download a torrent of Reddit comments from the torrent links [listed here](https://www.reddit.com/r/datasets/comments/65o7py/updated_reddit_comment_dataset_as_torrents/). The comments are available in annual archives, and you can download any or all of them (~304 GB compressed in total). Do not extract the individual bzip2 (.bz2) files contained in these archives. 99 | 100 | Once you have your raw reddit data, place it in the `reddit-parse/reddit_data` subdirectory and use the `reddit-parse.py` script included in the project file to convert them into compressed text files of appropriately formatted conversations. This script chooses qualifying comments (must be under 200 characters, can't contain certain substrings such as 'http://', can't have been posted on certain subreddits) and assembles them into conversations of at least five lines. Coming up with good rules to curate conversations from raw reddit data is more art than science. I encourage you to play around with the parameters in the included `parser_config_standard.json` file, or to mess around with the parsing script itself, to come up with an interesting data set. 101 | 102 | Please be aware that there is a *lot* of Reddit data included in the torrents. It is very easy to run out of memory or hard drive space. I used the entire archive (~304 GB compressed), and ran the `reddit-parse.py` script with the configuration I included as the default, which holds a million comments (several GB) in memory at a time, takes about a day to run on the entire archive, and produces 2.1 GB of bzip2-compressed output. When training the model, this raw data will be converted into numpy tensors, compressed, and saved back to disk, which consumes another ~5 GB of hard drive space. I acknowledge that this may be overkill relative to the size of the model. 103 | 104 | - **Provide your own training data.** Training data should be one or more newline-delimited text files. Each line of dialogue should begin with "> " and end with a newline. You'll need a lot of it. Several megabytes of uncompressed text is probably the minimum, and even that may not suffice if you want to train a large model. Text can be provided as raw .txt files or as bzip2-compressed (.bz2) files. 105 | 106 | - **Simulate the United States Supreme Court.** I've included a corpus of United States Supreme Court oral argument transcripts (2.7 MB compressed) in the project under the `data/scotus` directory. 107 | 108 | Once you have training data in hand (and located in a subdirectory of the `data` directory): 109 | 110 | ### Train your own model 111 | 112 | - **Train.** Use `train.py` to train the model. The default hyperparameters are the best that I've found, and are what I used to train the pre-trained model for a couple of months. These hyperparameters will just about fill the memory of a GTX 1080 Ti GPU (11 GB of VRAM), so if you have a smaller GPU, you will need to adjust them accordingly (for example, set --num_blocks to 2). 113 | 114 | Training can be interrupted with crtl-c at any time, and will immediately save the model when interrupted. Training can be resumed on a saved model and will automatically carry on from where it was interrupted. 115 | 116 | ## Thanks 117 | 118 | Thanks to Andrej Karpathy for his [char-rnn](https://github.com/karpathy/char-rnn) repo, and to Sherjil Ozair for his [TensorFlow port](https://github.com/sherjilozair/char-rnn-tensorflow) of char-rnn, which this repo is based on. 119 | -------------------------------------------------------------------------------- /chatbot.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | import argparse 7 | import os 8 | import pickle 9 | import copy 10 | import sys 11 | import html 12 | 13 | from utils import TextLoader 14 | from model import Model 15 | 16 | def main(): 17 | assert sys.version_info >= (3, 3), \ 18 | "Must be run in Python 3.3 or later. You are running {}".format(sys.version) 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--save_dir', type=str, default='models/reddit', 21 | help='model directory to store checkpointed models') 22 | parser.add_argument('-n', type=int, default=500, 23 | help='number of characters to sample') 24 | parser.add_argument('--prime', type=str, default=' ', 25 | help='prime text') 26 | parser.add_argument('--beam_width', type=int, default=2, 27 | help='Width of the beam for beam search, default 2') 28 | parser.add_argument('--temperature', type=float, default=1.0, 29 | help='sampling temperature' 30 | '(lower is more conservative, default is 1.0, which is neutral)') 31 | parser.add_argument('--topn', type=int, default=-1, 32 | help='at each step, choose from only this many most likely characters;' 33 | 'set to <0 to disable top-n filtering.') 34 | parser.add_argument('--relevance', type=float, default=-1., 35 | help='amount of "relevance masking/MMI (disabled by default):"' 36 | 'higher is more pressure, 0.4 is probably as high as it can go without' 37 | 'noticeably degrading coherence;' 38 | 'set to <0 to disable relevance masking') 39 | args = parser.parse_args() 40 | sample_main(args) 41 | 42 | def get_paths(input_path): 43 | if os.path.isfile(input_path): 44 | # Passed a model rather than a checkpoint directory 45 | model_path = input_path 46 | save_dir = os.path.dirname(model_path) 47 | elif os.path.exists(input_path): 48 | # Passed a checkpoint directory 49 | save_dir = input_path 50 | checkpoint = tf.train.get_checkpoint_state(save_dir) 51 | if checkpoint: 52 | model_path = checkpoint.model_checkpoint_path 53 | else: 54 | raise ValueError('Checkpoint not found in {}.'.format(save_dir)) 55 | else: 56 | raise ValueError('save_dir is not a valid path.') 57 | return model_path, os.path.join(save_dir, 'config.pkl'), os.path.join(save_dir, 'chars_vocab.pkl') 58 | 59 | def sample_main(args): 60 | model_path, config_path, vocab_path = get_paths(args.save_dir) 61 | # Arguments passed to sample.py direct us to a saved model. 62 | # Load the separate arguments by which that model was previously trained. 63 | # That's saved_args. Use those to load the model. 64 | with open(config_path, 'rb') as f: 65 | saved_args = pickle.load(f) 66 | # Separately load chars and vocab from the save directory. 67 | with open(vocab_path, 'rb') as f: 68 | chars, vocab = pickle.load(f) 69 | # Create the model from the saved arguments, in inference mode. 70 | print("Creating model...") 71 | saved_args.batch_size = args.beam_width 72 | net = Model(saved_args, True) 73 | config = tf.ConfigProto() 74 | config.gpu_options.allow_growth = True 75 | # Make tensorflow less verbose; filter out info (1+) and warnings (2+) but not errors (3). 76 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 77 | with tf.Session(config=config) as sess: 78 | tf.global_variables_initializer().run() 79 | saver = tf.train.Saver(net.save_variables_list()) 80 | # Restore the saved variables, replacing the initialized values. 81 | print("Restoring weights...") 82 | saver.restore(sess, model_path) 83 | chatbot(net, sess, chars, vocab, args.n, args.beam_width, 84 | args.relevance, args.temperature, args.topn) 85 | 86 | def initial_state(net, sess): 87 | # Return freshly initialized model states. 88 | return sess.run(net.zero_state) 89 | 90 | def forward_text(net, sess, states, relevance, vocab, prime_text=None): 91 | if prime_text is not None: 92 | for char in prime_text: 93 | if relevance > 0.: 94 | # Automatically forward the primary net. 95 | _, states[0] = net.forward_model(sess, states[0], vocab[char]) 96 | # If the token is newline, reset the mask net state; else, forward it. 97 | if vocab[char] == '\n': 98 | states[1] = initial_state(net, sess) 99 | else: 100 | _, states[1] = net.forward_model(sess, states[1], vocab[char]) 101 | else: 102 | _, states = net.forward_model(sess, states, vocab[char]) 103 | return states 104 | 105 | def sanitize_text(vocab, text): # Strip out characters that are not part of the net's vocab. 106 | return ''.join(i for i in text if i in vocab) 107 | 108 | def initial_state_with_relevance_masking(net, sess, relevance): 109 | if relevance <= 0.: return initial_state(net, sess) 110 | else: return [initial_state(net, sess), initial_state(net, sess)] 111 | 112 | def possibly_escaped_char(raw_chars): 113 | if raw_chars[-1] == ';': 114 | for i, c in enumerate(reversed(raw_chars[:-1])): 115 | if c == ';' or i > 8: 116 | return raw_chars[-1] 117 | elif c == '&': 118 | escape_seq = "".join(raw_chars[-(i + 2):]) 119 | new_seq = html.unescape(escape_seq) 120 | backspace_seq = "".join(['\b'] * (len(escape_seq)-1)) 121 | diff_length = len(escape_seq) - len(new_seq) - 1 122 | return backspace_seq + new_seq + "".join([' '] * diff_length) + "".join(['\b'] * diff_length) 123 | return raw_chars[-1] 124 | 125 | def chatbot(net, sess, chars, vocab, max_length, beam_width, relevance, temperature, topn): 126 | states = initial_state_with_relevance_masking(net, sess, relevance) 127 | while True: 128 | user_input = input('\n> ') 129 | user_command_entered, reset, states, relevance, temperature, topn, beam_width = process_user_command( 130 | user_input, states, relevance, temperature, topn, beam_width) 131 | if reset: states = initial_state_with_relevance_masking(net, sess, relevance) 132 | if not user_command_entered: 133 | states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "> " + user_input + "\n>")) 134 | computer_response_generator = beam_search_generator(sess=sess, net=net, 135 | initial_state=copy.deepcopy(states), initial_sample=vocab[' '], 136 | early_term_token=vocab['\n'], beam_width=beam_width, forward_model_fn=forward_with_mask, 137 | forward_args={'relevance':relevance, 'mask_reset_token':vocab['\n'], 'forbidden_token':vocab['>'], 138 | 'temperature':temperature, 'topn':topn}) 139 | out_chars = [] 140 | for i, char_token in enumerate(computer_response_generator): 141 | out_chars.append(chars[char_token]) 142 | print(possibly_escaped_char(out_chars), end='', flush=True) 143 | states = forward_text(net, sess, states, relevance, vocab, chars[char_token]) 144 | if i >= max_length: break 145 | states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "\n> ")) 146 | 147 | def process_user_command(user_input, states, relevance, temperature, topn, beam_width): 148 | user_command_entered = False 149 | reset = False 150 | try: 151 | if user_input.startswith('--temperature '): 152 | user_command_entered = True 153 | temperature = max(0.001, float(user_input[len('--temperature '):])) 154 | print("[Temperature set to {}]".format(temperature)) 155 | elif user_input.startswith('--relevance '): 156 | user_command_entered = True 157 | new_relevance = float(user_input[len('--relevance '):]) 158 | if relevance <= 0. and new_relevance > 0.: 159 | states = [states, copy.deepcopy(states)] 160 | elif relevance > 0. and new_relevance <= 0.: 161 | states = states[0] 162 | relevance = new_relevance 163 | print("[Relevance disabled]" if relevance <= 0. else "[Relevance set to {}]".format(relevance)) 164 | elif user_input.startswith('--topn '): 165 | user_command_entered = True 166 | topn = int(user_input[len('--topn '):]) 167 | print("[Top-n filtering disabled]" if topn <= 0 else "[Top-n filtering set to {}]".format(topn)) 168 | elif user_input.startswith('--beam_width '): 169 | user_command_entered = True 170 | beam_width = max(1, int(user_input[len('--beam_width '):])) 171 | print("[Beam width set to {}]".format(beam_width)) 172 | elif user_input.startswith('--reset'): 173 | user_command_entered = True 174 | reset = True 175 | print("[Model state reset]") 176 | except ValueError: 177 | print("[Value error with provided argument.]") 178 | return user_command_entered, reset, states, relevance, temperature, topn, beam_width 179 | 180 | def consensus_length(beam_outputs, early_term_token): 181 | for l in range(len(beam_outputs[0])): 182 | if l > 0 and beam_outputs[0][l-1] == early_term_token: 183 | return l-1, True 184 | for b in beam_outputs[1:]: 185 | if beam_outputs[0][l] != b[l]: return l, False 186 | return l, False 187 | 188 | def scale_prediction(prediction, temperature): 189 | if (temperature == 1.0): return prediction # Temperature 1.0 makes no change 190 | np.seterr(divide='ignore') 191 | scaled_prediction = np.log(prediction) / temperature 192 | scaled_prediction = scaled_prediction - np.logaddexp.reduce(scaled_prediction) 193 | scaled_prediction = np.exp(scaled_prediction) 194 | np.seterr(divide='warn') 195 | return scaled_prediction 196 | 197 | def forward_with_mask(sess, net, states, input_sample, forward_args): 198 | # forward_args is a dictionary containing arguments for generating probabilities. 199 | relevance = forward_args['relevance'] 200 | mask_reset_token = forward_args['mask_reset_token'] 201 | forbidden_token = forward_args['forbidden_token'] 202 | temperature = forward_args['temperature'] 203 | topn = forward_args['topn'] 204 | 205 | if relevance <= 0.: 206 | # No relevance masking. 207 | prob, states = net.forward_model(sess, states, input_sample) 208 | else: 209 | # states should be a 2-length list: [primary net state, mask net state]. 210 | if input_sample == mask_reset_token: 211 | # Reset the mask probs when reaching mask_reset_token (newline). 212 | states[1] = initial_state(net, sess) 213 | primary_prob, states[0] = net.forward_model(sess, states[0], input_sample) 214 | primary_prob /= sum(primary_prob) 215 | mask_prob, states[1] = net.forward_model(sess, states[1], input_sample) 216 | mask_prob /= sum(mask_prob) 217 | prob = np.exp(np.log(primary_prob) - relevance * np.log(mask_prob)) 218 | # Mask out the forbidden token (">") to prevent the bot from deciding the chat is over) 219 | prob[forbidden_token] = 0 220 | # Normalize probabilities so they sum to 1. 221 | prob = prob / sum(prob) 222 | # Apply temperature. 223 | prob = scale_prediction(prob, temperature) 224 | # Apply top-n filtering if enabled 225 | if topn > 0: 226 | prob[np.argsort(prob)[:-topn]] = 0 227 | prob = prob / sum(prob) 228 | return prob, states 229 | 230 | def beam_search_generator(sess, net, initial_state, initial_sample, 231 | early_term_token, beam_width, forward_model_fn, forward_args): 232 | '''Run beam search! Yield consensus tokens sequentially, as a generator; 233 | return when reaching early_term_token (newline). 234 | 235 | Args: 236 | sess: tensorflow session reference 237 | net: tensorflow net graph (must be compatible with the forward_net function) 238 | initial_state: initial hidden state of the net 239 | initial_sample: single token (excluding any seed/priming material) 240 | to start the generation 241 | early_term_token: stop when the beam reaches consensus on this token 242 | (but do not return this token). 243 | beam_width: how many beams to track 244 | forward_model_fn: function to forward the model, must be of the form: 245 | probability_output, beam_state = 246 | forward_model_fn(sess, net, beam_state, beam_sample, forward_args) 247 | (Note: probability_output has to be a valid probability distribution!) 248 | tot_steps: how many tokens to generate before stopping, 249 | unless already stopped via early_term_token. 250 | Returns: a generator to yield a sequence of beam-sampled tokens.''' 251 | # Store state, outputs and probabilities for up to args.beam_width beams. 252 | # Initialize with just the one starting entry; it will branch to fill the beam 253 | # in the first step. 254 | beam_states = [initial_state] # Stores the best activation states 255 | beam_outputs = [[initial_sample]] # Stores the best generated output sequences so far. 256 | beam_probs = [1.] # Stores the cumulative normalized probabilities of the beams so far. 257 | 258 | while True: 259 | # Keep a running list of the best beam branches for next step. 260 | # Don't actually copy any big data structures yet, just keep references 261 | # to existing beam state entries, and then clone them as necessary 262 | # at the end of the generation step. 263 | new_beam_indices = [] 264 | new_beam_probs = [] 265 | new_beam_samples = [] 266 | 267 | # Iterate through the beam entries. 268 | for beam_index, beam_state in enumerate(beam_states): 269 | beam_prob = beam_probs[beam_index] 270 | beam_sample = beam_outputs[beam_index][-1] 271 | 272 | # Forward the model. 273 | prediction, beam_states[beam_index] = forward_model_fn( 274 | sess, net, beam_state, beam_sample, forward_args) 275 | 276 | # Sample best_tokens from the probability distribution. 277 | # Sample from the scaled probability distribution beam_width choices 278 | # (but not more than the number of positive probabilities in scaled_prediction). 279 | count = min(beam_width, sum(1 if p > 0. else 0 for p in prediction)) 280 | best_tokens = np.random.choice(len(prediction), size=count, 281 | replace=False, p=prediction) 282 | for token in best_tokens: 283 | prob = prediction[token] * beam_prob 284 | if len(new_beam_indices) < beam_width: 285 | # If we don't have enough new_beam_indices, we automatically qualify. 286 | new_beam_indices.append(beam_index) 287 | new_beam_probs.append(prob) 288 | new_beam_samples.append(token) 289 | else: 290 | # Sample a low-probability beam to possibly replace. 291 | np_new_beam_probs = np.array(new_beam_probs) 292 | inverse_probs = -np_new_beam_probs + max(np_new_beam_probs) + min(np_new_beam_probs) 293 | inverse_probs = inverse_probs / sum(inverse_probs) 294 | sampled_beam_index = np.random.choice(beam_width, p=inverse_probs) 295 | if new_beam_probs[sampled_beam_index] <= prob: 296 | # Replace it. 297 | new_beam_indices[sampled_beam_index] = beam_index 298 | new_beam_probs[sampled_beam_index] = prob 299 | new_beam_samples[sampled_beam_index] = token 300 | # Replace the old states with the new states, first by referencing and then by copying. 301 | already_referenced = [False] * beam_width 302 | new_beam_states = [] 303 | new_beam_outputs = [] 304 | for i, new_index in enumerate(new_beam_indices): 305 | if already_referenced[new_index]: 306 | new_beam = copy.deepcopy(beam_states[new_index]) 307 | else: 308 | new_beam = beam_states[new_index] 309 | already_referenced[new_index] = True 310 | new_beam_states.append(new_beam) 311 | new_beam_outputs.append(beam_outputs[new_index] + [new_beam_samples[i]]) 312 | # Normalize the beam probabilities so they don't drop to zero 313 | beam_probs = new_beam_probs / sum(new_beam_probs) 314 | beam_states = new_beam_states 315 | beam_outputs = new_beam_outputs 316 | # Prune the agreed portions of the outputs 317 | # and yield the tokens on which the beam has reached consensus. 318 | l, early_term = consensus_length(beam_outputs, early_term_token) 319 | if l > 0: 320 | for token in beam_outputs[0][:l]: yield token 321 | beam_outputs = [output[l:] for output in beam_outputs] 322 | if early_term: return 323 | 324 | if __name__ == '__main__': 325 | main() 326 | -------------------------------------------------------------------------------- /data/scotus/scotus.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pender/chatbot-rnn/82cbb710beaf15da11b60e03f00bdbfd521df754/data/scotus/scotus.bz2 -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import rnn_cell 3 | from tensorflow.python.ops import nn_ops 4 | from tensorflow.python.ops import variable_scope as vs 5 | from tensorflow.python.framework import ops 6 | from tensorflow.contrib import rnn 7 | 8 | from tensorflow.python.util.nest import flatten 9 | 10 | import numpy as np 11 | 12 | class PartitionedMultiRNNCell(rnn_cell.RNNCell): 13 | """RNN cell composed sequentially of multiple simple cells.""" 14 | 15 | # Diagramn of a PartitionedMultiRNNCell net with three layers and three partitions per layer. 16 | # Each brick shape is a partition, which comprises one RNNCell of size partition_size. 17 | # The two tilde (~) characters indicate wrapping (i.e. the two halves are a single partition). 18 | # Like laying bricks, each layer is offset by half a partition width so that influence spreads 19 | # horizontally through subsequent layers, while avoiding the quadratic resource scaling of fully 20 | # connected layers with respect to layer width. 21 | 22 | # output 23 | # //////// \\\\\\\\ 24 | # ------------------- 25 | # | | | | 26 | # ------------------- 27 | # ~ | | | ~ 28 | # ------------------- 29 | # | | | | 30 | # ------------------- 31 | # \\\\\\\\ //////// 32 | # input 33 | 34 | 35 | def __init__(self, cell_fn, partition_size=128, partitions=1, layers=2): 36 | """Create a RNN cell composed sequentially of a number of RNNCells. 37 | Args: 38 | cell_fn: reference to RNNCell function to create each partition in each layer. 39 | partition_size: how many horizontal cells to include in each partition. 40 | partitions: how many horizontal partitions to include in each layer. 41 | layers: how many layers to include in the net. 42 | """ 43 | super(PartitionedMultiRNNCell, self).__init__() 44 | 45 | self._cells = [] 46 | for i in range(layers): 47 | self._cells.append([cell_fn(partition_size) for _ in range(partitions)]) 48 | self._partitions = partitions 49 | 50 | @property 51 | def state_size(self): 52 | # Return a 2D tuple where each row is the partition's cell size repeated `partitions` times, 53 | # and there are `layers` rows of that. 54 | return tuple(((layer[0].state_size,) * len(layer)) for layer in self._cells) 55 | 56 | @property 57 | def output_size(self): 58 | # Return the output size of each partition in the last layer times the number of partitions per layer. 59 | return self._cells[-1][0].output_size * len(self._cells[-1]) 60 | 61 | def zero_state(self, batch_size, dtype): 62 | # Return a 2D tuple of zero states matching the structure of state_size. 63 | with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 64 | return tuple(tuple(cell.zero_state(batch_size, dtype) for cell in layer) for layer in self._cells) 65 | 66 | def call(self, inputs, state): 67 | layer_input = inputs 68 | new_states = [] 69 | for l, layer in enumerate(self._cells): 70 | # In between layers, offset the layer input by half of a partition width so that 71 | # activations can horizontally spread through subsequent layers. 72 | if l > 0: 73 | offset_width = layer[0].output_size // 2 74 | layer_input = tf.concat((layer_input[:, -offset_width:], layer_input[:, :-offset_width]), 75 | axis=1, name='concat_offset_%d' % l) 76 | # Create a tuple of inputs by splitting the lower layer output into partitions. 77 | p_inputs = tf.split(layer_input, len(layer), axis=1, name='split_%d' % l) 78 | p_outputs = [] 79 | p_states = [] 80 | for p, p_inp in enumerate(p_inputs): 81 | with vs.variable_scope("cell_%d_%d" % (l, p)): 82 | p_state = state[l][p] 83 | cell = layer[p] 84 | p_out, new_p_state = cell(p_inp, p_state) 85 | p_outputs.append(p_out) 86 | p_states.append(new_p_state) 87 | new_states.append(tuple(p_states)) 88 | layer_input = tf.concat(p_outputs, axis=1, name='concat_%d' % l) 89 | new_states = tuple(new_states) 90 | return layer_input, new_states 91 | 92 | def _rnn_state_placeholders(state): 93 | """Convert RNN state tensors to placeholders, reflecting the same nested tuple structure.""" 94 | # Adapted from @carlthome's comment: 95 | # https://github.com/tensorflow/tensorflow/issues/2838#issuecomment-302019188 96 | if isinstance(state, tf.contrib.rnn.LSTMStateTuple): 97 | c, h = state 98 | c = tf.placeholder(c.dtype, c.shape, c.op.name) 99 | h = tf.placeholder(h.dtype, h.shape, h.op.name) 100 | return tf.contrib.rnn.LSTMStateTuple(c, h) 101 | elif isinstance(state, tf.Tensor): 102 | h = state 103 | h = tf.placeholder(h.dtype, h.shape, h.op.name) 104 | return h 105 | else: 106 | structure = [_rnn_state_placeholders(x) for x in state] 107 | return tuple(structure) 108 | 109 | class Model(): 110 | def __init__(self, args, infer=False): # infer is set to true during sampling. 111 | self.args = args 112 | if infer: 113 | # Worry about one character at a time during sampling; no batching or BPTT. 114 | args.batch_size = 1 115 | args.seq_length = 1 116 | 117 | # Set cell_fn to the type of network cell we're creating -- RNN, GRU, LSTM or NAS. 118 | if args.model == 'rnn': 119 | cell_fn = rnn_cell.BasicRNNCell 120 | elif args.model == 'gru': 121 | cell_fn = rnn_cell.GRUCell 122 | elif args.model == 'lstm': 123 | cell_fn = rnn_cell.BasicLSTMCell 124 | elif args.model == 'nas': 125 | cell_fn = rnn.NASCell 126 | else: 127 | raise Exception("model type not supported: {}".format(args.model)) 128 | 129 | # Create variables to track training progress. 130 | self.lr = tf.Variable(args.learning_rate, name="learning_rate", trainable=False) 131 | self.global_epoch_fraction = tf.Variable(0.0, name="global_epoch_fraction", trainable=False) 132 | self.global_seconds_elapsed = tf.Variable(0.0, name="global_seconds_elapsed", trainable=False) 133 | 134 | # Call tensorflow library tensorflow-master/tensorflow/python/ops/rnn_cell 135 | # to create a layer of block_size cells of the specified basic type (RNN/GRU/LSTM). 136 | # Use the same rnn_cell library to create a stack of these cells 137 | # of num_layers layers. Pass in a python list of these cells. 138 | # cell = rnn_cell.MultiRNNCell([cell_fn(args.block_size) for _ in range(args.num_layers)]) 139 | # cell = MyMultiRNNCell([cell_fn(args.block_size) for _ in range(args.num_layers)]) 140 | cell = PartitionedMultiRNNCell(cell_fn, partitions=args.num_blocks, 141 | partition_size=args.block_size, layers=args.num_layers) 142 | 143 | # Create a TF placeholder node of 32-bit ints (NOT floats!), 144 | # of shape batch_size x seq_length. This shape matches the batches 145 | # (listed in x_batches and y_batches) constructed in create_batches in utils.py. 146 | # input_data will receive input batches. 147 | self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) 148 | 149 | self.zero_state = cell.zero_state(args.batch_size, tf.float32) 150 | 151 | self.initial_state = _rnn_state_placeholders(self.zero_state) 152 | self._flattened_initial_state = flatten(self.initial_state) 153 | 154 | layer_size = args.block_size * args.num_blocks 155 | 156 | # Scope our new variables to the scope identifier string "rnnlm". 157 | with tf.variable_scope('rnnlm'): 158 | # Create new variable softmax_w and softmax_b for output. 159 | # softmax_w is a weights matrix from the top layer of the model (of size layer_size) 160 | # to the vocabulary output (of size vocab_size). 161 | softmax_w = tf.get_variable("softmax_w", [layer_size, args.vocab_size]) 162 | # softmax_b is a bias vector of the ouput characters (of size vocab_size). 163 | softmax_b = tf.get_variable("softmax_b", [args.vocab_size]) 164 | # Create new variable named 'embedding' to connect the character input to the base layer 165 | # of the RNN. Its role is the conceptual inverse of softmax_w. 166 | # It contains the trainable weights from the one-hot input vector to the lowest layer of RNN. 167 | embedding = tf.get_variable("embedding", [args.vocab_size, layer_size]) 168 | # Create an embedding tensor with tf.nn.embedding_lookup(embedding, self.input_data). 169 | # This tensor has dimensions batch_size x seq_length x layer_size. 170 | inputs = tf.nn.embedding_lookup(embedding, self.input_data) 171 | 172 | # TODO: Check arguments parallel_iterations (default uses more memory and less time) and 173 | # swap_memory (default uses more memory but "minimal (or no) performance penalty") 174 | outputs, self.final_state = tf.nn.dynamic_rnn(cell, inputs, 175 | initial_state=self.initial_state, scope='rnnlm') 176 | # outputs has shape [batch_size, max_time, cell.output_size] because time_major == false. 177 | # Do we need to transpose the first two dimensions? (Answer: no, this ruins everything.) 178 | # outputs = tf.transpose(outputs, perm=[1, 0, 2]) 179 | output = tf.reshape(outputs, [-1, layer_size]) 180 | # Obtain logits node by applying output weights and biases to the output tensor. 181 | # Logits is a tensor of shape [(batch_size * seq_length) x vocab_size]. 182 | # Recall that outputs is a 2D tensor of shape [(batch_size * seq_length) x layer_size], 183 | # and softmax_w is a 2D tensor of shape [layer_size x vocab_size]. 184 | # The matrix product is therefore a new 2D tensor of [(batch_size * seq_length) x vocab_size]. 185 | # In other words, that multiplication converts a loooong list of layer_size vectors 186 | # to a loooong list of vocab_size vectors. 187 | # Then add softmax_b (a single vocab-sized vector) to every row of that list. 188 | # That gives you the logits! 189 | self.logits = tf.matmul(output, softmax_w) + softmax_b 190 | if infer: 191 | # Convert logits to probabilities. Probs isn't used during training! That node is never calculated. 192 | # Like logits, probs is a tensor of shape [(batch_size * seq_length) x vocab_size]. 193 | # During sampling, this means it is of shape [1 x vocab_size]. 194 | self.probs = tf.nn.softmax(self.logits) 195 | else: 196 | # Create a targets placeholder of shape batch_size x seq_length. 197 | # Targets will be what output is compared against to calculate loss. 198 | self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) 199 | # seq2seq.sequence_loss_by_example returns 1D float Tensor containing the log-perplexity 200 | # for each sequence. (Size is batch_size * seq_length.) 201 | # Targets are reshaped from a [batch_size x seq_length] tensor to a 1D tensor, of the following layout: 202 | # target character (batch 0, seq 0) 203 | # target character (batch 0, seq 1) 204 | # ... 205 | # target character (batch 0, seq seq_len-1) 206 | # target character (batch 1, seq 0) 207 | # ... 208 | # These targets are compared to the logits to generate loss. 209 | # Logits: instead of a list of character indices, it's a list of character index probability vectors. 210 | # seq2seq.sequence_loss_by_example will do the work of generating losses by comparing the one-hot vectors 211 | # implicitly represented by the target characters against the probability distrutions in logits. 212 | # It returns a 1D float tensor (a vector) where item i is the log-perplexity of 213 | # the comparison of the ith logit distribution to the ith one-hot target vector. 214 | 215 | loss = nn_ops.sparse_softmax_cross_entropy_with_logits( 216 | labels=tf.reshape(self.targets, [-1]), logits=self.logits) 217 | 218 | # Cost is the arithmetic mean of the values of the loss tensor. 219 | # It is a single-element floating point tensor. This is what the optimizer seeks to minimize. 220 | self.cost = tf.reduce_mean(loss) 221 | # Create a tensorboard summary of our cost. 222 | tf.summary.scalar("cost", self.cost) 223 | 224 | tvars = tf.trainable_variables() # tvars is a python list of all trainable TF Variable objects. 225 | # tf.gradients returns a list of tensors of length len(tvars) where each tensor is sum(dy/dx). 226 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), 227 | args.grad_clip) 228 | optimizer = tf.train.AdamOptimizer(self.lr) # Use ADAM optimizer. 229 | # Zip creates a list of tuples, where each tuple is (variable tensor, gradient tensor). 230 | # Training op nudges the variables along the gradient, with the given learning rate, using the ADAM optimizer. 231 | # This is the op that a training session should be instructed to perform. 232 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 233 | #self.train_op = optimizer.minimize(self.cost) 234 | self.summary_op = tf.summary.merge_all() 235 | 236 | def add_state_to_feed_dict(self, feed_dict, state): 237 | for i, tensor in enumerate(flatten(state)): 238 | feed_dict[self._flattened_initial_state[i]] = tensor 239 | 240 | def save_variables_list(self): 241 | # Return a list of the trainable variables created within the rnnlm model. 242 | # This consists of the two projection softmax variables (softmax_w and softmax_b), 243 | # embedding, and all of the weights and biases in the MultiRNNCell model. 244 | # Save only the trainable variables and the placeholders needed to resume training; 245 | # discard the rest, including optimizer state. 246 | save_vars = set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='rnnlm')) 247 | save_vars.update({self.lr, self.global_epoch_fraction, self.global_seconds_elapsed}) 248 | return list(save_vars) 249 | 250 | def forward_model(self, sess, state, input_sample): 251 | '''Run a forward pass. Return the updated hidden state and the output probabilities.''' 252 | shaped_input = np.array([[input_sample]], np.float32) 253 | inputs = {self.input_data: shaped_input} 254 | self.add_state_to_feed_dict(inputs, state) 255 | [probs, state] = sess.run([self.probs, self.final_state], feed_dict=inputs) 256 | return probs[0], state 257 | 258 | def trainable_parameter_count(self): 259 | total_parameters = 0 260 | for variable in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='rnnlm'): 261 | shape = variable.get_shape() 262 | variable_parameters = 1 263 | for dim in shape: 264 | variable_parameters *= dim.value 265 | total_parameters += variable_parameters 266 | return total_parameters 267 | -------------------------------------------------------------------------------- /models/models-folder-readme.txt: -------------------------------------------------------------------------------- 1 | Place folders containing downloaded models in this directory. 2 | 3 | You can download my own pre-trained model (2.3 GB) here: https://drive.google.com/uc?id=1rRRY-y1KdVk4UB5qhu7BjQHtfadIOmMk&export=download -------------------------------------------------------------------------------- /reddit-parse/parser_config_standard.json: -------------------------------------------------------------------------------- 1 | { 2 | "subreddit_whitelist": [], 3 | "subreddit_blacklist": [ 4 | "sports", 5 | "hockey", 6 | "nba", 7 | "NBA2k", 8 | "nfl", 9 | "NFTL", 10 | "soccer", 11 | "CFB", 12 | "FIFA", 13 | "FifaCareers", 14 | "BostonBruins", 15 | "DetroitRedWings", 16 | "rangers", 17 | "NHLHUT", 18 | "WWE", 19 | "MaddenUltimateTeam", 20 | "NYKnicks", 21 | "CollegeBasketball", 22 | "NOLAPelicans", 23 | "BGASL", 24 | "sixers", 25 | 26 | "announcements", 27 | "blog", 28 | "gaming", 29 | 30 | "counting", 31 | 32 | "The_Donald", 33 | 34 | "pokemon", 35 | "CasualPokemonTrades", 36 | "Pokemongiveaway", 37 | "pokemontrades", 38 | "Pokemonexchange", 39 | "PokeMoonSun", 40 | "BankBallExchange", 41 | "pokemonduel", 42 | "SVExchange", 43 | "ClubNintendoTrade", 44 | "PokemonQRCodes", 45 | "ACTrade", 46 | "RocketLeagueExchange", 47 | "rocket_league_trading", 48 | 49 | "YamakuHighSchool", 50 | "XMenRP", 51 | "CampArcadia", 52 | "MonarchyOfEquestria", 53 | "rwbyRP", 54 | "TTPloreplaycentral", 55 | "TheDescendantsOfRome", 56 | "CampHalfBloodRP", 57 | "RWBY", 58 | "rwbyRP", 59 | "EroticRolePlay", 60 | "PercyJacksonRP", 61 | "PotterPlayRP", 62 | "HogwartsRP", 63 | "ALORP", 64 | "SupersRP", 65 | "dcrp", 66 | "BloodGulchRP", 67 | "IronThronePowers", 68 | "MassEffectPhoenix", 69 | "MigrantFleet", 70 | "TheInnBetween", 71 | "AntiHeroReborn", 72 | "AuraRP", 73 | "CrimsonShoresRP", 74 | "DarkPantheon", 75 | "darkestdungeonrp", 76 | "Devilrp", 77 | "Fairy_TailRP", 78 | "FTRP", 79 | "HeroesAcademyReborn", 80 | "HonorHillRP", 81 | "TheNarutoWorld", 82 | "randomsuperpowers", 83 | "RidersOfBerk", 84 | "SalvaticaRP", 85 | "SuperWorldRP", 86 | "TheKalenSeries", 87 | "GreekMythRP", 88 | "BullworthRP", 89 | "InfamousSecondRP", 90 | "vegasquadrantrp" 91 | ], 92 | "substring_blacklist": [ 93 | "[", 94 | "http://", 95 | "https://", 96 | " r/", 97 | " u/", 98 | "/r/", 99 | "/u/", 100 | "reddit", 101 | "Reddit", 102 | "upvot", 103 | "Upvot", 104 | "downvot", 105 | "Downvot", 106 | "username", 107 | "Username", 108 | "OOC:", 109 | ">" 110 | ] 111 | } -------------------------------------------------------------------------------- /reddit-parse/reddit_parse.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import argparse 3 | import os 4 | import json 5 | import re 6 | import sys 7 | 8 | FILE_SUFFIX = ".bz2" 9 | OUTPUT_FILE = "output.bz2" 10 | REPORT_FILE = "RC_report.txt" 11 | 12 | def main(): 13 | assert sys.version_info >= (3, 3), \ 14 | "Must be run in Python 3.3 or later. You are running {}".format(sys.version) 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--input_file', type=str, default='reddit_data', 18 | help='data file or directory containing bz2 archive of json reddit data') 19 | parser.add_argument('--logdir', type=str, default='output/', 20 | help='directory to save the output and report') 21 | parser.add_argument('--config_file', type=str, default='parser_config_standard.json', 22 | help='json parameters for parsing') 23 | parser.add_argument('--comment_cache_size', type=int, default=1e7, 24 | help='max number of comments to cache in memory before flushing') 25 | parser.add_argument('--output_file_size', type=int, default=2e8, 26 | help='max size of each output file (give or take one conversation)') 27 | parser.add_argument('--print_every', type=int, default=1000, 28 | help='print an update to the screen this often') 29 | parser.add_argument('--min_conversation_length', type=int, default=5, 30 | help='conversations must have at least this many comments for inclusion') 31 | parser.add_argument('--print_subreddit', type=str2bool, nargs='?', 32 | const=False, default=False, 33 | help='set to true to print the name of the subreddit before each conversation' 34 | + ' to facilitate more convenient blacklisting in the config json file.' 35 | + ' (Remember to disable before constructing training data.)') 36 | args = parser.parse_args() 37 | parse_main(args) 38 | 39 | def str2bool(v): 40 | if v.lower() in ('yes', 'true', 't', 'y', '1'): return True 41 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False 42 | else: raise argparse.ArgumentTypeError('Boolean value expected.') 43 | 44 | class RedditComment(object): 45 | def __init__(self, json_object, record_subreddit=False): 46 | self.body = json_object['body'] 47 | if 'score' in json_object: 48 | self.score = json_object['score'] 49 | elif 'ups' in json_object and 'down' in json_object: 50 | self.score = json_object['ups'] - json_object['downs'] 51 | else: raise ValueError("Reddit comment did not include a score attribute. " 52 | + "Comment was as follows: " + json_object) 53 | self.author = json_object['author'] 54 | parent_id = json_object['parent_id'] 55 | # t1_ prefixes indicate comments. t3_ prefix would indicate a link submission. 56 | if parent_id.startswith('t1_'): self.parent_id = parent_id 57 | else: self.parent_id = None 58 | self.child_id = None 59 | if record_subreddit: self.subreddit = json_object['subreddit'] 60 | 61 | def parse_main(args): 62 | if not os.path.isfile(args.config_file): 63 | print("File not found: {}".format(args.input_file)) 64 | return 65 | with open(args.config_file, 'r') as f: 66 | config = json.load(f) 67 | subreddit_blacklist = set(config['subreddit_blacklist']) 68 | subreddit_whitelist = set(config['subreddit_whitelist']) 69 | substring_blacklist = set(config['substring_blacklist']) 70 | 71 | if not os.path.exists(args.input_file): 72 | print("File not found: {}".format(args.input_file)) 73 | return 74 | if os.path.isfile(args.logdir): 75 | print("File already exists at output directory location: {}".format(args.logdir)) 76 | return 77 | if not os.path.exists(args.logdir): 78 | os.makedirs(args.logdir) 79 | subreddit_dict = {} 80 | comment_dict = {} 81 | raw_data = raw_data_generator(args.input_file) 82 | output_handler = OutputHandler(os.path.join(args.logdir, OUTPUT_FILE), args.output_file_size) 83 | done = False 84 | total_read = 0 85 | while not done: 86 | done, i = read_comments_into_cache(raw_data, comment_dict, args.print_every, args.print_subreddit, 87 | args.comment_cache_size, subreddit_dict, substring_blacklist, subreddit_whitelist, substring_blacklist) 88 | total_read += i 89 | process_comment_cache(comment_dict, args.print_every) 90 | write_comment_cache(comment_dict, output_handler, args.print_every, 91 | args.print_subreddit, args.min_conversation_length) 92 | write_report(os.path.join(args.logdir, REPORT_FILE), subreddit_dict) 93 | comment_dict.clear() 94 | print("\nRead all {:,d} lines from {}.".format(total_read, args.input_file)) 95 | 96 | def read_comments_into_cache(raw_data, comment_dict, print_every, print_subreddit, comment_cache_size, 97 | subreddit_dict, subreddit_blacklist, subreddit_whitelist, substring_blacklist): 98 | done = False 99 | cache_count = 0 100 | for i, line in enumerate(raw_data): 101 | # Ignore certain kinds of malformed JSON 102 | if len(line) > 1 and (line[-1] == '}' or line[-2] == '}'): 103 | comment = json.loads(line) 104 | if post_qualifies(comment, subreddit_blacklist, # Also preprocesses the post. 105 | subreddit_whitelist, substring_blacklist): 106 | sub = comment['subreddit'] 107 | if sub in subreddit_dict: 108 | subreddit_dict[sub] += 1 109 | else: subreddit_dict[sub] = 1 110 | comment_dict[comment['id']] = RedditComment(comment, print_subreddit) 111 | cache_count += 1 112 | if cache_count % print_every == 0: 113 | print("\rCached {:,d} comments".format(cache_count), end='') 114 | sys.stdout.flush() 115 | if cache_count > comment_cache_size: break 116 | else: # raw_data has been exhausted. 117 | done = True 118 | print() 119 | return done, i 120 | 121 | def raw_data_generator(path): 122 | if os.path.isdir(path): 123 | for walk_root, walk_dir, walk_files in os.walk(path): 124 | for file_name in walk_files: 125 | file_path = os.path.join(walk_root, file_name) 126 | if file_path.endswith(FILE_SUFFIX): 127 | print("\nReading from {}".format(file_path)) 128 | with bz2.open(file_path, "rt") as raw_data: 129 | try: 130 | for line in raw_data: yield line 131 | except IOError: 132 | print("IOError from file {}".format(file_path)) 133 | continue 134 | else: print("Skipping file {} (doesn't end with {})".format(file_path, FILE_SUFFIX)) 135 | elif os.path.isfile(path): 136 | print("Reading from {}".format(path)) 137 | with bz2.open(path, "rt") as raw_data: 138 | for line in raw_data: yield line 139 | 140 | class OutputHandler(): 141 | def __init__(self, path, output_file_size): 142 | if path.endswith(FILE_SUFFIX): 143 | path = path[:-len(FILE_SUFFIX)] 144 | self.base_path = path 145 | self.output_file_size = output_file_size 146 | self.file_reference = None 147 | 148 | def write(self, data): 149 | if self.file_reference is None: 150 | self._get_current_path() 151 | self.file_reference.write(data) 152 | self.current_file_size += len(data) 153 | if self.current_file_size >= self.output_file_size: 154 | self.file_reference.close() 155 | self.file_reference = None 156 | 157 | def _get_current_path(self): 158 | i = 1 159 | while True: 160 | path = "{} {}{}".format(self.base_path, i, FILE_SUFFIX) 161 | if not os.path.exists(path): break 162 | i += 1 163 | self.current_path = path 164 | self.current_file_size = 0 165 | self.file_reference = bz2.open(self.current_path, mode="wt") 166 | 167 | def post_qualifies(json_object, subreddit_blacklist, 168 | subreddit_whitelist, substring_blacklist): 169 | body = json_object['body'] 170 | post_length = len(body) 171 | if post_length < 4 or post_length > 200: return False 172 | subreddit = json_object['subreddit'] 173 | if len(subreddit_whitelist) > 0 and subreddit not in subreddit_whitelist: return False 174 | if len(subreddit_blacklist) > 0 and subreddit in subreddit_blacklist: return False 175 | if len(substring_blacklist) > 0: 176 | for substring in substring_blacklist: 177 | if body.find(substring) >= 0: return False 178 | # Preprocess the comment text. 179 | body = re.sub('[ \t\n\r]+', ' ', body) # Replace runs of whitespace with a single space. 180 | body = re.sub('\^', '', body) # Strip out carets. 181 | body = re.sub('\\\\', '', body) # Strip out backslashes. 182 | body = re.sub('<', '<', body) # Replace '<' with '<' 183 | body = re.sub('>', '>', body) # Replace '>' with '>' 184 | body = re.sub('&', '&', body) # Replace '&' with '&' 185 | post_length = len(body) 186 | # Check the length again, now that we've preprocessed it. 187 | if post_length < 4 or post_length > 200: return False 188 | json_object['body'] = body # Save our changes 189 | # Make sure the ID has the 't1_' prefix because that is how child comments refer to their parents. 190 | if not json_object['id'].startswith('t1_'): json_object['id'] = 't1_' + json_object['id'] 191 | return True 192 | 193 | def process_comment_cache(comment_dict, print_every): 194 | i = 0 195 | for my_id, my_comment in comment_dict.items(): 196 | i += 1 197 | if i % print_every == 0: 198 | print("\rProcessed {:,d} comments".format(i), end='') 199 | sys.stdout.flush() 200 | if my_comment.parent_id is not None: # If we're not a top-level post... 201 | if my_comment.parent_id in comment_dict: # ...and the parent is in our data set... 202 | parent = comment_dict[my_comment.parent_id] 203 | if parent.child_id is None: # If my parent doesn't already have a child, adopt me! 204 | parent.child_id = my_id 205 | else: # My parent already has a child. 206 | parent_previous_child = comment_dict[parent.child_id] 207 | if parent.parent_id in comment_dict: # If my grandparent is in our data set... 208 | grandparent = comment_dict[parent.parent_id] 209 | if my_comment.author == grandparent.author: 210 | # If I share an author with grandparent, adopt me! 211 | parent.child_id = my_id 212 | elif (parent_previous_child.author != grandparent.author 213 | and my_comment.score > parent_previous_child.score): 214 | # If the existing child doesn't share an author with grandparent, 215 | # higher score prevails. 216 | parent.child_id = my_id 217 | elif my_comment.score > parent_previous_child.score: 218 | # If there's no grandparent, the higher-score child prevails. 219 | parent.child_id = my_id 220 | else: 221 | # Parent IDs that aren't in the data set get de-referenced. 222 | my_comment.parent_id = None 223 | print() 224 | 225 | def write_comment_cache(comment_dict, output_file, print_every, 226 | record_subreddit=False, min_conversation_length=5): 227 | i = 0 228 | prev_print_count = 0 229 | for k, v in comment_dict.items(): 230 | if v.parent_id is None and v.child_id is not None: 231 | comment = v 232 | depth = 0 233 | if record_subreddit: output_string = "/r/" + comment.subreddit + '\n' 234 | else: output_string = "" 235 | while comment is not None: 236 | depth += 1 237 | output_string += '> ' + comment.body + '\n' 238 | if comment.child_id in comment_dict: 239 | comment = comment_dict[comment.child_id] 240 | else: 241 | comment = None 242 | if depth >= min_conversation_length: 243 | output_file.write(output_string + '\n') 244 | i += depth 245 | if i > prev_print_count + print_every: 246 | prev_print_count = i 247 | print("\rWrote {:,d} comments".format(i), end='') 248 | sys.stdout.flush() 249 | print() 250 | 251 | def write_report(report_file_path, subreddit_dict): 252 | print("Updating subreddit report file") 253 | subreddit_list = sorted(subreddit_dict.items(), key=lambda x: -x[1]) 254 | with open(report_file_path, "w") as f: 255 | for item in subreddit_list: 256 | f.write("{}: {}\n".format(*item)) 257 | 258 | if __name__ == '__main__': 259 | main() 260 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import argparse 5 | import time, datetime 6 | import os 7 | import pickle 8 | import sys 9 | 10 | from utils import TextLoader 11 | from model import Model 12 | 13 | def main(): 14 | assert sys.version_info >= (3, 3), \ 15 | "Must be run in Python 3.3 or later. You are running {}".format(sys.version) 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data_dir', type=str, default='data/scotus', 19 | help='data directory containing input.txt') 20 | parser.add_argument('--save_dir', type=str, default='models/new_save', 21 | help='directory for checkpointed models (load from here if one is already present)') 22 | parser.add_argument('--block_size', type=int, default=2048, 23 | help='number of cells per block') 24 | parser.add_argument('--num_blocks', type=int, default=3, 25 | help='number of blocks per layer') 26 | parser.add_argument('--num_layers', type=int, default=3, 27 | help='number of layers') 28 | parser.add_argument('--model', type=str, default='gru', 29 | help='rnn, gru, lstm or nas') 30 | parser.add_argument('--batch_size', type=int, default=40, 31 | help='minibatch size') 32 | parser.add_argument('--seq_length', type=int, default=40, 33 | help='RNN sequence length') 34 | parser.add_argument('--num_epochs', type=int, default=50, 35 | help='number of epochs') 36 | parser.add_argument('--save_every', type=int, default=5000, 37 | help='save frequency') 38 | parser.add_argument('--grad_clip', type=float, default=5., 39 | help='clip gradients at this value') 40 | parser.add_argument('--learning_rate', type=float, default=1e-5, 41 | help='learning rate') 42 | parser.add_argument('--decay_rate', type=float, default=0.975, 43 | help='how much to decay the learning rate') 44 | parser.add_argument('--decay_steps', type=int, default=100000, 45 | help='how often to decay the learning rate') 46 | parser.add_argument('--set_learning_rate', type=float, default=-1, 47 | help='reset learning rate to this value (if greater than zero)') 48 | args = parser.parse_args() 49 | train(args) 50 | 51 | def train(args): 52 | # Create the data_loader object, which loads up all of our batches, vocab dictionary, etc. 53 | # from utils.py (and creates them if they don't already exist). 54 | # These files go in the data directory. 55 | data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length) 56 | args.vocab_size = data_loader.vocab_size 57 | 58 | load_model = False 59 | if not os.path.exists(args.save_dir): 60 | print("Creating directory %s" % args.save_dir) 61 | os.mkdir(args.save_dir) 62 | elif (os.path.exists(os.path.join(args.save_dir, 'config.pkl'))): 63 | # Trained model already exists 64 | ckpt = tf.train.get_checkpoint_state(args.save_dir) 65 | if ckpt and ckpt.model_checkpoint_path: 66 | with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: 67 | saved_args = pickle.load(f) 68 | args.block_size = saved_args.block_size 69 | args.num_blocks = saved_args.num_blocks 70 | args.num_layers = saved_args.num_layers 71 | args.model = saved_args.model 72 | print("Found a previous checkpoint. Overwriting model description arguments to:") 73 | print(" model: {}, block_size: {}, num_blocks: {}, num_layers: {}".format( 74 | saved_args.model, saved_args.block_size, saved_args.num_blocks, saved_args.num_layers)) 75 | load_model = True 76 | 77 | # Save all arguments to config.pkl in the save directory -- NOT the data directory. 78 | with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f: 79 | pickle.dump(args, f) 80 | # Save a tuple of the characters list and the vocab dictionary to chars_vocab.pkl in 81 | # the save directory -- NOT the data directory. 82 | with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f: 83 | pickle.dump((data_loader.chars, data_loader.vocab), f) 84 | 85 | # Create the model! 86 | print("Building the model") 87 | model = Model(args) 88 | print("Total trainable parameters: {:,d}".format(model.trainable_parameter_count())) 89 | 90 | # Make tensorflow less verbose; filter out info (1+) and warnings (2+) but not errors (3). 91 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 92 | 93 | config = tf.ConfigProto(log_device_placement=False) 94 | #config.gpu_options.allow_growth = True 95 | with tf.Session(config=config) as sess: 96 | tf.global_variables_initializer().run() 97 | saver = tf.train.Saver(model.save_variables_list(), max_to_keep=3) 98 | if (load_model): 99 | print("Loading saved parameters") 100 | saver.restore(sess, ckpt.model_checkpoint_path) 101 | global_epoch_fraction = sess.run(model.global_epoch_fraction) 102 | global_seconds_elapsed = sess.run(model.global_seconds_elapsed) 103 | if load_model: print("Resuming from global epoch fraction {:.3f}," 104 | " total trained time: {}, learning rate: {}".format( 105 | global_epoch_fraction, 106 | datetime.timedelta(seconds=float(global_seconds_elapsed)), 107 | sess.run(model.lr))) 108 | if (args.set_learning_rate > 0): 109 | sess.run(tf.assign(model.lr, args.set_learning_rate)) 110 | print("Reset learning rate to {}".format(args.set_learning_rate)) 111 | data_loader.cue_batch_pointer_to_epoch_fraction(global_epoch_fraction) 112 | initial_batch_step = int((global_epoch_fraction 113 | - int(global_epoch_fraction)) * data_loader.total_batch_count) 114 | epoch_range = (int(global_epoch_fraction), 115 | args.num_epochs + int(global_epoch_fraction)) 116 | writer = tf.summary.FileWriter(args.save_dir, graph=tf.get_default_graph()) 117 | outputs = [model.cost, model.final_state, model.train_op, model.summary_op] 118 | global_step = epoch_range[0] * data_loader.total_batch_count + initial_batch_step 119 | avg_loss = 0 120 | avg_steps = 0 121 | try: 122 | for e in range(*epoch_range): 123 | # e iterates through the training epochs. 124 | # Reset the model state, so it does not carry over from the end of the previous epoch. 125 | state = sess.run(model.zero_state) 126 | batch_range = (initial_batch_step, data_loader.total_batch_count) 127 | initial_batch_step = 0 128 | for b in range(*batch_range): 129 | global_step += 1 130 | if global_step % args.decay_steps == 0: 131 | # Set the model.lr element of the model to track 132 | # the appropriately decayed learning rate. 133 | current_learning_rate = sess.run(model.lr) 134 | current_learning_rate *= args.decay_rate 135 | sess.run(tf.assign(model.lr, current_learning_rate)) 136 | print("Decayed learning rate to {}".format(current_learning_rate)) 137 | start = time.time() 138 | # Pull the next batch inputs (x) and targets (y) from the data loader. 139 | x, y = data_loader.next_batch() 140 | 141 | # feed is a dictionary of variable references and respective values for initialization. 142 | # Initialize the model's input data and target data from the batch, 143 | # and initialize the model state to the final state from the previous batch, so that 144 | # model state is accumulated and carried over between batches. 145 | feed = {model.input_data: x, model.targets: y} 146 | model.add_state_to_feed_dict(feed, state) 147 | 148 | # Run the session! Specifically, tell TensorFlow to compute the graph to calculate 149 | # the values of cost, final state, and the training op. 150 | # Cost is used to monitor progress. 151 | # Final state is used to carry over the state into the next batch. 152 | # Training op is not used, but we want it to be calculated, since that calculation 153 | # is what updates parameter states (i.e. that is where the training happens). 154 | train_loss, state, _, summary = sess.run(outputs, feed) 155 | elapsed = time.time() - start 156 | global_seconds_elapsed += elapsed 157 | writer.add_summary(summary, e * batch_range[1] + b + 1) 158 | if avg_steps < 100: avg_steps += 1 159 | avg_loss = 1 / avg_steps * train_loss + (1 - 1 / avg_steps) * avg_loss 160 | print("{:,d} / {:,d} (epoch {:.3f} / {}), loss {:.3f} (avg {:.3f}), {:.3f}s" \ 161 | .format(b, batch_range[1], e + b / batch_range[1], epoch_range[1], 162 | train_loss, avg_loss, elapsed)) 163 | # Every save_every batches, save the model to disk. 164 | # By default, only the five most recent checkpoint files are kept. 165 | if (e * batch_range[1] + b + 1) % args.save_every == 0 \ 166 | or (e == epoch_range[1] - 1 and b == batch_range[1] - 1): 167 | save_model(sess, saver, model, args.save_dir, global_step, 168 | data_loader.total_batch_count, global_seconds_elapsed) 169 | except KeyboardInterrupt: 170 | # Introduce a line break after ^C is displayed so save message 171 | # is on its own line. 172 | print() 173 | finally: 174 | writer.flush() 175 | global_step = e * data_loader.total_batch_count + b 176 | save_model(sess, saver, model, args.save_dir, global_step, 177 | data_loader.total_batch_count, global_seconds_elapsed) 178 | 179 | def save_model(sess, saver, model, save_dir, global_step, steps_per_epoch, global_seconds_elapsed): 180 | global_epoch_fraction = float(global_step) / float(steps_per_epoch) 181 | checkpoint_path = os.path.join(save_dir, 'model.ckpt') 182 | print("Saving model to {} (epoch fraction {:.3f})...".format(checkpoint_path, global_epoch_fraction), 183 | end='', flush=True) 184 | sess.run(tf.assign(model.global_epoch_fraction, global_epoch_fraction)) 185 | sess.run(tf.assign(model.global_seconds_elapsed, global_seconds_elapsed)) 186 | saver.save(sess, checkpoint_path, global_step = global_step) 187 | print("\rSaved model to {} (epoch fraction {:.3f}). ".format(checkpoint_path, global_epoch_fraction)) 188 | 189 | if __name__ == '__main__': 190 | main() 191 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import pickle 4 | import time 5 | import bz2 6 | import numpy as np 7 | 8 | class TextLoader(): 9 | # Call this class to load text from a file. 10 | def __init__(self, data_dir, batch_size, seq_length): 11 | # TextLoader remembers its initialization arguments. 12 | self.data_dir = data_dir 13 | self.batch_size = batch_size 14 | self.seq_length = seq_length 15 | self.tensor_sizes = [] 16 | 17 | self.tensor_file_template = os.path.join(data_dir, "data{}.npz") 18 | vocab_file = os.path.join(data_dir, "vocab.pkl") 19 | sizes_file = os.path.join(data_dir, "sizes.pkl") 20 | 21 | self.input_files = self._get_input_file_list(data_dir) 22 | self.input_file_count = len(self.input_files) 23 | 24 | if self.input_file_count < 1: 25 | raise ValueError("Input files not found. File names must end in '.txt' or '.bz2'.") 26 | 27 | if self._preprocess_required(vocab_file, sizes_file, self.tensor_file_template, self.input_file_count): 28 | # If either the vocab file or the tensor file doesn't already exist, create them. 29 | t0 = time.time() 30 | print("Preprocessing the following files:") 31 | for i, filename in enumerate(self.input_files): print(" {}.\t{}".format(i+1, filename)) 32 | print("Saving vocab file") 33 | self._save_vocab(vocab_file) 34 | 35 | for i, filename in enumerate(self.input_files): 36 | t1 = time.time() 37 | print("Preprocessing file {}/{} ({})... ".format(i+1, len(self.input_files), filename), 38 | end='', flush=True) 39 | self._preprocess(self.input_files[i], self.tensor_file_template.format(i)) 40 | self.tensor_sizes.append(self.tensor.size) 41 | print("done ({:.1f} seconds)".format(time.time() - t1), flush=True) 42 | 43 | with open(sizes_file, 'wb') as f: 44 | pickle.dump(self.tensor_sizes, f) 45 | 46 | print("Processed input data: {:,d} characters loaded ({:.1f} seconds)".format( 47 | self.tensor.size, time.time() - t0)) 48 | else: 49 | # If the vocab file and sizes file already exist, load them. 50 | print("Loading vocab file...") 51 | self._load_vocab(vocab_file) 52 | print("Loading sizes file...") 53 | with open(sizes_file, 'rb') as f: 54 | self.tensor_sizes = pickle.load(f) 55 | self.tensor_batch_counts = [n // (self.batch_size * self.seq_length) for n in self.tensor_sizes] 56 | self.total_batch_count = sum(self.tensor_batch_counts) 57 | print("Total batch count: {:,d}".format(self.total_batch_count)) 58 | 59 | self.tensor_index = -1 60 | 61 | def _preprocess_required(self, vocab_file, sizes_file, tensor_file_template, input_file_count): 62 | if not os.path.exists(vocab_file): 63 | print("No vocab file found. Preprocessing...") 64 | return True 65 | if not os.path.exists(sizes_file): 66 | print("No sizes file found. Preprocessing...") 67 | return True 68 | for i in range(input_file_count): 69 | if not os.path.exists(tensor_file_template.format(i)): 70 | print("Couldn't find {}. Preprocessing...".format(tensor_file_template.format(i))) 71 | return True 72 | return False 73 | 74 | def _get_input_file_list(self, data_dir): 75 | suffixes = ['.txt', '.bz2'] 76 | input_file_list = [] 77 | if os.path.isdir(data_dir): 78 | for walk_root, walk_dir, walk_files in os.walk(data_dir): 79 | for file_name in walk_files: 80 | if file_name.startswith("."): continue 81 | file_path = os.path.join(walk_root, file_name) 82 | if file_path.endswith(suffixes[0]) or file_path.endswith(suffixes[1]): 83 | input_file_list.append(file_path) 84 | else: raise ValueError("Not a directory: {}".format(data_dir)) 85 | return sorted(input_file_list) 86 | 87 | def _save_vocab(self, vocab_file): 88 | self.chars = [chr(i) for i in range(128)] 89 | self.vocab_size = len(self.chars) 90 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 91 | with open(vocab_file, 'wb') as f: 92 | pickle.dump(self.chars, f) 93 | print("Saved vocab (vocab size: {:,d})".format(self.vocab_size)) 94 | 95 | def _load_vocab(self, vocab_file): 96 | # Load the character tuple (vocab.pkl) to self.chars. 97 | # Remember that it is in descending order of character frequency in the data. 98 | with open(vocab_file, 'rb') as f: 99 | self.chars = pickle.load(f) 100 | # Use the character tuple to regenerate vocab_size and the vocab dictionary. 101 | self.vocab_size = len(self.chars) 102 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 103 | 104 | def _preprocess(self, input_file, tensor_file): 105 | if input_file.endswith(".bz2"): file_reference = bz2.open(input_file, mode='rt') 106 | elif input_file.endswith(".txt"): file_reference = io.open(input_file, mode='rt') 107 | data = file_reference.read() 108 | file_reference.close() 109 | # Convert the entirety of the data file from characters to indices via the vocab dictionary. 110 | # How? map(function, iterable) returns a list of the output of the function 111 | # executed on each member of the iterable. E.g.: 112 | # [14, 2, 9, 2, 0, 6, 7, 0, ...] 113 | # np.array converts the list into a numpy array. 114 | self.tensor = np.array(list(map(self.vocab.get, data))) 115 | self.tensor = self.tensor[self.tensor != np.array(None)].astype(int) # Filter out None 116 | # Compress and save the numpy tensor array to data.npz. 117 | np.savez_compressed(tensor_file, tensor_data=self.tensor) 118 | 119 | def _load_preprocessed(self, tensor_index): 120 | self.reset_batch_pointer() 121 | if tensor_index == self.tensor_index: 122 | return 123 | print("loading tensor data file {}".format(tensor_index)) 124 | tensor_file = self.tensor_file_template.format(tensor_index) 125 | # Load the data tensor file to self.tensor. 126 | with np.load(tensor_file) as loaded: 127 | self.tensor = loaded['tensor_data'] 128 | self.tensor_index = tensor_index 129 | # Calculate the number of batches in the data. Each batch is batch_size x seq_length, 130 | # so this is just the input data size divided by that product, rounded down. 131 | self.num_batches = self.tensor.size // (self.batch_size * self.seq_length) 132 | if self.tensor_batch_counts[tensor_index] != self.num_batches: 133 | print("Error in batch size! Expected {:,d}; found {:,d}".format(self.tensor_batch_counts[tensor_index], 134 | self.num_batches)) 135 | # Chop off the end of the data tensor so that the length of the data is a whole 136 | # multiple of the (batch_size x seq_length) product. 137 | # Do this with the slice operator on the numpy array. 138 | self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length] 139 | # Construct two numpy arrays to represent input characters (xdata) 140 | # and target characters (ydata). 141 | # In training, we will feed in input characters one at a time, and optimize along 142 | # a loss function computed against the target characters. 143 | # (We do this with batch_size characters at a time, in parallel.) 144 | # Since this is a sequence prediction net, the target is just the input right-shifted 145 | # by 1. 146 | xdata = self.tensor 147 | ydata = np.copy(self.tensor) # Y-data starts as a copy of x-data. 148 | ydata[:-1] = xdata[1:] # Right-shift y-data by 1 using the numpy array slice syntax. 149 | # Replace the very last character of y-data with the first character of the input data. 150 | ydata[-1] = xdata[0] 151 | # Split our unidemnsional data array into distinct batches. 152 | # How? xdata.reshape(self.batch_size, -1) returns a 2D numpy tensor view 153 | # in which the first dimension is the batch index (from 0 to num_batches), 154 | # and the second dimension is the index of the character within the batch 155 | # (from 0 to (batch_size x seq_length)). 156 | # Within each batch, characters follow the same sequence as in the input data. 157 | # Then, np.split(that 2D numpy tensor, num_batches, 1) gives a list of numpy arrays. 158 | # Say batch_size = 4, seq_length = 5, and data is the following string: 159 | # "Here is a new string named data. It is a new string named data. It is named data." 160 | # We truncate the string to lop off the last period (so there are now 80 characters, 161 | # which is evenly divisible by 4 x 5). After xdata.reshape, we have: 162 | # 163 | # [[Here is a new string], 164 | # [ named data. It is a], 165 | # [ new string named da], 166 | # [ta. It is named data]] 167 | # 168 | # After np.split, we have: 169 | # <[[Here ], <[[is a ], <[[new s], <[[tring], 170 | # [ name], [d dat], [a. It], [ is a], 171 | # [ new ], [strin], [g nam], [ed da], 172 | # [ta. I]]>, [t is ]]>, [named]]>, [ data]]> 173 | # 174 | # where the first item of the list is the numpy array on the left. 175 | # Thus x_batches is a list of numpy arrays. The first dimension of each numpy array 176 | # is the batch number (from 0 to batch_size), and the second dimension is the 177 | # character index (from 0 to seq_length). 178 | # 179 | # These will be fed to the model one at a time sequentially. 180 | # State is preserved between sequential batches. 181 | self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1) 182 | self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1) 183 | 184 | def next_batch(self): 185 | if self.tensor_index < 0: 186 | self._load_preprocessed(0) 187 | if self.pointer >= self.num_batches: 188 | self._load_preprocessed((self.tensor_index + 1) % self.input_file_count) 189 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 190 | self.pointer += 1 191 | return x, y 192 | 193 | def reset_batch_pointer(self): 194 | self.pointer = 0 195 | 196 | def cue_batch_pointer_to_epoch_fraction(self, epoch_fraction): 197 | step_target = (epoch_fraction - int(epoch_fraction)) * self.total_batch_count 198 | self._cue_batch_pointer_to_step_count(step_target) 199 | 200 | def _cue_batch_pointer_to_step_count(self, step_target): 201 | for i, n in enumerate(self.tensor_batch_counts): 202 | if step_target < n: 203 | break 204 | step_target -= n 205 | self.pointer = n 206 | self.current_tensor_index = i 207 | self._load_preprocessed(i) --------------------------------------------------------------------------------