├── LICENSE ├── README.md ├── img └── compgraph.png ├── main.py ├── runall.sh ├── test ├── sample1 │ ├── checkpoint │ ├── cost_list_lstm.pickle │ ├── figure_lstm.png │ ├── model-0 │ ├── model-0.meta │ └── variable_dict.pickle └── sample2 │ ├── checkpoint │ ├── cost_list_lstm.pickle │ ├── figure_lstm.png │ ├── model-0 │ ├── model-0.meta │ └── variable_dict.pickle └── trainOptimizer.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Yutaro Yamada 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning-To-Learn: RNN-based Optimization 2 | 3 | TensorFlow implementation of [Learning to learn by gradient descent by gradient descent](https://arxiv.org/pdf/1606.04474v1.pdf). 4 | 5 | Since the computational graph of the architecture could be huge on MNIST and Cifar10, the current implementation only deals with the task 6 | on quadratic functions as described in Section 3.1 in the paper. The image below is from the paper (Figure 2 on page 4). 7 | 8 | ![compgraph](./img/compgraph.png) 9 | 10 | See the writeup for explanation of the paper [here](http://runopti.github.io/blog/2016/10/17/learningtolearn-1/) 11 | 12 | ## Requirements 13 | 14 | - Python 2.7 or higher 15 | - Numpy 16 | - Tensorflow 0.9 17 | 18 | 19 | ## Usage 20 | 21 | First train a rnn optimizer: 22 | ``` 23 | $ python trainOptimizer.py 24 | ``` 25 | which outputs variable_dict.pickle, model-0-0, and model-0-0.meta. Note that graph compilation at the begnning might take a while depending on the parameters. 26 | 27 | variable_dict.pickle contains the parameters of the rnn optimizer, which is to be used 28 | in the construction of the training graph in main.py. 29 | 30 | Note that unless the loss at the end is around 2.0 ~ 4.0, the RNN will not correctly update the parameters in main.py, so 31 | if you get more than 4.0, you might want to re-train the rnn oprimizer. 32 | 33 | Then, run 34 | ``` 35 | $ python main.py 36 | ``` 37 | 38 | The parameters in this model are: 39 | 40 | | params | |default | in the paper | 41 | |:-------------:|---- |:-------------:|:--:| 42 | | n_samplings |number of random function samplings |10 |100 | 43 | | n_unroll |number of steps to unroll in BPTT |20 |20| 44 | | n_dimension |the dimension of input data space |3 |10| 45 | | n_hidden |number of hidden units in RNN |5 |20| 46 | | n_layers |number of RNN layers |2 | 2| 47 | | max_epoch | only for main.py | 20 | NA | 48 | | optim_method | only for main.py | 'lstm' | 'lstm' | 49 | 50 | You can change these values by passing the arguments explicitly. For example: 51 | ``` 52 | $ python trainOptimizer.py --n_samplings 10 --n_unroll 4 --n_dimension 5 --n_hidden 3 --n_layers 2 53 | ``` 54 | 55 | Make sure you use the same values for both trainOptimizer.py and main.py. 56 | 57 | Instead of manually running the two python scripts, you can also run runall.sh to do the same. One run of runall.sh should be done in around 5 minutes with the default setting. 58 | 59 | ## Reference 60 | 61 | [Learning to learn by gradient descent by gradient descent](https://arxiv.org/pdf/1606.04474v1.pdf) 62 | 63 | 64 | ## LICENSE 65 | 66 | MIT License 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /img/compgraph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runopti/Learning-To-Learn/168e7d7a01775d139353633d0d5cea520909752f/img/compgraph.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | import argparse 5 | parser = argparse.ArgumentParser(description='Define parameters to learn optimizer.') 6 | parser.add_argument('--n_samplings', type=int, default=10) 7 | parser.add_argument('--n_unroll', type=int, default=20) 8 | parser.add_argument('--n_dimension', type=int, default=3) 9 | parser.add_argument('--n_hidden', type=int, default=5) 10 | parser.add_argument('--n_layers', type=int, default=2) 11 | parser.add_argument('--max_epoch', type=int, default=20) 12 | parser.add_argument('--optim_method', type=str, default="lstm") 13 | 14 | args = parser.parse_args() 15 | 16 | T = args.n_samplings # number of random function samplings. L will be averaged over T. 17 | n_unroll_in_m = args.n_unroll # the trained optimizers were unrolled for 20 steps. 18 | num_of_coordinates = n_dimension = args.n_dimension # the dimension of input data space 19 | hidden_size = args.n_hidden # This will be used when we actually use this rnn to generate a search direction for the next time step. 20 | num_layers = args.n_layers # n_layer LSTM architecture 21 | max_epoch = args.max_epoch 22 | 23 | def get_inputs(n_dimension, n): 24 | theta_param = np.random.randn(n_dimension, 1) 25 | 26 | W_inputs = np.random.randn(n, n_dimension, n_dimension) 27 | y_inputs = np.zeros([n, n_dimension, 1]) 28 | for i in range(n): 29 | y_inputs[i] = np.dot(W_inputs[i], theta_param) 30 | 31 | return W_inputs, y_inputs 32 | 33 | def restore_trained_optimizer_variables(): 34 | variable_dict = {} 35 | g1 = tf.Graph() 36 | with g1.as_default(): 37 | with tf.Session() as sess: 38 | saver = tf.train.import_meta_graph("./model-0.meta") 39 | saver.restore(sess, "./model-0") 40 | # I need to extract all model's parameters here and then construct another graph 41 | for var in tf.trainable_variables(): 42 | print(var.name) # # Getting the parameters inside the RNN. 43 | print(var.eval()) # store these values in numpy array. 44 | # mapping by names I think? 45 | variable_dict[var.name] = var.eval() # like this? 46 | return variable_dict 47 | 48 | 49 | def build_optimizer_graph(): 50 | ### BEGIN: GRAPH CONSTRCUTION ### 51 | grad_f = tf.placeholder(tf.float32, [n_dimension, 1]) 52 | 53 | cell_list = [] 54 | lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size) 55 | for i in range(num_of_coordinates): 56 | cell_list.append(tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_layers)) # num_layers = 2 according to the paper. 57 | batch_size = 1 58 | state_list = [cell_list[i].zero_state(batch_size, tf.float32) for i in range(num_of_coordinates)] 59 | sum_f = 0 60 | g_new_list = [] 61 | for i in range(num_of_coordinates): 62 | cell = cell_list[i]; state = state_list[i] 63 | grad_h_t = tf.slice(grad_f, begin=[i,0], size=[1,1]) 64 | 65 | for k in range(n_unroll_in_m): 66 | if k > 0: tf.get_variable_scope().reuse_variables() 67 | cell_output, state = cell(grad_h_t, state) # g_new should be a scalar b/c grad_h_t is a scalar 68 | softmax_w = tf.get_variable("softmax_w", [hidden_size, 1]) 69 | softmax_b = tf.get_variable("softmax_b", [1]) 70 | g_new_i = tf.matmul(cell_output, softmax_w) + softmax_b 71 | 72 | g_new_list.append(g_new_i) 73 | # state_list[i] = state # for the next t # I don't need this list right..? b/c I'm not using t...T thing. 74 | 75 | # Reshaping g_new 76 | g_new = tf.reshape(tf.squeeze(tf.pack(g_new_list)), [n_dimension, 1]) # should be a [10, 1] tensor 77 | 78 | 79 | return g_new, grad_f 80 | 81 | 82 | def build_training_graph(method): 83 | n = T 84 | W = tf.placeholder(tf.float32, shape=[n, n_dimension, n_dimension]) 85 | y = tf.placeholder(tf.float32, shape=[n, n_dimension, 1]) 86 | theta = tf.Variable(tf.truncated_normal([n_dimension, 1])) 87 | if method == "lstm": 88 | g_new = tf.placeholder(tf.float32, shape=[n_dimension, 1]) 89 | 90 | loss = 0 91 | for i in range(n): 92 | W_i = tf.reshape(tf.slice(W, begin=[i, 0, 0], size=[1, n_dimension, n_dimension]), [n_dimension, n_dimension]) 93 | y_i = tf.reshape(tf.slice(y, begin=[i, 0, 0], size = [1, n_dimension, 1]), [n_dimension, 1]) 94 | f = tf.reduce_sum(tf.square(tf.matmul(W_i, theta) - y_i)) # make this faster using tensor only 95 | loss += f 96 | loss /= (n*n_dimension) 97 | 98 | f_grad = tf.gradients(loss, theta)[0] 99 | 100 | if method == "SGD": 101 | train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss) 102 | # train_op = tf.train.AdamOptimizer().minimize(loss) 103 | return loss, train_op, W, y 104 | 105 | if method == "lstm": 106 | new_value = tf.add(theta, g_new) 107 | train_op = tf.assign(theta, new_value) # just to make it compatiable with method == "SGD case. 108 | 109 | return loss, f_grad, train_op, g_new, W, y 110 | 111 | def main(): 112 | # variable name convention 113 | # **_ph : placeholders 114 | # **_op : nodes other than placeholders. 115 | # **_val: actual numpy values. 116 | g = tf.Graph() 117 | with g.as_default(): 118 | with tf.Session() as sess: 119 | if args.optim_method == "lstm": 120 | loss_op, f_grad_op, train_op, g_new_ph, W_ph, y_ph = build_training_graph(method=args.optim_method) 121 | g_op, f_grad_ph = build_optimizer_graph() 122 | elif args.optim_method =="SGD": 123 | loss_op, train_op, W_ph, y_ph = build_training_graph(method=args.optim_method) 124 | else: 125 | print("Define optimization method.") 126 | exit() 127 | 128 | sess.run(tf.initialize_all_variables()) 129 | 130 | # Restore the trained optimizer in order to get the values 131 | # variable_dict = restore_trained_optimizer_variables() 132 | import pickle 133 | with open("variable_dict.pickle","rb") as f: 134 | variable_dict = pickle.load(f) 135 | 136 | ## INITIALIZATION BEGIN ## 137 | for var in tf.trainable_variables(): 138 | # assign values using the numpy arrays to the current graph. 139 | if var.name in variable_dict: 140 | assign_op = var.assign(variable_dict[var.name]) # the inside param has to be a np array like this: var.assign(np.ones(12)) 141 | sess.run(assign_op) 142 | 143 | W_val, y_val = get_inputs(n_dimension, T) 144 | ## INITIALIZATION DONE ## 145 | 146 | cost_list = [] 147 | if args.optim_method == "lstm": 148 | g_new_val = np.zeros([n_dimension,1]) 149 | for epoch in range(max_epoch): 150 | loss_val, f_grad_val, _ = sess.run([loss_op, f_grad_op, train_op], feed_dict={g_new_ph: g_new_val , W_ph: W_val, y_ph: y_val}) 151 | g_new_val = sess.run(g_op, feed_dict={f_grad_ph: f_grad_val}) 152 | print(loss_val) 153 | cost_list.append(loss_val) 154 | if args.optim_method == "SGD": 155 | for epoch in range(max_epoch): 156 | loss_val, _ = sess.run([loss_op, train_op], feed_dict={W_ph: W_val, y_ph: y_val}) 157 | print(loss_val) 158 | cost_list.append(loss_val) 159 | 160 | 161 | with open("cost_list_" + args.optim_method + ".pickle","wb") as f: 162 | pickle.dump(cost_list, f) 163 | 164 | # % matplotlib inline 165 | import matplotlib.pyplot as plt 166 | plt.plot(range(len(cost_list)), cost_list) 167 | imagename = "figure_" + args.optim_method + ".png" 168 | plt.savefig(imagename) 169 | 170 | if __name__ == "__main__": 171 | main() 172 | -------------------------------------------------------------------------------- /runall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | n_samplings=10 4 | n_unroll=20 5 | n_dimension=3 6 | n_hidden=5 7 | n_layers=2 8 | 9 | max_epoch=20 10 | 11 | python3 trainOptimizer.py --n_samplings $n_samplings --n_unroll $n_unroll --n_dimension $n_dimension --n_hidden $n_hidden --n_layers $n_layers 12 | 13 | python3 main.py --n_samplings $n_samplings --n_unroll $n_unroll --n_dimension $n_dimension --n_hidden $n_hidden --n_layers $n_layers --max_epoch $max_epoch 14 | -------------------------------------------------------------------------------- /test/sample1/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model-0" 2 | all_model_checkpoint_paths: "model-0" 3 | -------------------------------------------------------------------------------- /test/sample1/cost_list_lstm.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runopti/Learning-To-Learn/168e7d7a01775d139353633d0d5cea520909752f/test/sample1/cost_list_lstm.pickle -------------------------------------------------------------------------------- /test/sample1/figure_lstm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runopti/Learning-To-Learn/168e7d7a01775d139353633d0d5cea520909752f/test/sample1/figure_lstm.png -------------------------------------------------------------------------------- /test/sample1/model-0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runopti/Learning-To-Learn/168e7d7a01775d139353633d0d5cea520909752f/test/sample1/model-0 -------------------------------------------------------------------------------- /test/sample1/model-0.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runopti/Learning-To-Learn/168e7d7a01775d139353633d0d5cea520909752f/test/sample1/model-0.meta -------------------------------------------------------------------------------- /test/sample1/variable_dict.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runopti/Learning-To-Learn/168e7d7a01775d139353633d0d5cea520909752f/test/sample1/variable_dict.pickle -------------------------------------------------------------------------------- /test/sample2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model-0" 2 | all_model_checkpoint_paths: "model-0" 3 | -------------------------------------------------------------------------------- /test/sample2/cost_list_lstm.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runopti/Learning-To-Learn/168e7d7a01775d139353633d0d5cea520909752f/test/sample2/cost_list_lstm.pickle -------------------------------------------------------------------------------- /test/sample2/figure_lstm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runopti/Learning-To-Learn/168e7d7a01775d139353633d0d5cea520909752f/test/sample2/figure_lstm.png -------------------------------------------------------------------------------- /test/sample2/model-0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runopti/Learning-To-Learn/168e7d7a01775d139353633d0d5cea520909752f/test/sample2/model-0 -------------------------------------------------------------------------------- /test/sample2/model-0.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runopti/Learning-To-Learn/168e7d7a01775d139353633d0d5cea520909752f/test/sample2/model-0.meta -------------------------------------------------------------------------------- /test/sample2/variable_dict.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/runopti/Learning-To-Learn/168e7d7a01775d139353633d0d5cea520909752f/test/sample2/variable_dict.pickle -------------------------------------------------------------------------------- /trainOptimizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import argparse 5 | parser = argparse.ArgumentParser(description='Define parameters to learn optimizer.') 6 | parser.add_argument('--n_samplings', type=int, default=10) 7 | parser.add_argument('--n_unroll', type=int, default=20) 8 | parser.add_argument('--n_dimension', type=int, default=3) 9 | parser.add_argument('--n_hidden', type=int, default=5) 10 | parser.add_argument('--n_layers', type=int, default=2) 11 | 12 | args = parser.parse_args() 13 | 14 | T = args.n_samplings # number of random function samplings. L will be averaged over T. 15 | n_unroll_in_m = args.n_unroll # the trained optimizers were unrolled for 20 steps. 16 | num_of_coordinates = n_dimension = args.n_dimension # the dimension of input data space 17 | hidden_size = args.n_hidden # This will be used when we actually use this rnn to generate a search direction for the next time step. 18 | num_layers = args.n_layers # n_layer LSTM architecture 19 | 20 | 21 | def trainOptimizer(): 22 | g = tf.Graph() 23 | ### BEGIN: GRAPH CONSTRUCTION ### 24 | with g.as_default(): 25 | cell_list = [] 26 | lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size) 27 | for i in range(num_of_coordinates): 28 | cell_list.append(tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_layers)) # num_layers = 2 according to the paper. 29 | 30 | loss = 0 31 | for t in range(T): 32 | # random sampling of one instance of the quadratic function 33 | W = tf.truncated_normal([n_dimension, n_dimension]); y = tf.truncated_normal([n_dimension, 1]) 34 | theta = tf.truncated_normal([n_dimension, 1]) 35 | f = tf.reduce_sum(tf.square(tf.matmul(W, theta) - y)) 36 | batch_size = 1 37 | state_list = [cell_list[i].zero_state(batch_size, tf.float32) for i in range(num_of_coordinates)] 38 | sum_f = 0 39 | g_new_list = [] 40 | grad_f = tf.gradients(f, theta)[0] 41 | for i in range(num_of_coordinates): 42 | cell = cell_list[i]; state = state_list[i] 43 | grad_h_t = tf.slice(grad_f, begin=[i,0], size=[1,1]) 44 | for k in range(n_unroll_in_m): 45 | if k > 0: tf.get_variable_scope().reuse_variables() 46 | cell_output, state = cell(grad_h_t, state) # g_new should be a scalar b/c grad_h_t is a scalar 47 | softmax_w = tf.get_variable("softmax_w", [hidden_size, 1]) 48 | softmax_b = tf.get_variable("softmax_b", [1]) 49 | g_new_i = tf.matmul(cell_output, softmax_w) + softmax_b 50 | 51 | g_new_list.append(g_new_i) 52 | state_list[i] = state # for the next t 53 | 54 | # update parameter 55 | g_new = tf.reshape(tf.squeeze(tf.pack(g_new_list)), [n_dimension, 1]) # should be a [n_dimension, 1] tensor 56 | theta = tf.add(theta, g_new) 57 | 58 | f_at_theta_t = tf.reduce_sum(tf.square(tf.matmul(W, theta) - y)) 59 | sum_f = sum_f + f_at_theta_t 60 | 61 | loss += sum_f 62 | 63 | loss = loss / T 64 | 65 | tvars = tf.trainable_variables() # should be just the variable inside the RNN 66 | grads = tf.gradients(loss, tvars) 67 | lr = 0.001 # Technically I need to do random search to decide this 68 | optimizer = tf.train.AdamOptimizer(lr) 69 | train_op = optimizer.apply_gradients(zip(grads, tvars)) 70 | 71 | ### END: GRAPH CONSTRUCTION ### 72 | 73 | with tf.Session() as sess: 74 | sess.run(tf.initialize_all_variables()) 75 | max_epoch = 100 76 | for epoch in range(max_epoch): 77 | cost, _ = sess.run([loss, train_op]) 78 | print("Epoch %d : loss %f" % (epoch, cost)) 79 | 80 | print("Saving the trained model...") 81 | saver = tf.train.Saver() 82 | saver.save(sess, "model", global_step=0) 83 | 84 | import pickle 85 | import time 86 | print("Extracting variables...") 87 | now = time.time() 88 | variable_dict = {} 89 | for var in tf.trainable_variables(): 90 | print(var.name) 91 | print(var.eval()) 92 | variable_dict[var.name] = var.eval() 93 | print("elapsed time: {0}".format(time.time()-now)) 94 | with open("variable_dict.pickle", "wb") as f: 95 | pickle.dump(variable_dict,f) 96 | 97 | 98 | if __name__ == "__main__": 99 | trainOptimizer() 100 | 101 | 102 | 103 | --------------------------------------------------------------------------------