├── LICENSE ├── README.md ├── bptt.py └── example_bptt.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | A Simple Design Pattern for TensorFlow Recurrency 2 | =============== 3 | 4 | 5 | ### Introduction 6 | This repo includes the bptt.py library for recurrent graph bookkeeping (in Tensorflow) and a sample client which learns to predict a simple palindromic sequence using a double-layer LSTM (Graves 2013). 7 | 8 | See https://medium.com/@devnag/a-simple-design-pattern-for-recurrent-deep-learning-in-tensorflow-37aba4e2fd6b for the relevant blog post. 9 | 10 | 11 | ### Running 12 | Run the sample code by typing: 13 | 14 | 15 | ``` 16 | ./example_bptt.py 17 | ``` 18 | 19 | ...and you'll train a 2-layer LSTM on a palindromic sequence prediction task, then test it on sequential inference. The loss should drop below 1e-3 pretty quickly, and then you'll see the last few hundred attempts vs. the expected output. 20 | -------------------------------------------------------------------------------- /bptt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | # See https://medium.com/@devnag/ 7 | 8 | 9 | class BPTT(object): 10 | """ 11 | Convenience design pattern for handling simple recurrent graphs, implementing backpropagation through time. 12 | See https://medium.com/@devnag/ 13 | 14 | Typical usage: 15 | 16 | - Graph building 17 | - Define a function that takes a BPTT object and the depth flag (will be BPTT.DEEP or BPTT.SHALLOW) 18 | and builds your computational graph; should return any I/O placeholders in an array. 19 | - Use get_past_variable() to define a name (string) and pass in a constant value (numpy). 20 | - Use name_variable() to name (string) the same value for the current loop, for the future. 21 | 22 | - Unrolling 23 | - bp.generate_graphs() will take the function above and the desired BPTT depth and provide the 24 | sequence of stitched DAGs. 25 | 26 | 27 | - Training 28 | - generate_feed_dict() on the relevant depth (BPTT.DEEP) with the array data to be fed into the 29 | I/O placeholders that your custom graph function returned. This will also include the working 30 | state for the recurrent variables (whether the starting constants or state from the last loop). 31 | Must also include a count of the number of I/O slots. 32 | - generate_output_definitions() will provide an array of variables that must be fetched to extract state. 33 | - save_output_state() will take the results and save for the next loop. 34 | 35 | - Inference 36 | - Same three functions as in training, but use BPTT.SHALLOW instead. 37 | - Can optionally call copy_state_forward() before inference if you want to start with the final training state. 38 | """ 39 | 40 | DEEP = "deep" 41 | SHALLOW = "shallow" 42 | MODEL_NAME = "unrolled_model" 43 | LOOP_SCOPE = "unroll" 44 | 45 | def __init__(self): 46 | """ 47 | Initialize the name dictionaries (state, placeholders, constants, etc) 48 | """ 49 | self.graph_dict = {} 50 | 51 | # Name -> Constants: Starting values (typically np.arrays). Shared between shallow/deep, used in run-time 52 | self.starting_constants = {} 53 | # Name -> State: np.arrays reflecting state between run-times (starting from C) 54 | self.state = {self.DEEP: {}, self.SHALLOW: {}} 55 | # Name -> Variables: Py variables passed through during build-time 56 | self.vars = {self.DEEP: {}, self.SHALLOW: {}} 57 | # Name -> Placeholder: Placeholders: to inject state, set during build-time 58 | self.placeholders = {self.DEEP: {}, self.SHALLOW: {}} 59 | 60 | self.current_depth = self.DEEP 61 | 62 | def get_past_variable(self, variable_name, starting_value): 63 | """ 64 | Get-or-set a recurrent variable from the past (time t-1) 65 | 66 | :param variable_name: A unique (to this object) string representing this variable. 67 | :param starting_value: A constant that can be fed into a placeholder eventually 68 | :return: A variable (representing the value at t-1) that can be computed on to generate current value (at t) 69 | """ 70 | 71 | if variable_name not in self.placeholders[self.current_depth]: 72 | # First time being called 73 | self.starting_constants[variable_name] = starting_value 74 | 75 | # First initial state is the constant np.array sent in 76 | self.state[self.current_depth][variable_name] = starting_value 77 | 78 | # Define a mirror placeholder with same type/shape 79 | self.placeholders[self.current_depth][variable_name] = tf.placeholder(starting_value.dtype, 80 | shape=starting_value.shape) 81 | # Set current (starting) variable as that placeholder, to be filled in later 82 | self.vars[self.current_depth][variable_name] = self.placeholders[self.current_depth][variable_name] 83 | 84 | # Return the pyvariable: placeholder the first time, pydescendant on later calls 85 | return self.vars[self.current_depth][variable_name] 86 | 87 | def name_variable(self, variable_name, v): 88 | """ 89 | Set/assign a recurrent variable for the current time (time t) 90 | 91 | :param variable_name: A unique (to this object) string, must have been used in a get_past_variable() call 92 | :param v: A Tensorflow variable representing the current value of this variable (at t) 93 | :return: v, unchanged, for easy in-line usage 94 | """ 95 | assert variable_name in self.vars[self.current_depth], \ 96 | "Tried to set variable name that was never defined with get_past_variable()" 97 | self.vars[self.current_depth][variable_name] = v 98 | return v 99 | 100 | def generate_graphs(self, func, num_loops=10): 101 | """ 102 | Generate the two graphs -- the deep (unrolled) connected graphs and the shallow/simple graph. 103 | 104 | :param func: A function which takes the BPTT object and the depth_type (BPTT.{DEEP,SHALLOW}), returns 105 | array of I/O placeholders. 106 | :param num_loops: The desired number of loops to unroll 107 | :return: A dictionary of the two graphs (deep+shallow). 108 | """ 109 | # Scoping -- generate the deep/unrolled graph (training) 110 | self.current_depth = self.DEEP 111 | with tf.variable_scope(self.MODEL_NAME, reuse=False): 112 | self.graph_dict[self.DEEP] = self.unroll(func, self.DEEP, num_loops) 113 | 114 | # Now, generate the shallow graph (inference) 115 | self.current_depth = self.SHALLOW 116 | with tf.variable_scope(self.MODEL_NAME, reuse=True): 117 | # Shallow is depth 1, but sharing all variables with deep graph above 118 | self.graph_dict[self.SHALLOW] = self.unroll(func, self.SHALLOW, 1) 119 | 120 | return self.graph_dict 121 | 122 | def unroll(self, func, depth_type, num_loops): 123 | """ 124 | Given the graph-generating function, unroll to the desired depth. 125 | 126 | :param func: A function which takes the BPTT object and the depth_type (BPTT.{DEEP,SHALLOW}), returns 127 | array of I/O placeholders. 128 | :param depth_type: The depth_type (BPTT.{DEEP,SHALLOW}) 129 | :param num_loops: The desired number of loops to unroll 130 | :return: A list of the graphs, connected by variables. 131 | """ 132 | frames = [] 133 | for loop in range(num_loops): 134 | # Scoping on top of each depth 135 | # We need 'False' for the first time and 'True' for all others 136 | with tf.variable_scope(self.LOOP_SCOPE, reuse=(loop != 0)): 137 | frames.append(func(self, depth_type)) 138 | 139 | return frames 140 | 141 | def generate_feed_dict(self, depth_type, data_array, num_settable): 142 | """ 143 | Generate a feed dictionary; takes in an array of the data that will be inserted into the unrolled 144 | placeholders. 145 | 146 | :param depth_type: The depth_type (BPTT.{DEEP,SHALLOW}) 147 | :param data_array: An array of arrays of data to insert into the unrolled placeholders 148 | :param num_settable: How many elements of the data_array to use. 149 | :return: A dictionary to feed into tf.Session().run() 150 | """ 151 | frames = self.graph_dict[depth_type] 152 | d = {} 153 | 154 | # Recurrent: Auto-defined placeholders / current variables 155 | for variable_name in self.placeholders[depth_type]: 156 | d[self.placeholders[depth_type][variable_name]] = self.state[depth_type][variable_name] 157 | 158 | # User-provided data to unroll/insert into the placeholders 159 | for frame_index in range(len(frames)): # Unroll index 160 | for var_index in range(num_settable): # Variable index 161 | frame_var = frames[frame_index][var_index] 162 | d[frame_var] = np.reshape(data_array[var_index][frame_index], 163 | frame_var.get_shape()) 164 | return d 165 | 166 | def copy_state_forward(self): 167 | """ 168 | Copy the working state from the DEEP pipeline to the SHALLOW pipeline 169 | """ 170 | for key in self.state[self.DEEP]: 171 | self.state[self.SHALLOW][key] = np.copy(self.state[self.DEEP][key]) 172 | 173 | def generate_output_definitions(self, depth_type): 174 | """ 175 | Generate the desired output variables to fetch from the graph run 176 | 177 | :param depth_type: The depth_type (BPTT.{DEEP,SHALLOW}) 178 | :return: An array of variables to add to the fetch list 179 | """ 180 | d = self.vars[depth_type] 181 | # Define consistent sort order by the variable names 182 | return [d[k] for k in sorted(d.keys())] 183 | 184 | def save_output_state(self, depth_type, arr): 185 | """ 186 | Save the working state for the next run (will be available in generate_feed_dict() in the next loop) 187 | 188 | :param depth_type: The depth_type (BPTT.{DEEP,SHALLOW}) 189 | :param arr: An array of values (returned by tf.Session.run()) which map to generate_output_definitions() 190 | """ 191 | d = self.state[depth_type] 192 | sorted_names = sorted(d.keys()) 193 | assert len(sorted_names) == len(arr), \ 194 | "Sent in the wrong number of variables (%s) to update state (%s)" % (len(arr), len(sorted_names)) 195 | for variable_index in range(len(sorted_names)): 196 | variable_name = sorted_names[variable_index] 197 | # Saved for next time. 198 | self.state[depth_type][variable_name] = arr[variable_index] 199 | 200 | 201 | -------------------------------------------------------------------------------- /example_bptt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import math 4 | import numpy as np 5 | import random 6 | import sys 7 | import tensorflow as tf 8 | import bptt 9 | 10 | # See https://medium.com/@devnag/ 11 | 12 | 13 | # Data parameters: simple one-number-at-a-time for now 14 | input_dimensions = 1 15 | output_dimensions = 1 16 | batch_size = 1 17 | 18 | # Model parameters 19 | lstm_width = 5 20 | m = 0.0 21 | s = 0.5 22 | init = tf.random_normal_initializer(m, s) 23 | noise_m = 0.0 24 | noise_s = 0.03 25 | 26 | # Optimization parameters 27 | learning_rate = 0.05 28 | beta1 = 0.95 29 | beta2 = .999 30 | epsilon = 1e-3 31 | momentum = 0.4 32 | gradient_clipping = 2.0 33 | unroll_depth = 4 34 | max_reset_loops = 20 35 | 36 | # Training parameters 37 | num_training_loops = 3000 38 | num_inference_loops = 100 39 | num_inference_warmup_loops = 1900 40 | 41 | 42 | def build_lstm_layer(bp, depth_type, layer_index, raw_x, width): 43 | """ 44 | Build a single LSTM layer (Graves 2013); can be stacked, but send in sequential layer_indexes to scope properly. 45 | """ 46 | global init, noise_m, noise_s 47 | # Define variable names 48 | h_name = "hidden-%s" % layer_index # Really the 'output' of the LSTM layer 49 | c_name = "cell-%s" % layer_index 50 | # raw_x is [input_size, 1] 51 | input_size = raw_x.get_shape()[0].value 52 | # Why so serious? Introduce a little anarchy. Upset the established order... 53 | x = raw_x + tf.random_normal(raw_x.get_shape(), noise_m, noise_s) 54 | 55 | with tf.variable_scope("lstm_layer_%s" % layer_index): 56 | 57 | # Define shapes for all the weights/biases, limited to just this layer (not shared with other layers) 58 | # Sizes are 'input_size' when mapping x and 'width' otherwise 59 | W_xi = tf.get_variable("W_xi", [width, input_size], initializer=init) 60 | W_hi = tf.get_variable("W_hi", [width, width], initializer=init) 61 | W_ci = tf.get_variable("W_ci", [width, width], initializer=init) 62 | b_i = tf.get_variable("b_i", [width, 1], initializer=init) 63 | W_xf = tf.get_variable("W_xf", [width, input_size], initializer=init) 64 | W_hf = tf.get_variable("W_hf", [width, width], initializer=init) 65 | W_cf = tf.get_variable("W_cf", [width, width], initializer=init) 66 | b_f = tf.get_variable("b_f", [width, 1], initializer=init) 67 | W_xc = tf.get_variable("W_xc", [width, input_size], initializer=init) 68 | W_hc = tf.get_variable("W_hc", [width, width], initializer=init) 69 | b_c = tf.get_variable("b_c", [width, 1], initializer=init) 70 | W_xo = tf.get_variable("W_xo", [width, input_size], initializer=init) 71 | W_ho = tf.get_variable("W_ho", [width, width], initializer=init) 72 | W_co = tf.get_variable("W_co", [width, width], initializer=init) 73 | b_o = tf.get_variable("b_o", [width, 1], initializer=init) 74 | 75 | # Retrieve the previous roll-depth's data, with starting random data if first roll-depth. 76 | h_past = bp.get_past_variable(h_name, np.float32(np.random.normal(m, s, [width, 1]))) 77 | c_past = bp.get_past_variable(c_name, np.float32(np.random.normal(m, s, [width, 1]))) 78 | 79 | # Build graph - looks almost like Alex Graves wrote it! 80 | i = tf.sigmoid(tf.matmul(W_xi, x) + tf.matmul(W_hi, h_past) + tf.matmul(W_ci, c_past) + b_i) 81 | f = tf.sigmoid(tf.matmul(W_xf, x) + tf.matmul(W_hf, h_past) + tf.matmul(W_cf, c_past) + b_f) 82 | c = bp.name_variable(c_name, tf.multiply(f, c_past) + tf.multiply(i, tf.tanh(tf.matmul(W_xc, x) + tf.matmul(W_hc, h_past) + b_c))) 83 | o = tf.sigmoid(tf.matmul(W_xo, x) + tf.matmul(W_ho, h_past) + tf.matmul(W_co, c) + b_o) 84 | h = bp.name_variable(h_name, tf.multiply(o, tf.tanh(c))) 85 | 86 | return [c, h] 87 | 88 | 89 | def build_dual_lstm_frame(bp, depth_type): 90 | """ 91 | Build a dual-layer LSTM followed by standard sigmoid/linear mapping 92 | """ 93 | global init, input_dimensions, output_dimensions, batch_size, lstm_width 94 | 95 | # I/O DATA 96 | input_placeholder = tf.placeholder(tf.float32, shape=(input_dimensions, batch_size)) 97 | output_placeholder = tf.placeholder(tf.float32, shape=(output_dimensions, batch_size)) 98 | 99 | last_output = input_placeholder 100 | for layer_index in range(2): 101 | [_, h] = build_lstm_layer(bp, depth_type, layer_index, last_output, lstm_width) 102 | last_output = h 103 | 104 | W = tf.get_variable("W", [1, lstm_width], initializer=init) 105 | b = tf.get_variable("b", [1,1], initializer=init) 106 | output_result = tf.sigmoid(tf.matmul(W, last_output) + b) 107 | 108 | # return array of whatever you want, but I/O placeholders FIRST. 109 | return [input_placeholder, output_placeholder, output_result] 110 | 111 | 112 | def palindrome(step): 113 | """ 114 | Turn sequential integers into a palindromic sequence (so look-ahead mapping is not a function, but requires state) 115 | """ 116 | return (5.0 - abs(float(step % 10) - 5.0)) / 10.0 117 | 118 | 119 | bp = None 120 | sess = None 121 | graphs = None 122 | done = False 123 | 124 | # Loop until you get out of a local minimum or you hit max reset loops 125 | for reset_loop_index in range(max_reset_loops): 126 | 127 | # Clean any previous loops 128 | if reset_loop_index > 0: 129 | tf.reset_default_graph() 130 | 131 | # Generate unrolled+shallow graphs 132 | bp = bptt.BPTT() 133 | graphs = bp.generate_graphs(build_dual_lstm_frame, unroll_depth) 134 | 135 | # Define loss and clip gradients 136 | error_vec = [[o - p] for [i, p, o] in graphs[bp.DEEP]] 137 | loss = tf.reduce_mean(tf.square(error_vec)) 138 | optimizer = tf.train.AdamOptimizer(learning_rate, beta1, beta2, epsilon) 139 | grads = optimizer.compute_gradients(loss) 140 | clipped_grads = [(tf.clip_by_value(grad, -gradient_clipping, gradient_clipping), var) for grad, var in grads] 141 | optimizer.apply_gradients(clipped_grads) 142 | train = optimizer.minimize(loss) 143 | 144 | # Boilerplate initialization 145 | init_op = tf.global_variables_initializer() 146 | sess = tf.Session() 147 | sess.run(init_op) 148 | reset = False 149 | 150 | print("=== Training the unrolled model (reset loop %s) ===" % (reset_loop_index)) 151 | 152 | for step in range(num_training_loops): 153 | # 1.) Generate the dictionary of I/O placeholder data 154 | start_index = step * unroll_depth 155 | in_data = np.array([palindrome(x) for x in range(start_index, start_index + unroll_depth)], dtype=np.float32) 156 | out_data = np.array([palindrome(x+1) for x in range(start_index, start_index + unroll_depth)], dtype=np.float32) 157 | 158 | # 2a.) Generate the working state to send in, along with data to insert into unrolled placeholders 159 | frame_dict = bp.generate_feed_dict(bp.DEEP, [in_data, out_data], 2) 160 | 161 | # 2b.) Define the output (training/loss) that we'd like to see (optional) 162 | session_out = [train, loss] + [o for [i, p, o] in graphs[bp.DEEP]] # calculated output 163 | 164 | # 3.) Define state variables to pull out as well. 165 | state_vars = bp.generate_output_definitions(bp.DEEP) 166 | session_out.extend(state_vars) 167 | 168 | # 4.) Execute the graph 169 | results = sess.run(session_out, feed_dict=frame_dict) 170 | 171 | # 5.) Extract the state for next training loop; need to make sure we have right part of result array 172 | bp.save_output_state(bp.DEEP, results[-len(state_vars):]) # for simple RNN 173 | 174 | # 6.) Show training progress; reset graph if loss is stagnant. 175 | if (step % 100) == 0: 176 | print("Loss: %s => %s (output: %s)" % (step, results[1], [str(x) for x in results[2:-len(state_vars)]])) 177 | sys.stdout.flush() 178 | 179 | if step >= 1000 and (results[1] > 0.01): 180 | print("\nResetting; loss (%s) is stagnating after 1k rounds...\n" % (results[1])) 181 | reset = True 182 | break # To next reset loop 183 | 184 | if not reset: 185 | break 186 | 187 | print("=== Evaluating on shallow model ===") 188 | 189 | # Copy final deep state from the training loop above to the shallow state. 190 | bp.copy_state_forward() 191 | [in_ph, out_ph, out_out] = graphs[bp.SHALLOW][0] 192 | 193 | # Evaluate one step at a time, and burn in first. 194 | for step in range(num_inference_loops + num_inference_warmup_loops): 195 | # 1.) Convert step to the palindromic sequence (current and look-ahead-by-one) 196 | in_value = palindrome(step) 197 | expected_out_value = palindrome(step+1) 198 | 199 | # 2.) Generate the feed dictionary to send in, both I/O data and recurrent variables 200 | frame_dict = bp.generate_feed_dict(bp.SHALLOW, np.array([[in_value]], np.float32), 1) 201 | 202 | # 3.) Define state variables to pull out 203 | session_out = [out_out] 204 | state_vars = bp.generate_output_definitions(bp.SHALLOW) 205 | session_out.extend(state_vars) 206 | 207 | # 4.) Execute the graph 208 | results = sess.run(session_out, feed_dict=frame_dict) 209 | 210 | # 5.) Extract/save state variables for the next loop 211 | bp.save_output_state(bp.SHALLOW, results[-len(state_vars):]) 212 | 213 | # 6.) How we doin'? 214 | if step > num_inference_warmup_loops: 215 | print("%s: %s => %s actual vs %s expected (diff: %s)" % 216 | (step, in_value, results[0][0][0], expected_out_value, expected_out_value - results[0][0][0])) 217 | sys.stdout.flush() 218 | 219 | 220 | 221 | 222 | --------------------------------------------------------------------------------