├── .gitignore ├── LICENSE ├── README.md ├── README.md~ ├── SIN_problem.pdf ├── S_LBFGS.py ├── S_LSR1.py ├── _saved_log_files └── S_LSR1.pkl ├── data_generation.py ├── main.py ├── network.py ├── parameters.py ├── sampleSY.py └── util_func.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Optimization and Machine Learning Group @ Lehigh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SQN: Sampled Quasi-Newton Methods for Deep Learning 2 | 3 | Authors: [Albert S. Berahas](https://sites.google.com/a/u.northwestern.edu/albertsberahas/home), [Majid Jahani](http://coral.ise.lehigh.edu/maj316/) and [Martin Takáč](http://mtakac.com/) 4 | 5 | Please contact us if you have any questions, suggestions, requests or bug-reports. 6 | 7 | ## Introduction 8 | This is a Python software package for solving a toy classification problem using neural networks. More specifically, the user can select one of two methods: 9 | - **sampled LBFGS (S-LBFGS)**, 10 | - **sampled LSR1 (S-LSR1)**, 11 | 12 | to solve the problem described below. See [paper](https://arxiv.org/abs/1901.09997) for details. 13 | 14 | Note, the code is extendible to solving other deep learning problems (see comments below). 15 | 16 | ## Problem 17 | Consider the following simple classification problem, illustrated in the figure below, consisting of two classes each with 50 data points. We call this the **sin classification problem**. We trained a small fully conncted neural network with sigmoid activation functions and 4 hidden layers with 2 nodes in each layer. 18 | 19 | 20 | 21 | ## Citation 22 | If you use SQN for your research, please cite: 23 | 24 | ``` 25 | @article{berahas2019quasi, 26 | title={Quasi-Newton Methods for Deep Learning: Forget the Past, Just Sample}, 27 | author={Berahas, Albert S and Jahani, Majid and Tak{\'a}{\v{c}}, Martin}, 28 | journal={arXiv preprint arXiv:1901.09997}, 29 | year={2019} 30 | } 31 | ``` 32 | 33 | ## Usage Guide 34 | The algorithms can be run using the syntax: ``` python main.py method``` 35 | where ```method = SLBFGS``` or ```method = SLSR1``` 36 | 37 | By default, the code runs on a single GPU. 38 | 39 | ### Dependencies 40 | * Numpy 41 | * [TensorFlow](https://www.tensorflow.org/)>=1.2 42 | 43 | ### Parameters 44 | In this section we describe all the parameters needed to run the methods on the **sin classification problem**. The list of all parameters is available in the ``parameters.py`` file. 45 | 46 | The parameters for the **problem** are: 47 | - ```num_pts```: Number of data points (per class) (default ```num_pts = 50```) 48 | - ```freq```: Frequency (default ```freq = 8```) 49 | - ```offset```: Offset (default ```offset = 0.8```) 50 | - ```activation```: Activation (default ```activation = "sigmoid"```) 51 | - ```FC1```, ```FC2```,..., ```FC6```: Network size (number of nodes in each hidden layer) (default all equation to ```2```) 52 | 53 | All these parameters can be changed in ```parameters.py```. Note that ```FC1``` and ```FC6``` should both be equal to ```2``` since this is the input and output size. 54 | 55 | The hyperparameters for the **methods** are: 56 | - ```seed```: Random seed (default ```seed = 67```) 57 | - ```numIter```: Maximum number of iterations (default ```numIter = 1000```) 58 | - ```mmr```: Memory length (default ```mmr = 10```) 59 | - ```radius```: Sampling radius (default ```radius = 1```) 60 | - ```eps```: Tolerance for updating QN matrices (default ```eps = 1e-8```) 61 | - ```eta```: TR tolerance (default ```eta = 1e-6```) 62 | - ```delta_init```: Initial trust region radius (default ```delta_init = 1```) 63 | - ```alpha_init```: Initial step length (default ```alpha_init = 1```) 64 | - ```epsTR```: Tolerance of CG Steinhaug (default ```epsTR = 1e-10```) 65 | - ```cArmijo```: Armijo sufficient decrease parameter (default ```cArmijo = 1e-4```) 66 | - ```rhoArmijo```: Armijo backtracking factor (default ```rhoArmijo = 0.5```) 67 | - ```init_sampling_SLBFGS```: Initial sampling SLBFGS (default ```init_sampling_SLBFGS = "on"```) 68 | 69 | All these parameters can be changed in ```parameters.py```. 70 | 71 | ### Functions 72 | In this section we describe all the functions needed to run the code. For both methods: 73 | - ```main.py```: This is the main file that runs the code for both methods. For each method: (1) gets the input parameters required (```parameters.py```), (2) gets the data for the **sin classification problem** (```data_generation.py```), (3) constructs the neural network (```network.py```), and (4) runs the method (```S_LBFGS.py``` or ```S_LSR1.py```). 74 | -```parameters.py```: Sets all the parameters. 75 | -```data_generation.py```: Generates the data. 76 | - ```network.py```: Constructs the neural network (function, gradient, Hessian and Hessian-vector products). 77 | - ```S_LBFGS.py```, ```S_LSR1.py```: Runs the **S-LBFGS** and **S-LSR1** methods, respectively. 78 | 79 | Each method has several method specific functions. For **S-LBFGS**: 80 | - ```L_BFGS_two_loop_recursion.py```: LBFGS two-loop recursion for computing the search direction. 81 | - ```sample_pairs_SY_SLBFGS.py```: Function for computing ```S```, ```Y``` curvature pairs. 82 | 83 | For **S-LSR1**: 84 | - ```CG_Steinhaug_matFree.py```: CG Steinhaug method for solving the TR subproblem and computing the search direction. 85 | - ```rootFinder.py```: Root finder subroutine used in the CG Steinhaug method. 86 | - ```sample_pairs_SY_SLSR1.py```: Function for computing ```S```, ```Y``` curvature pairs. 87 | 88 | The ```sample_pairs_SY_SLBFGS.py``` and ```sample_pairs_SY_SLSR1.py``` functions are in the ```sampleSY.py``` file. The rest of the functions are found in the ```util_func.py```. 89 | 90 | ### Logs & Printing 91 | All logs are stored in ``.pkl`` files in ``./_saved_log_files`` directory. The default outputs and what is printed at every iteration is: 92 | - Iteration counter, 93 | - Function value, 94 | - Accuracy, 95 | - Norm of the gradient, 96 | - Number of function evaluations, 97 | - Number of gradient evaluations, 98 | - Number of Hessian-vector products, 99 | - Total cost (# function evaluations + # gradient evaluations + # Hessian-vector products) 100 | - Elapsed time, 101 | - Step length (**S-LBFGS**) and TR radius (**S-LSR1**). 102 | 103 | ## Example 104 | 105 | Here, we provide the commands for running the two methods, and the output for the first 10 iterations of both methods. We then describe how one could use our code to solve different problems. 106 | 107 | ### Sampled LBFGS (S-LBFGS) 108 | 109 | To run the **S-LBFGS** method the syntax is: ```python main.py SLBFGS``` 110 | 111 | The output of the first 10 iterations is: 112 | ``` 113 | [0, 0.7568820772024124, 0.5, 0.3164452574498637, 1, 1, 0, 2, 0.03657197952270508, 2] 114 | [1, 0.6956258550774328, 0.5, 0.0674561314351428, 2, 2, 0, 4, 0.2941138744354248, 1] 115 | [2, 0.6928303728452202, 0.5, 0.009572266947966944, 4, 3, 1, 8, 0.7679529190063477, 1.0] 116 | [3, 0.692176991735801, 0.5, 0.013195938306183187, 6, 4, 2, 12, 1.0290610790252686, 1.0] 117 | [4, 0.6910937151406233, 0.5, 0.06540600186566126, 8, 5, 3, 16, 1.241429090499878, 1.0] 118 | [5, 0.6869002623690355, 0.5, 0.02262611533759857, 12, 6, 4, 22, 1.6189680099487305, 0.25] 119 | [6, 0.6778780169034967, 0.5, 0.07594391958980883, 23, 7, 5, 35, 2.000309944152832, 0.00048828125] 120 | [7, 0.677836543566537, 0.5, 0.07582481736644729, 24, 8, 6, 38, 2.510093927383423, 0.0009765625] 121 | [8, 0.6770420397015698, 0.5, 0.07593946324726104, 25, 9, 7, 41, 2.8690669536590576, 0.001953125] 122 | [9, 0.6769673131763045, 0.5, 0.07570366951301735, 26, 10, 8, 44, 3.1529359817504883, 0.00390625] 123 | [10, 0.6762432711651389, 0.5, 0.07603773166566419, 27, 11, 9, 47, 3.436460018157959, 0.0078125] 124 | ``` 125 | 126 | ### Sampled LSR1 (S-LSR1) 127 | 128 | To run the **S-LSR1** method the syntax is: ```python main.py SLSR1``` 129 | 130 | The output of the first 10 iterations is: 131 | ``` 132 | [0, 0.7568820772024124, 0.5, 0.3164452574498637, 1, 1, 1, 3, 0.48604512214660645, 1] 133 | [1, 0.6952455337961233, 0.5, 0.07691431196154719, 2, 2, 2, 6, 0.7635290622711182, 2] 134 | [2, 0.6795686674317041, 0.5, 0.09084038468631689, 3, 3, 3, 9, 1.0299029350280762, 2] 135 | [3, 0.6305025883514083, 0.6, 0.20005533664635822, 4, 4, 4, 12, 1.2226409912109375, 4] 136 | [4, 0.6305025883514083, 0.6, 0.20005533664635822, 5, 5, 5, 15, 1.6073269844055176, 2.0] 137 | [5, 0.6305025883514083, 0.6, 0.20005533664635822, 6, 6, 6, 18, 1.7895889282226562, 1.0] 138 | [6, 0.577640123436474, 0.82, 0.21616021189282156, 7, 7, 7, 21, 1.9722239971160889, 1.0] 139 | [7, 0.49736796460679933, 0.78, 0.14993963430571358, 8, 8, 8, 24, 2.1734681129455566, 1.0] 140 | [8, 0.42248074404077923, 0.87, 0.11864276500865208, 9, 9, 9, 27, 2.385266065597534, 2.0] 141 | [9, 0.33875392059040343, 0.83, 0.07470337441644294, 10, 10, 10, 30, 2.5926640033721924, 4.0] 142 | [10, 0.33875392059040343, 0.83, 0.07470337441644294, 11, 11, 11, 33, 2.7668380737304688, 2.0] 143 | ``` 144 | 145 | ### Other problems 146 | 147 | In order for a user to run the **S-LBFGS** and **S-LSR1** methods on different problems, there are a few things that must be modified: (1) the parameters of the neural network (Network size in ```parameters.py```), (2) the data (in ```data_generation.py```), and (3) the network (in ```network.py```). More specifically, a user need to construct or load his/her own data in the ```main.py``` function and then define a neural network in the ```DNN``` class to be used in ```network.py```. The latter also calculates Hessian (```H```), gradient (```G```), and Hessian-matrix product (```Hvs```) with respect to the new function; also it cointans the operators for updating (```updateOp```), assigning (```ASSIGN_OP```) and adding (```assign_add```) the parameters of the new network which will be adapted automatically with the new setting. 148 | 149 | 150 | 151 | If users have any issues, please contact us. 152 | 153 | ## Paper 154 | [Quasi-Newton Methods for Deep Learning: Forget the Past, Just Sample](https://arxiv.org/abs/1901.09997). 155 | 156 | -------------------------------------------------------------------------------- /README.md~: -------------------------------------------------------------------------------- 1 | # SQN 2 | Sampled Quasi-Newton Methods for Deep Learning 3 | 4 | ## Paper 5 | Implementation of our paper: [Quasi-Newton Methods for Deep Learning: Forget the Past, Just Sample](https://arxiv.org/abs/1901.09997). 6 | 7 | 8 | ## Dependencies 9 | * Numpy 10 | * [tensorflow](https://www.tensorflow.org/)>=1.2 11 | 12 | ## How to Run 13 | ### Train 14 | By default, the code is running in the training mode on a single gpu. For running the code, one can use the following command: 15 | ```bash 16 | python main.py 17 | ``` 18 | 19 | 20 | The list of specific parameters are available in the ``parameters.py`` file. 21 | 22 | 23 | ### Logs 24 | All logs are stored in ``.pkl`` file stored in ``./_saved_log_files`` directory. 25 | -------------------------------------------------------------------------------- /SIN_problem.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OptMLGroup/SQN/1fd9745b8221282f44288d55cdeaf777129381cd/SIN_problem.pdf -------------------------------------------------------------------------------- /S_LBFGS.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (C) 2019 Albert Berahas, Majid Jahani, Martin Takáč 4 | # 5 | # All Rights Reserved. 6 | # 7 | # Authors: Albert Berahas, Majid Jahani, Martin Takáč 8 | # 9 | # Please cite: 10 | # 11 | # A. S. Berahas, M. Jahani, and M. Takáč, "Quasi-Newton Methods for 12 | # Deep Learning: Forget the Past, Just Sample." (2019). Lehigh University. 13 | # http://arxiv.org/abs/1901.09997 14 | # ========================================================================== 15 | 16 | import numpy as np 17 | import pickle 18 | import os.path 19 | import os 20 | import sys 21 | import tensorflow as tf 22 | import time 23 | from util_func import * 24 | from network import * 25 | from data_generation import * 26 | from sampleSY import * 27 | 28 | # ========================================================================== 29 | def S_LBFGS(w_init,X,y,seed,numIter,mmr,radius,eps,alpha_init,cArmijo,rhoArmijo,num_weights,init_sampling_SLBFGS,dnn,sess): 30 | """Sampled LBFGS method.""" 31 | 32 | w = w_init 33 | sess.run(dnn.params.assign(w)) # Assign initial weights to parameters of the network 34 | np.random.seed(seed) # Set random seed 35 | 36 | print(seed) 37 | numFunEval = 0 # Initialize counters (function values, gradients and Hessians) 38 | numGradEval = 0 39 | numHessEval = 0 40 | 41 | gamma_k = 1 42 | 43 | g_kTemp, objFunOldTemp = sess.run( [dnn.G,[dnn.cross_entropy,dnn.accuracy]] , feed_dict={dnn.x: X, dnn.y:y}) 44 | numFunEval += 1 45 | numGradEval += 1 46 | objFunOld = objFunOldTemp[0] 47 | acc = objFunOldTemp[1] 48 | g_k = g_kTemp[0] 49 | norm_g = LA.norm( g_k ) 50 | 51 | HISTORY = [] 52 | weights_SLBFGS = [] 53 | 54 | k=0 55 | st=time.time() 56 | 57 | alpha = alpha_init 58 | 59 | while 1: 60 | 61 | weights_SLBFGS.append(sess.run(dnn.params)) 62 | 63 | HISTORY.append([k, objFunOld,acc,norm_g, numFunEval,numGradEval,numHessEval, numFunEval+numGradEval+numHessEval, 64 | time.time()-st,alpha]) 65 | 66 | print HISTORY[k] # Print History array 67 | 68 | if k > numIter or acc ==1: # Terminate if number of iterations > numIter or Accuracy = 1 69 | break 70 | 71 | if init_sampling_SLBFGS == "off" and k == 0: 72 | alpha = min(1,1.0/(np.linalg.norm(g_k, ord=1))) 73 | pk = g_k 74 | else: 75 | S,Y,counterSucc,numHessEval,gamma_k = sample_pairs_SY_SLBFGS(X,y,num_weights,mmr,radius,eps,dnn,numHessEval,sess) 76 | pk = L_BFGS_two_loop_recursion(g_k,S,Y,k,mmr,gamma_k,num_weights) 77 | alpha = 2*alpha # change to 2*alpha 78 | 79 | mArmijo = -(pk.T.dot(g_k)) 80 | 81 | x0 = sess.run(dnn.params) 82 | while 1: 83 | # params is the updated variable by adding -alpha* pk to the previous one 84 | sess.run(dnn.updateOp, feed_dict={dnn.updateVal: -alpha* pk }) 85 | 86 | objFunNew = sess.run(dnn.cross_entropy, feed_dict={dnn.x: X, dnn.y:y}) 87 | numFunEval += 1 88 | if objFunOld + alpha*cArmijo* mArmijo < objFunNew : 89 | sess.run(dnn.ASSIGN_OP, feed_dict={dnn.updateVal: x0}) 90 | alpha = alpha * rhoArmijo 91 | if alpha < 1e-25: 92 | print "issue with Armijo" 93 | break 94 | else: 95 | break 96 | objFunOld = objFunNew 97 | 98 | xNew, acc, g_k_newTemp = sess.run( [dnn.params,dnn.accuracy, dnn.G] , feed_dict={dnn.x: X, dnn.y:y}) 99 | numGradEval += 1 100 | g_k = g_k_newTemp[0] 101 | norm_g = LA.norm( g_k ) 102 | k += 1 103 | 104 | sess.run(dnn.ASSIGN_OP, feed_dict={dnn.updateVal: xNew}) 105 | 106 | pickle.dump( HISTORY, open( "./_saved_log_files/S_LBFGS.pkl", "wb" ) ) # Save History in .pkl file 107 | # pickle.dump( weights_SLBFGS, open( "./_saved_log_files/S_LBFGS_weights.pkl", "wb" ) ) # Save Weights in .pkl file 108 | -------------------------------------------------------------------------------- /S_LSR1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/enum_weights python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (C) 2019 Albert Berahas, Majid Jahani, Martin Takáč 4 | # 5 | # All Rights Reserved. 6 | # 7 | # Authors: Albert Berahas, Majid Jahani, Martin Takáč 8 | # 9 | # Please cite: 10 | # 11 | # A. S. Berahas, M. Jahani, and M. Takáč, "Quasi-Newton Methods for 12 | # Deep Learning: Forget the Past, Just Sample." (2019). Lehigh University. 13 | # http://arxiv.org/abs/1901.09997 14 | # ========================================================================== 15 | 16 | import numpy as np 17 | import pickle 18 | import os.path 19 | import os 20 | import sys 21 | import tensorflow as tf 22 | import time 23 | from util_func import * 24 | from network import * 25 | from data_generation import * 26 | from sampleSY import * 27 | 28 | # ========================================================================== 29 | def S_LSR1(w_init,X,y,seed,numIter,mmr,radius,eps,eta,delta_init,epsTR,num_weights,dnn,sess): 30 | """Sampled LSR1 method.""" 31 | 32 | w = w_init 33 | sess.run(dnn.params.assign(w)) # Assign initial weights to parameters of the network 34 | np.random.seed(seed) # Set random seed 35 | 36 | numFunEval = 0 # Initialize counters (function values, gradients and Hessians) 37 | numGradEval = 0 38 | numHessEval = 0 39 | 40 | deltak = delta_init # Initialize trust region radius 41 | 42 | HISTORY = [] # Initialize array for storage 43 | weights_SLSR1 = [] # Initialize array for storing weights 44 | 45 | k=0 # Initialize iteration counter 46 | st = time.time() # Start the timer 47 | 48 | objFunOld = sess.run(dnn.cross_entropy,feed_dict={dnn.x: X, dnn.y:y}) # Compute function value at current iterate 49 | numFunEval += 1 50 | 51 | print objFunOld 52 | 53 | # Method while loop (terminate after numIter or Accuracy 1 achieved) 54 | while 1: 55 | gradTemp, acc, xOld = sess.run([dnn.G,dnn.accuracy,dnn.params], 56 | feed_dict={dnn.x: X, dnn.y:y}) # Compute gradient and accuracy 57 | gard_k = gradTemp[0] 58 | numGradEval += 1 59 | norm_g = LA.norm(gard_k) 60 | 61 | # Sample S, Y pairs 62 | S,Y,counterSucc,numHessEval = sample_pairs_SY_SLSR1(X,y,num_weights,mmr,radius,eps,dnn,numHessEval,sess) 63 | 64 | # Append to History array 65 | HISTORY.append([k, objFunOld,acc,norm_g,numFunEval,numGradEval,numHessEval,numFunEval+numGradEval+numHessEval, 66 | counterSucc,time.time()-st,deltak]) 67 | print HISTORY[k] # Print History array 68 | 69 | if k > numIter or acc ==1: # Terminate if number of iterations > numIter or Accuracy = 1 70 | break 71 | 72 | weights_SLSR1.append(sess.run(dnn.params)) # Append weights 73 | 74 | 75 | sk_TR = CG_Steinhaug_matFree(epsTR, gard_k , deltak,S,Y,num_weights) # Compute step using CG Steinhaug 76 | sess.run(dnn.ASSIGN_OP, feed_dict={dnn.updateVal: xOld + sk_TR }) # Assign new weights 77 | 78 | objFunNew = sess.run(dnn.cross_entropy, feed_dict={dnn.x: X, dnn.y:y}) # Compute new function value 79 | numFunEval += 1 80 | 81 | ared = objFunOld - objFunNew # Compute actual reduction 82 | 83 | Lp = np.zeros((Y.shape[1],Y.shape[1])) 84 | for ii in xrange(Y.shape[1]): 85 | for jj in range(0,ii): 86 | Lp[ii,jj] = S[:,ii].dot(Y[:,jj]) 87 | tmpp = np.sum((S * Y),axis=0) 88 | Dp = np.diag(tmpp) 89 | Mp = (Dp + Lp + Lp.T) 90 | Minvp = np.linalg.inv(Mp) 91 | tmpp1 = np.matmul(Y.T,sk_TR) 92 | tmpp2 = np.matmul(Minvp,tmpp1) 93 | Bk_skTR = np.matmul(Y,tmpp2) 94 | pred = -(gard_k.T.dot(sk_TR) + 0.5* sk_TR.T.dot(Bk_skTR)) # Compute predicted reduction 95 | 96 | # Take step 97 | if ared/pred > eta: 98 | xNew = xOld + sk_TR 99 | objFunOld = objFunNew 100 | else: 101 | xNew = xOld 102 | 103 | # Update trust region radius 104 | if ared/pred > 0.75: 105 | deltak = 2*deltak 106 | elif ared/pred>=0.1 and ared/pred <=0.75: 107 | pass # no need to change deltak 108 | elif ared/pred<0.1: 109 | deltak = deltak*0.5 110 | 111 | k += 1 # Increment iteration counter 112 | sess.run(dnn.ASSIGN_OP, feed_dict={dnn.updateVal: xNew}) # Assign updated weights 113 | 114 | pickle.dump( HISTORY, open( "./_saved_log_files/S_LSR1.pkl", "wb" ) ) # Save History in .pkl file 115 | # pickle.dump( weights_SLSR1, open( "./_saved_log_files/S_LSR1_weights.pkl", "wb" ) ) # Save Weights in .pkl file 116 | -------------------------------------------------------------------------------- /_saved_log_files/S_LSR1.pkl: -------------------------------------------------------------------------------- 1 | (lp0 2 | (lp1 3 | I0 4 | acnumpy.core.multiarray 5 | scalar 6 | p2 7 | (cnumpy 8 | dtype 9 | p3 10 | (S'f8' 11 | p4 12 | I0 13 | I1 14 | tp5 15 | Rp6 16 | (I3 17 | S'<' 18 | p7 19 | NNNI-1 20 | I-1 21 | I0 22 | tp8 23 | bS'\xbfi\x10\xc3`8\xe8?' 24 | p9 25 | tp10 26 | Rp11 27 | ag2 28 | (g3 29 | (S'f4' 30 | p12 31 | I0 32 | I1 33 | tp13 34 | Rp14 35 | (I3 36 | S'<' 37 | p15 38 | NNNI-1 39 | I-1 40 | I0 41 | tp16 42 | bS'\x00\x00\x00?' 43 | p17 44 | tp18 45 | Rp19 46 | ag2 47 | (g6 48 | S'{,\xee\x9b\xa3@\xd4?' 49 | p20 50 | tp21 51 | Rp22 52 | aI1 53 | aI1 54 | aI1 55 | aI3 56 | aI10 57 | aF0.5641310214996338 58 | aI1 59 | aa(lp23 60 | I1 61 | ag2 62 | (g6 63 | S'm\x05\xcb\x8fs?\xe6?' 64 | p24 65 | tp25 66 | Rp26 67 | ag2 68 | (g14 69 | S'\x00\x00\x00?' 70 | p27 71 | tp28 72 | Rp29 73 | ag2 74 | (g6 75 | S'\xa0\x1cx\x06\xa8\xb0\xb3?' 76 | p30 77 | tp31 78 | Rp32 79 | aI2 80 | aI2 81 | aI2 82 | aI6 83 | aI10 84 | aF0.8752388954162598 85 | aI2 86 | aa(lp33 87 | I2 88 | ag2 89 | (g6 90 | S'\xce,@\xca\x06\xbf\xe5?' 91 | p34 92 | tp35 93 | Rp36 94 | ag2 95 | (g14 96 | S'\x00\x00\x00?' 97 | p37 98 | tp38 99 | Rp39 100 | ag2 101 | (g6 102 | S'\x14@b\xc1PA\xb7?' 103 | p40 104 | tp41 105 | Rp42 106 | aI3 107 | aI3 108 | aI3 109 | aI9 110 | aI10 111 | aF1.0731980800628662 112 | aI2 113 | aa(lp43 114 | I3 115 | ag2 116 | (g6 117 | S'\x9fg\xa0\xc3\x13-\xe4?' 118 | p44 119 | tp45 120 | Rp46 121 | ag2 122 | (g14 123 | S'\x9a\x99\x19?' 124 | p47 125 | tp48 126 | Rp49 127 | ag2 128 | (g6 129 | S'\x10\xa8$\xcci\x9b\xc9?' 130 | p50 131 | tp51 132 | Rp52 133 | aI4 134 | aI4 135 | aI4 136 | aI12 137 | aI10 138 | aF1.307805061340332 139 | aI4 140 | aa(lp53 141 | I4 142 | ag46 143 | ag2 144 | (g14 145 | S'\x9a\x99\x19?' 146 | p54 147 | tp55 148 | Rp56 149 | ag2 150 | (g6 151 | S'\x10\xa8$\xcci\x9b\xc9?' 152 | p57 153 | tp58 154 | Rp59 155 | aI5 156 | aI5 157 | aI5 158 | aI15 159 | aI10 160 | aF1.4928619861602783 161 | aF2.0 162 | aa(lp60 163 | I5 164 | ag46 165 | ag2 166 | (g14 167 | S'\x9a\x99\x19?' 168 | p61 169 | tp62 170 | Rp63 171 | ag2 172 | (g6 173 | S'\x10\xa8$\xcci\x9b\xc9?' 174 | p64 175 | tp65 176 | Rp66 177 | aI6 178 | aI6 179 | aI6 180 | aI18 181 | aI10 182 | aF1.6996049880981445 183 | aF1.0 184 | aa(lp67 185 | I6 186 | ag2 187 | (g6 188 | S'\xbf\x8b\xe0#\x07|\xe2?' 189 | p68 190 | tp69 191 | Rp70 192 | ag2 193 | (g14 194 | S'\x85\xebQ?' 195 | p71 196 | tp72 197 | Rp73 198 | ag2 199 | (g6 200 | S'5WcH#\xab\xcb?' 201 | p74 202 | tp75 203 | Rp76 204 | aI7 205 | aI7 206 | aI7 207 | aI21 208 | aI10 209 | aF1.8986470699310303 210 | aF1.0 211 | aa(lp77 212 | I7 213 | ag2 214 | (g6 215 | S'N\x1d\x84q\xe0\xd4\xdf?' 216 | p78 217 | tp79 218 | Rp80 219 | ag2 220 | (g14 221 | S'\x14\xaeG?' 222 | p81 223 | tp82 224 | Rp83 225 | ag2 226 | (g6 227 | S'\x81\xce\xdb\xd081\xc3?' 228 | p84 229 | tp85 230 | Rp86 231 | aI8 232 | aI8 233 | aI8 234 | aI24 235 | aI10 236 | aF2.1121370792388916 237 | aF1.0 238 | aa(lp87 239 | I8 240 | ag2 241 | (g6 242 | S'\xbd\x12\xb6\xac\xec\t\xdb?' 243 | p88 244 | tp89 245 | Rp90 246 | ag2 247 | (g14 248 | S'R\xb8^?' 249 | p91 250 | tp92 251 | Rp93 252 | ag2 253 | (g6 254 | S'-\x82\x9eK__\xbe?' 255 | p94 256 | tp95 257 | Rp96 258 | aI9 259 | aI9 260 | aI9 261 | aI27 262 | aI10 263 | aF2.3221309185028076 264 | aF2.0 265 | aa(lp97 266 | I9 267 | ag2 268 | (g6 269 | S'\xce\xf6\x94\xec$\xae\xd5?' 270 | p98 271 | tp99 272 | Rp100 273 | ag2 274 | (g14 275 | S'\xe1zT?' 276 | p101 277 | tp102 278 | Rp103 279 | ag2 280 | (g6 281 | S'\xb1\xfb\x04\xa6\xc2\x1f\xb3?' 282 | p104 283 | tp105 284 | Rp106 285 | aI10 286 | aI10 287 | aI10 288 | aI30 289 | aI10 290 | aF2.564138889312744 291 | aF4.0 292 | aa(lp107 293 | I10 294 | ag100 295 | ag2 296 | (g14 297 | S'\xe1zT?' 298 | p108 299 | tp109 300 | Rp110 301 | ag2 302 | (g6 303 | S'\xb1\xfb\x04\xa6\xc2\x1f\xb3?' 304 | p111 305 | tp112 306 | Rp113 307 | aI11 308 | aI11 309 | aI11 310 | aI33 311 | aI10 312 | aF2.79461407661438 313 | aF2.0 314 | aa(lp114 315 | I11 316 | ag100 317 | ag2 318 | (g14 319 | S'\xe1zT?' 320 | p115 321 | tp116 322 | Rp117 323 | ag2 324 | (g6 325 | S'\xb1\xfb\x04\xa6\xc2\x1f\xb3?' 326 | p118 327 | tp119 328 | Rp120 329 | aI12 330 | aI12 331 | aI12 332 | aI36 333 | aI10 334 | aF2.983212947845459 335 | aF1.0 336 | aa(lp121 337 | I12 338 | ag100 339 | ag2 340 | (g14 341 | S'\xe1zT?' 342 | p122 343 | tp123 344 | Rp124 345 | ag2 346 | (g6 347 | S'\xb1\xfb\x04\xa6\xc2\x1f\xb3?' 348 | p125 349 | tp126 350 | Rp127 351 | aI13 352 | aI13 353 | aI13 354 | aI39 355 | aI10 356 | aF3.18391489982605 357 | aF0.5 358 | aa(lp128 359 | I13 360 | ag2 361 | (g6 362 | S'\x1c\xa8|\xfb\xd4\xdb\xd4?' 363 | p129 364 | tp130 365 | Rp131 366 | ag2 367 | (g14 368 | S'\xaeGa?' 369 | p132 370 | tp133 371 | Rp134 372 | ag2 373 | (g6 374 | S'\xd5\x80az\xea@\xa2?' 375 | p135 376 | tp136 377 | Rp137 378 | aI14 379 | aI14 380 | aI14 381 | aI42 382 | aI10 383 | aF3.408785104751587 384 | aF0.5 385 | aa(lp138 386 | I14 387 | ag2 388 | (g6 389 | S'\xe5V\xd8\xe3\x8e\x84\xd4?' 390 | p139 391 | tp140 392 | Rp141 393 | ag2 394 | (g14 395 | S'=\nW?' 396 | p142 397 | tp143 398 | Rp144 399 | ag2 400 | (g6 401 | S'\x85\xe3\x1bA\xd4C\xab?' 402 | p145 403 | tp146 404 | Rp147 405 | aI15 406 | aI15 407 | aI15 408 | aI45 409 | aI10 410 | aF3.648300886154175 411 | aF0.5 412 | aa(lp148 413 | I15 414 | ag2 415 | (g6 416 | S'\x0e\x17\x08X/\x02\xd4?' 417 | p149 418 | tp150 419 | Rp151 420 | ag2 421 | (g14 422 | S'\xf6(\\?' 423 | p152 424 | tp153 425 | Rp154 426 | ag2 427 | (g6 428 | S'\xa98\x8d,\x04S\x95?' 429 | p155 430 | tp156 431 | Rp157 432 | aI16 433 | aI16 434 | aI16 435 | aI48 436 | aI10 437 | aF3.837692975997925 438 | aF1.0 439 | aa(lp158 440 | I16 441 | ag2 442 | (g6 443 | S'# <\xc5\xa1j\xd3?' 444 | p159 445 | tp160 446 | Rp161 447 | ag2 448 | (g14 449 | S'\xf6(\\?' 450 | p162 451 | tp163 452 | Rp164 453 | ag2 454 | (g6 455 | S'x0\x0b*>\x0e\xa1?' 456 | p165 457 | tp166 458 | Rp167 459 | aI17 460 | aI17 461 | aI17 462 | aI51 463 | aI10 464 | aF4.054593086242676 465 | aF2.0 466 | aa(lp168 467 | I17 468 | ag161 469 | ag2 470 | (g14 471 | S'\xf6(\\?' 472 | p169 473 | tp170 474 | Rp171 475 | ag2 476 | (g6 477 | S'x0\x0b*>\x0e\xa1?' 478 | p172 479 | tp173 480 | Rp174 481 | aI18 482 | aI18 483 | aI18 484 | aI54 485 | aI10 486 | aF4.220589876174927 487 | aF1.0 488 | aa(lp175 489 | I18 490 | ag2 491 | (g6 492 | S'=\xfe\xf0r(2\xd3?' 493 | p176 494 | tp177 495 | Rp178 496 | ag2 497 | (g14 498 | S'\x85\xebQ?' 499 | p179 500 | tp180 501 | Rp181 502 | ag2 503 | (g6 504 | S'\xdf\x03\x82\xe1\xe1l\xc5?' 505 | p182 506 | tp183 507 | Rp184 508 | aI19 509 | aI19 510 | aI19 511 | aI57 512 | aI10 513 | aF4.408884048461914 514 | aF1.0 515 | aa(lp185 516 | I19 517 | ag2 518 | (g6 519 | S']\xa9\xe6-@\x95\xd2?' 520 | p186 521 | tp187 522 | Rp188 523 | ag2 524 | (g14 525 | S'\xf6(\\?' 526 | p189 527 | tp190 528 | Rp191 529 | ag2 530 | (g6 531 | S'0\x16\xfdEp\xef\xb1?' 532 | p192 533 | tp193 534 | Rp194 535 | aI20 536 | aI20 537 | aI20 538 | aI60 539 | aI10 540 | aF4.623905897140503 541 | aF1.0 542 | aa(lp195 543 | I20 544 | ag2 545 | (g6 546 | S'\x1f\xc6~=!9\xd2?' 547 | p196 548 | tp197 549 | Rp198 550 | ag2 551 | (g14 552 | S'\xe1zT?' 553 | p199 554 | tp200 555 | Rp201 556 | ag2 557 | (g6 558 | S'L7CYvy\xa9?' 559 | p202 560 | tp203 561 | Rp204 562 | aI21 563 | aI21 564 | aI21 565 | aI63 566 | aI10 567 | aF4.827118873596191 568 | aF1.0 569 | aa(lp205 570 | I21 571 | ag2 572 | (g6 573 | S'S6\x1ew\x8d\xa3\xd1?' 574 | p206 575 | tp207 576 | Rp208 577 | ag2 578 | (g14 579 | S'\x9a\x99Y?' 580 | p209 581 | tp210 582 | Rp211 583 | ag2 584 | (g6 585 | S'\x90qx\x10\xa2\xc8\xa3?' 586 | p212 587 | tp213 588 | Rp214 589 | aI22 590 | aI22 591 | aI22 592 | aI66 593 | aI10 594 | aF5.051961898803711 595 | aF2.0 596 | aa(lp215 597 | I22 598 | ag208 599 | ag2 600 | (g14 601 | S'\x9a\x99Y?' 602 | p216 603 | tp217 604 | Rp218 605 | ag2 606 | (g6 607 | S'\x90qx\x10\xa2\xc8\xa3?' 608 | p219 609 | tp220 610 | Rp221 611 | aI23 612 | aI23 613 | aI23 614 | aI69 615 | aI10 616 | aF5.240300893783569 617 | aF1.0 618 | aa(lp222 619 | I23 620 | ag2 621 | (g6 622 | S'\xfd\xe0,\xaa/&\xd1?' 623 | p223 624 | tp224 625 | Rp225 626 | ag2 627 | (g14 628 | S'\xe1zT?' 629 | p226 630 | tp227 631 | Rp228 632 | ag2 633 | (g6 634 | S'\xae\xc5\xc3u\x19:\xaa?' 635 | p229 636 | tp230 637 | Rp231 638 | aI24 639 | aI24 640 | aI24 641 | aI72 642 | aI10 643 | aF5.4428019523620605 644 | aF1.0 645 | aa(lp232 646 | I24 647 | ag2 648 | (g6 649 | S'\x91n\xff\xddc\x0c\xd0?' 650 | p233 651 | tp234 652 | Rp235 653 | ag2 654 | (g14 655 | S'\xf6(\\?' 656 | p236 657 | tp237 658 | Rp238 659 | ag2 660 | (g6 661 | S'\xf7u`\t\x03\xdb\x97?' 662 | p239 663 | tp240 664 | Rp241 665 | aI25 666 | aI25 667 | aI25 668 | aI75 669 | aI10 670 | aF5.624152898788452 671 | aF2.0 672 | aa(lp242 673 | I25 674 | ag2 675 | (g6 676 | S'M`\x01)4\xe6\xcf?' 677 | p243 678 | tp244 679 | Rp245 680 | ag2 681 | (g14 682 | S'\xf6(\\?' 683 | p246 684 | tp247 685 | Rp248 686 | ag2 687 | (g6 688 | S'\xe9\x90\xaa\xdb\x93\xb8\xc6?' 689 | p249 690 | tp250 691 | Rp251 692 | aI26 693 | aI26 694 | aI26 695 | aI78 696 | aI10 697 | aF5.819897890090942 698 | aF1.0 699 | aa(lp252 700 | I26 701 | ag2 702 | (g6 703 | S'sLC]\xa3:\xcc?' 704 | p253 705 | tp254 706 | Rp255 707 | ag2 708 | (g14 709 | S'\xf6(\\?' 710 | p256 711 | tp257 712 | Rp258 713 | ag2 714 | (g6 715 | S'\xf9\xa5\x9b{\xdfG\xc0?' 716 | p259 717 | tp260 718 | Rp261 719 | aI27 720 | aI27 721 | aI27 722 | aI81 723 | aI10 724 | aF6.013213872909546 725 | aF1.0 726 | aa(lp262 727 | I27 728 | ag2 729 | (g6 730 | S'\x7fF\xc3J\x8b\x8b\xca?' 731 | p263 732 | tp264 733 | Rp265 734 | ag2 735 | (g14 736 | S'\xaeGa?' 737 | p266 738 | tp267 739 | Rp268 740 | ag2 741 | (g6 742 | S'\x1c\xcf\x7f\x15`\xff\xae?' 743 | p269 744 | tp270 745 | Rp271 746 | aI28 747 | aI28 748 | aI28 749 | aI84 750 | aI10 751 | aF6.210946083068848 752 | aF1.0 753 | aa(lp272 754 | I28 755 | ag2 756 | (g6 757 | S'\x83b\xc7\xc1\xbc\x0b\xca?' 758 | p273 759 | tp274 760 | Rp275 761 | ag2 762 | (g14 763 | S'R\xb8^?' 764 | p276 765 | tp277 766 | Rp278 767 | ag2 768 | (g6 769 | S'\x92\x1d\xa7\xecNj\xd5?' 770 | p279 771 | tp280 772 | Rp281 773 | aI29 774 | aI29 775 | aI29 776 | aI87 777 | aI10 778 | aF6.4020280838012695 779 | aF1.0 780 | aa(lp282 781 | I29 782 | ag2 783 | (g6 784 | S'\x93\xa8\x0c\r\xbf#\xc9?' 785 | p283 786 | tp284 787 | Rp285 788 | ag2 789 | (g14 790 | S'fff?' 791 | p286 792 | tp287 793 | Rp288 794 | ag2 795 | (g6 796 | S'\x0bP\xc3\xc67\xcb\xc5?' 797 | p289 798 | tp290 799 | Rp291 800 | aI30 801 | aI30 802 | aI30 803 | aI90 804 | aI10 805 | aF6.63359808921814 806 | aF1.0 807 | aa(lp292 808 | I30 809 | ag2 810 | (g6 811 | S'\x1b@X\xd3\x1a#\xc8?' 812 | p293 813 | tp294 814 | Rp295 815 | ag2 816 | (g14 817 | S'R\xb8^?' 818 | p296 819 | tp297 820 | Rp298 821 | ag2 822 | (g6 823 | S'\xae\xbb\xad\xf1\xa0G\xd7?' 824 | p299 825 | tp300 826 | Rp301 827 | aI31 828 | aI31 829 | aI31 830 | aI93 831 | aI10 832 | aF6.8434789180755615 833 | aF1.0 834 | aa(lp302 835 | I31 836 | ag2 837 | (g6 838 | S'\xfb\x9b\x83\x93s\x8d\xc4?' 839 | p303 840 | tp304 841 | Rp305 842 | ag2 843 | (g14 844 | S'\xd7\xa3p?' 845 | p306 846 | tp307 847 | Rp308 848 | ag2 849 | (g6 850 | S' 2\xa3[\xa8!\xe8?' 851 | p309 852 | tp310 853 | Rp311 854 | aI32 855 | aI32 856 | aI32 857 | aI96 858 | aI10 859 | aF7.019865036010742 860 | aF1.0 861 | aa(lp312 862 | I32 863 | ag2 864 | (g6 865 | S'\x03\x92\xa3\xe5d"\xc2?' 866 | p313 867 | tp314 868 | Rp315 869 | ag2 870 | (g14 871 | S'fff?' 872 | p316 873 | tp317 874 | Rp318 875 | ag2 876 | (g6 877 | S'\xf0\x8cZ!Q\xa6\xc6?' 878 | p319 879 | tp320 880 | Rp321 881 | aI33 882 | aI33 883 | aI33 884 | aI99 885 | aI10 886 | aF7.212146997451782 887 | aF1.0 888 | aa(lp322 889 | I33 890 | ag2 891 | (g6 892 | S'\x8cU]\x15\x9b\xaf\xc0?' 893 | p323 894 | tp324 895 | Rp325 896 | ag2 897 | (g14 898 | S'\x8f\xc2u?' 899 | p326 900 | tp327 901 | Rp328 902 | ag2 903 | (g6 904 | S'\xd8\x91\x16\x17\x18\xb9\xd5?' 905 | p329 906 | tp330 907 | Rp331 908 | aI34 909 | aI34 910 | aI34 911 | aI102 912 | aI10 913 | aF7.434514045715332 914 | aF1.0 915 | aa(lp332 916 | I34 917 | ag325 918 | ag2 919 | (g14 920 | S'\x8f\xc2u?' 921 | p333 922 | tp334 923 | Rp335 924 | ag2 925 | (g6 926 | S'\xd8\x91\x16\x17\x18\xb9\xd5?' 927 | p336 928 | tp337 929 | Rp338 930 | aI35 931 | aI35 932 | aI35 933 | aI105 934 | aI10 935 | aF7.6767919063568115 936 | aF0.5 937 | aa(lp339 938 | I35 939 | ag2 940 | (g6 941 | S'V\x0cnq\xa0=\xb6?' 942 | p340 943 | tp341 944 | Rp342 945 | ag2 946 | (g14 947 | S'\xa4p}?' 948 | p343 949 | tp344 950 | Rp345 951 | ag2 952 | (g6 953 | S'O\x06\\\xde*!\xc4?' 954 | p346 955 | tp347 956 | Rp348 957 | aI36 958 | aI36 959 | aI36 960 | aI108 961 | aI10 962 | aF7.8857550621032715 963 | aF1.0 964 | aa(lp349 965 | I36 966 | ag2 967 | (g6 968 | S'Wa\x90W\x93\xb8\xab?' 969 | p350 970 | tp351 971 | Rp352 972 | ag2 973 | (g14 974 | S'\x00\x00\x80?' 975 | p353 976 | tp354 977 | Rp355 978 | ag2 979 | (g6 980 | S'`g\xd5\xd1\xf9\xec\xdb?' 981 | p356 982 | tp357 983 | Rp358 984 | aI37 985 | aI37 986 | aI37 987 | aI111 988 | aI10 989 | aF8.124665021896362 990 | aF2.0 991 | aa. -------------------------------------------------------------------------------- /data_generation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (C) 2019 Albert Berahas, Majid Jahani, Martin Takáč 4 | # 5 | # All Rights Reserved. 6 | # 7 | # Authors: Albert Berahas, Majid Jahani, Martin Takáč 8 | # 9 | # Please cite: 10 | # 11 | # A. S. Berahas, M. Jahani, and M. Takáč, "Quasi-Newton Methods for 12 | # Deep Learning: Forget the Past, Just Sample." (2019). Lehigh University. 13 | # http://arxiv.org/abs/1901.09997 14 | # ========================================================================== 15 | 16 | import numpy as np 17 | 18 | # ========================================================================== 19 | def getData(num_pts = 50, freq = 8.0, offset = 0.8): 20 | """Get and return the data.""" 21 | 22 | # Create array with num_pts points between 0 and 1 (i.e., 0,1/num_pts, 2/num_pts,...) 23 | xx = np.array(range(num_pts))*1.0/(num_pts+.0) 24 | # Create positive (xp) and negative (xn) classes 25 | xp = np.sin(freq*xx)+offset 26 | xn = np.sin(freq*xx)-offset 27 | 28 | # Concatenate the two arrays into list and reshape 29 | X = [ [xx.tolist()+xx.tolist()],[xp.tolist()+xn.tolist()]] 30 | X = np.reshape(np.array(X),[2,-1]) 31 | 32 | # Create labels Y 33 | Y = [1 for _ in xrange(num_pts)] 34 | Y = Y + [0 for _ in xrange(num_pts)] 35 | 36 | ns = len(Y) 37 | Y = np.array(Y) 38 | X = np.transpose(X) 39 | y = np.zeros([ns,2]) 40 | for i in xrange(ns): 41 | y[i,Y[i]] = 1 42 | 43 | return X,y 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (C) 2019 Albert Berahas, Majid Jahani, Martin Takáč 4 | # 5 | # All Rights Reserved. 6 | # 7 | # Authors: Albert Berahas, Majid Jahani, Martin Takáč 8 | # 9 | # Please cite: 10 | # 11 | # A. S. Berahas, M. Jahani, and M. Takáč, "Quasi-Newton Methods for 12 | # Deep Learning: Forget the Past, Just Sample." (2019). Lehigh University. 13 | # http://arxiv.org/abs/1901.09997 14 | # ========================================================================== 15 | 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import pickle 19 | from S_LSR1 import * 20 | from S_LBFGS import * 21 | from parameters import * 22 | from network import * 23 | from data_generation import * 24 | import os.path 25 | import sys 26 | 27 | input1 = sys.argv[1] 28 | 29 | 30 | # ========================================================================== 31 | def main(opt=input1): 32 | 33 | """Call the selected solver with the selected parameters.""" 34 | if opt == "SLSR1": 35 | S_LSR1(w_init,X,y,cp.seed,cp.numIter,cp.mmr,cp.radius,cp.eps,cp.eta,cp.delta_init,cp.epsTR,cp.num_weights,dnn,sess) 36 | elif opt == "SLBFGS": 37 | S_LBFGS(w_init,X,y,cp.seed,cp.numIter,cp.mmr, 38 | cp.radius,cp.eps,cp.alpha_init,cp.cArmijo,cp.rhoArmijo,cp.num_weights,cp.init_sampling_SLBFGS,dnn,sess) 39 | 40 | # Get the parameters 41 | cp = parameters() 42 | 43 | # Create the data 44 | X,y = getData(cp.num_pts,cp.freq,cp.offset) 45 | 46 | # Create network 47 | os.environ["CUDA_VISIBLE_DEVICES"] = cp.GPUnumber 48 | sess = tf.InteractiveSession() 49 | dnn = DNN(cp.sizeNet,cp.activation,cp.mmr) 50 | 51 | # Set the initial point 52 | np.random.seed(cp.seed) 53 | w_init = np.random.randn(cp.num_weights,1) 54 | 55 | # ========================================================================== 56 | if __name__ == '__main__': 57 | """Run the selected solver.""" 58 | main() -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (C) 2019 Albert Berahas, Majid Jahani, Martin Takáč 4 | # 5 | # All Rights Reserved. 6 | # 7 | # Authors: Albert Berahas, Majid Jahani, Martin Takáč 8 | # 9 | # Please cite: 10 | # 11 | # A. S. Berahas, M. Jahani, and M. Takáč, "Quasi-Newton Methods for 12 | # Deep Learning: Forget the Past, Just Sample." (2019). Lehigh University. 13 | # http://arxiv.org/abs/1901.09997 14 | # ========================================================================== 15 | 16 | import tensorflow as tf 17 | import numpy as np 18 | import time 19 | 20 | # ========================================================================== 21 | def weight_variable(shape, std=0.1): 22 | initial = tf.truncated_normal(shape, stddev=std, dtype=tf.float64) 23 | return tf.Variable(initial,dtype=tf.float64) 24 | 25 | 26 | # ========================================================================== 27 | class DNN: 28 | """This class constructs the network used. 29 | The inputs are: (1) sizeNet (total number of weights, 30 | and (2) activation. 31 | Note, that for this code, we fix the network to have 32 | 6 fully connected layers, with varying number of nodes. 33 | Moreover, we concatenate the weight matrices into a long 34 | vector as this allows for easier implementation of our 35 | methods.""" 36 | def __init__(self,hiddenSizes,activation="sigmoid",mmr=10): 37 | 38 | x = tf.placeholder(tf.float64, shape=[None, 2]) 39 | y_ = tf.placeholder(tf.float64, shape=[None, 2]) 40 | 41 | FC1 = hiddenSizes[0] 42 | FC2 = hiddenSizes[1] 43 | FC3 = hiddenSizes[2] 44 | FC4 = hiddenSizes[3] 45 | FC5 = hiddenSizes[4] 46 | FC6 = 2 47 | 48 | sizes = [2*FC1, FC1, FC1*FC2, FC2,FC2*FC3, FC3,FC3*FC4, FC4, FC4*FC5, FC5, FC5*FC6, FC6] 49 | 50 | n = np.sum(sizes) 51 | params = weight_variable([n, 1],1.0/(n)) 52 | uparam = tf.unstack(params,axis = 0) 53 | 54 | W1 = tf.reshape(tf.stack( uparam[0:sizes[0]] ), shape=[2,FC1]) 55 | b1 = tf.reshape(tf.stack( uparam[sum(sizes[0:1]):sum(sizes[0:1])+sizes[1]] ), shape=[FC1]) 56 | 57 | W2 = tf.reshape(tf.stack( uparam[sum(sizes[0:2]):sum(sizes[0:2])+sizes[2]] ), shape=[FC1, FC2]) 58 | b2 = tf.reshape(tf.stack( uparam[sum(sizes[0:3]):sum(sizes[0:3])+sizes[3]] ), shape=[FC2]) 59 | 60 | W3 = tf.reshape(tf.stack( uparam[sum(sizes[0:4]):sum(sizes[0:4])+sizes[4]] ), shape=[FC2, FC3]) 61 | b3 = tf.reshape(tf.stack( uparam[sum(sizes[0:5]):sum(sizes[0:5])+sizes[5]] ), shape=[FC3]) 62 | 63 | W4 = tf.reshape(tf.stack( uparam[sum(sizes[0:6]):sum(sizes[0:6])+sizes[6]] ), shape=[FC3, FC4]) 64 | b4 = tf.reshape(tf.stack( uparam[sum(sizes[0:7]):sum(sizes[0:7])+sizes[7]] ), shape=[FC4]) 65 | 66 | W5 = tf.reshape(tf.stack( uparam[sum(sizes[0:8]):sum(sizes[0:8])+sizes[8]] ), shape=[FC4, FC5]) 67 | b5 = tf.reshape(tf.stack( uparam[sum(sizes[0:9]):sum(sizes[0:9])+sizes[9]] ), shape=[FC5]) 68 | 69 | W6 = tf.reshape(tf.stack( uparam[sum(sizes[0:10]):sum(sizes[0:10])+sizes[10]] ), shape=[FC5, FC6]) 70 | b6 = tf.reshape(tf.stack( uparam[sum(sizes[0:11]):sum(sizes[0:11])+sizes[11]] ), shape=[FC6]) 71 | 72 | Ws = [W1,W2,W3,W4,W5,W6] 73 | bs = [b1,b2,b3,b4,b5,b6] 74 | 75 | 76 | if activation=="sigmoid": 77 | acf = tf.nn.sigmoid 78 | if activation=="ReLU": 79 | acf = tf.nn.relu 80 | if activation=="Softplus": 81 | acf = tf.nn.softplus 82 | 83 | a1 = acf(tf.matmul(x, W1) + b1) 84 | a2 = acf(tf.matmul(a1, W2) + b2) 85 | a3 = acf(tf.matmul(a2, W3) + b3) 86 | a4 = acf(tf.matmul(a3, W4) + b4) 87 | a5 = acf(tf.matmul(a4, W5) + b5) 88 | a6 = (tf.matmul(a5, W6) + b6) 89 | 90 | 91 | #----------------------------------------------- 92 | #----------- Function, Gradient, Hessian, Accuracy and Other Operators -------------- 93 | #----------------------------------------------- 94 | output = a6 # Output of network 95 | probdist = tf.nn.softmax(output) # Softmax of output layer 96 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=output)) # Cross entropy loss 97 | correct_prediction = tf.equal(tf.argmax(a6, 1), tf.argmax(y_,1)) 98 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # Accuracy computation 99 | 100 | self.output= output 101 | self.probdist=probdist 102 | self.x = x 103 | self.y = y_ 104 | self.Ws = Ws 105 | self.bs = bs 106 | self.cross_entropy = cross_entropy 107 | self.accuracy = accuracy 108 | self.correct_prediction = correct_prediction 109 | self.params = params 110 | 111 | self.updateVal = tf.placeholder(tf.float64, shape=[int(params.shape[0]),1]) # Placeholder for updating parameters 112 | self.updateOp = tf.assign_add(params, self.updateVal).op # Operator for updating parameters 113 | self.G = tf.gradients(cross_entropy,params) # Gradient computation 114 | self.H = tf.hessians(cross_entropy,params) # Hessian computation 115 | self.ASSIGN_OP = tf.assign(self.params, self.updateVal).op # Operator for assigning parameters 116 | Gradient = self.G[0] 117 | self.vecs = tf.placeholder(dtype=tf.float64, shape=[int(self.params.shape[0]), mmr]) #Placeholder for the matrix 118 | self.Gv = tf.reshape(Gradient,shape=(1,-1)) 119 | self.grad_vs =(tf.matmul(self.Gv,self.vecs)) 120 | self.Hvs = tf.stack([tf.gradients(tm[0], params, stop_gradients=self.vecs) for tm in tf.unstack(self.grad_vs, axis=1) ] ) # Operator for Hessian-matrix product 121 | -------------------------------------------------------------------------------- /parameters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (C) 2019 Albert Berahas, Majid Jahani, Martin Takáč 4 | # 5 | # All Rights Reserved. 6 | # 7 | # Authors: Albert Berahas, Majid Jahani, Martin Takáč 8 | # 9 | # Please cite: 10 | # 11 | # A. S. Berahas, M. Jahani, and M. Takáč, "Quasi-Newton Methods for 12 | # Deep Learning: Forget the Past, Just Sample." (2019). Lehigh University. 13 | # http://arxiv.org/abs/1901.09997 14 | # ========================================================================== 15 | 16 | import numpy as np 17 | 18 | # ========================================================================== 19 | 20 | class parameters : 21 | def __init__(self): 22 | """Return the setting of paramters.""" 23 | 24 | #----------------------------------------------- 25 | #----------- Network Parameters ---------------- 26 | #----------------------------------------------- 27 | 28 | #----------- inputs for SIN problem ------------ 29 | self.freq = 8 30 | self.offset = 0.8 31 | self.num_pts = 50 32 | # sin(freq*xx)+offset 33 | # sin(freq*xx)-offset 34 | 35 | 36 | #----------- activation function --------------- 37 | # activation function can be selected here, the 38 | # possible inputs are "sigmoid", "ReLU" and "Softplus" 39 | 40 | self.activation="sigmoid" 41 | 42 | #---------------- network size ----------------- 43 | # the size of network can be specified here; note that 44 | # it will be fully connected network, e.g. [2,2,2,2,2,2] 45 | # contains 6 layers with 2 nodes in every layer 46 | 47 | self.FC1 = 2 48 | self.FC2 = 2 49 | self.FC3 = 2 50 | self.FC4 = 2 51 | self.FC5 = 2 52 | self.FC6 = 2 53 | self.sizeNet =[self.FC1,self.FC2,self.FC3,self.FC4,self.FC5,self.FC6] 54 | dimensionSet = [2*self.FC1, self.FC1, self.FC1*self.FC2, self.FC2,self.FC2*self.FC3, 55 | self.FC3,self.FC3*self.FC4, self.FC4, self.FC4*self.FC5, self.FC5, self.FC5*self.FC6, self.FC6] 56 | 57 | self.num_weights = np.sum(dimensionSet) # dimension of the problem 58 | #----------------------------------------------- 59 | #----------------------------------------------- 60 | 61 | #----------------------------------------------- 62 | #----------- Algorithm Parameters -------------- 63 | #----------------------------------------------- 64 | 65 | self.seed = 67 # random seed 66 | self.numIter = 1000 # maximum number of iterations 67 | self.mmr = 10 # memory length for S-LSR1, S-LBFGS 68 | self.radius = 1 # sampling radius for S-LSR1, S-LBFGS 69 | self.eps = 1e-8 # tolerance for updating quasi-Newton matrices 70 | self.eta = 1e-6 # tolerance for ared/pred reduction in TR 71 | self.delta_init = 1 # initial TR radius 72 | self.alpha_init = 1 # initial step length 73 | self.epsTR = 1e-10 # tolernace for CG_Steinhaug 74 | self.cArmijo = 1e-4 # Armijo sufficient decrease parameter 75 | self.rhoArmijo = .5 # Armijo backtracking factor 76 | self.init_sampling_SLBFGS = "off" # S-LBFGS sampling from first iteration 77 | #----------------------------------------------- 78 | #----------------------------------------------- 79 | 80 | #----------------------------------------------- 81 | #------------- Other Parameters ---------------- 82 | #----------------------------------------------- 83 | self.GPUnumber = "0" # GPU ID 84 | -------------------------------------------------------------------------------- /sampleSY.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (C) 2019 Albert Berahas, Majid Jahani, Martin Takáč 4 | # 5 | # All Rights Reserved. 6 | # 7 | # Authors: Albert Berahas, Majid Jahani, Martin Takáč 8 | # 9 | # Please cite: 10 | # 11 | # A. S. Berahas, M. Jahani, and M. Takáč, "Quasi-Newton Methods for 12 | # Deep Learning: Forget the Past, Just Sample." (2019). Lehigh University. 13 | # http://arxiv.org/abs/1901.09997 14 | # ========================================================================== 15 | 16 | import numpy as np 17 | from numpy import linalg as LA 18 | import tensorflow as tf 19 | 20 | 21 | def sample_pairs_SY_SLSR1(X,y,num_weights,mmr,radius,eps,dnn,numHessEval,sess): 22 | """ Function that computes SY pairs for S-LSR1 method""" 23 | 24 | Stemp = radius*np.random.randn(num_weights,mmr) 25 | Ytemp = np.squeeze(sess.run([dnn.Hvs], feed_dict={dnn.x: X, dnn.y:y, dnn.vecs: Stemp})).T 26 | numHessEval += 1 27 | S = np.zeros((num_weights,0)) 28 | Y = np.zeros((num_weights,0)) 29 | 30 | counterSucc = 0 31 | for idx in xrange(mmr): 32 | 33 | L = np.zeros((Y.shape[1],Y.shape[1])) 34 | for ii in xrange(Y.shape[1]): 35 | for jj in range(0,ii): 36 | L[ii,jj] = S[:,ii].dot(Y[:,jj]) 37 | 38 | 39 | tmp = np.sum((S * Y),axis=0) 40 | D = np.diag(tmp) 41 | M = (D + L + L.T) 42 | Minv = np.linalg.inv(M) 43 | 44 | tmp1 = np.matmul(Y.T,Stemp[:,idx]) 45 | tmp2 = np.matmul(Minv,tmp1) 46 | Bksk = np.squeeze(np.matmul(Y,tmp2)) 47 | yk_BkskDotsk = ( Ytemp[:,idx]- Bksk ).T.dot( Stemp[:,idx] ) 48 | if np.abs(np.squeeze(yk_BkskDotsk)) > ( 49 | eps *(LA.norm(Ytemp[:,idx]- Bksk ) * LA.norm(Stemp[:,idx])) ): 50 | counterSucc += 1 51 | 52 | S = np.append(S,Stemp[:,idx].reshape(num_weights,1),axis = 1) 53 | Y = np.append(Y,Ytemp[:,idx].reshape(num_weights,1),axis=1) 54 | 55 | return S,Y,counterSucc,numHessEval 56 | 57 | def sample_pairs_SY_SLBFGS(X,y,num_weights,mmr,radius,eps,dnn,numHessEval,sess): 58 | """ Function that computes SY pairs for S-LBFGS method""" 59 | 60 | Stemp = radius*np.random.randn(num_weights,mmr) 61 | Ytemp = np.squeeze(sess.run([dnn.Hvs], feed_dict={dnn.x: X, dnn.y:y, dnn.vecs: Stemp})).T 62 | numHessEval += 1 63 | 64 | S = np.zeros((num_weights,0)) 65 | Y = np.zeros((num_weights,0)) 66 | 67 | counterSucc = 0 68 | for idx in xrange(mmr): 69 | sTy = Ytemp[:,idx].T.dot(Stemp[:,idx]) 70 | if sTy > eps *(LA.norm(Stemp[:,idx])*LA.norm(Ytemp[:,idx])): 71 | gamma_k = np.squeeze((Stemp[:,idx]).T.dot(Ytemp[:,idx])/((Ytemp[:,idx]).T.dot(Ytemp[:,idx]))) 72 | S = np.append(S,Stemp[:,idx].reshape(num_weights,1),axis = 1) 73 | Y = np.append(Y,Ytemp[:,idx].reshape(num_weights,1),axis=1) 74 | counterSucc += 1 75 | return S,Y,counterSucc,numHessEval,gamma_k 76 | -------------------------------------------------------------------------------- /util_func.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (C) 2019 Albert Berahas, Majid Jahani, Martin Takáč 4 | # 5 | # All Rights Reserved. 6 | # 7 | # Authors: Albert Berahas, Majid Jahani, Martin Takáč 8 | # 9 | # Please cite: 10 | # 11 | # A. S. Berahas, M. Jahani, and M. Takáč, "Quasi-Newton Methods for 12 | # Deep Learning: Forget the Past, Just Sample." (2019). Lehigh University. 13 | # http://arxiv.org/abs/1901.09997 14 | # ========================================================================== 15 | 16 | import numpy as np 17 | import random 18 | from numpy import linalg as LA 19 | import math 20 | 21 | # ========================================================================== 22 | def CG_Steinhaug_matFree(epsTR, g , deltak, S,Y,nv): 23 | """ 24 | The following function is used for sloving the trust region subproblem 25 | by utilizing "CG_Steinhaug" algorithm discussed in 26 | Nocedal, J., & Wright, S. J. (2006). Nonlinear Equations (pp. 270-302). Springer New York.; 27 | moreover, for Hessian-free implementation, we used the compact form of Hessian 28 | approximation discussed in Byrd, Richard H., Jorge Nocedal, and Robert B. Schnabel. 29 | "Representations of quasi-Newton matrices and their use in limited memory methods." 30 | Mathematical Programming 63.1-3 (1994): 129-156 31 | """ 32 | zOld = np.zeros((nv,1)) 33 | rOld = g 34 | dOld = -g 35 | trsLoop = 1e-12 36 | if LA.norm(rOld) < epsTR: 37 | return zOld 38 | flag = True 39 | pk= np.zeros((nv,1)) 40 | 41 | # for Hessfree 42 | L = np.zeros((Y.shape[1],Y.shape[1])) 43 | for ii in xrange(Y.shape[1]): 44 | for jj in range(0,ii): 45 | L[ii,jj] = S[:,ii].dot(Y[:,jj]) 46 | 47 | 48 | tmp = np.sum((S * Y),axis=0) 49 | 50 | D = np.diag(tmp) 51 | M = (D + L + L.T) 52 | Minv = np.linalg.inv(M) 53 | 54 | while flag: 55 | 56 | ################ 57 | tmp1 = np.matmul(Y.T,dOld) 58 | tmp2 = np.matmul(Minv,tmp1) 59 | Bk_d = np.matmul(Y,tmp2) 60 | 61 | ################ 62 | 63 | if dOld.T.dot(Bk_d) < trsLoop: 64 | tau = rootFinder(LA.norm(dOld)**2, 2*zOld.T.dot(dOld), (LA.norm(zOld)**2 - deltak**2)) 65 | pk = zOld + tau*dOld 66 | flag = False 67 | break 68 | alphaj = rOld.T.dot(rOld) / (dOld.T.dot(Bk_d)) 69 | zNew = zOld +alphaj*dOld 70 | 71 | if LA.norm(zNew) >= deltak: 72 | tau = rootFinder(LA.norm(dOld)**2, 2*zOld.T.dot(dOld), (LA.norm(zOld)**2 - deltak**2)) 73 | pk = zOld + tau*dOld 74 | flag = False 75 | break 76 | rNew = rOld + alphaj*Bk_d 77 | 78 | if LA.norm(rNew) < epsTR: 79 | pk = zNew 80 | flag = False 81 | break 82 | betajplus1 = rNew.T.dot(rNew) /(rOld.T.dot(rOld)) 83 | dNew = -rNew + betajplus1*dOld 84 | 85 | zOld = zNew 86 | dOld = dNew 87 | rOld = rNew 88 | return pk 89 | 90 | 91 | # ========================================================================== 92 | def rootFinder(a,b,c): 93 | """return the root of (a * x^2) + b*x + c =0""" 94 | r = b**2 - 4*a*c 95 | 96 | if r > 0: 97 | num_roots = 2 98 | x1 = ((-b) + np.sqrt(r))/(2*a+0.0) 99 | x2 = ((-b) - np.sqrt(r))/(2*a+0.0) 100 | x = max(x1,x2) 101 | if x>=0: 102 | return x 103 | else: 104 | print "no positive root!" 105 | elif r == 0: 106 | num_roots = 1 107 | x = (-b) / (2*a+0.0) 108 | if x>=0: 109 | return x 110 | else: 111 | print "no positive root!" 112 | else: 113 | print("No roots") 114 | 115 | def L_BFGS_two_loop_recursion(g_k,S,Y,k,mmr,gamma_k,nv): 116 | """ 117 | The following function returns the serach direction based 118 | on LBFGS two loop recursion discussed in 119 | Nocedal, J., & Wright, S. J. (2006). Nonlinear Equations (pp. 270-302). Springer New York. 120 | """ 121 | # idx = min(k,mmr) 122 | idx = min(S.shape[1],mmr) 123 | rho = np.zeros((idx,1)) 124 | 125 | theta = np.zeros((idx,1)) 126 | q = g_k 127 | for i in xrange(idx): 128 | rho[idx-i-1] = 1/ S[:,idx-i-1].reshape(nv,1).T.dot(Y[:,idx-i-1].reshape(nv,1)) 129 | theta[idx-i-1] =(rho[idx-i-1])*(S[:,idx-i-1].reshape(nv,1).T.dot(q)) 130 | q = q - theta[idx-i-1]*Y[:,idx-i-1].reshape(nv,1) 131 | 132 | r = gamma_k*q 133 | for j in xrange(idx): 134 | beta = (rho[j])*(Y[:,j].reshape(nv,1).T.dot(r)) 135 | r = r + S[:,j].reshape(nv,1)*(theta[j] - beta) 136 | 137 | return r 138 | --------------------------------------------------------------------------------