├── .gitignore ├── README.md ├── __init__.py ├── generate_figures ├── README.md ├── __init__.py ├── adagrad.py ├── convergence_utils.py ├── dropout │ ├── .DS_Store │ ├── .gitignore │ ├── LICENSE │ ├── README │ ├── __init__.py │ ├── deepae.py │ ├── load_data.py │ ├── logistic_sgd.py │ ├── mlp.py │ └── plot_results.sh ├── figure_cartoon.py ├── figure_compare_optimizers.py ├── figure_compare_sfo_L.py ├── figure_compare_sfo_N.py ├── figure_compare_sfo_variations.py ├── figure_overhead.py ├── figures_cae.py ├── models.py ├── nnet │ ├── __init__.py │ ├── conv.yaml │ ├── mnist.yaml │ └── model_gradient.py ├── optimization_wrapper.py ├── sag.py └── utils.py ├── sfo.m ├── sfo.py ├── sfo_demo.m └── sfo_demo.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | # C extensions 4 | *.so 5 | 6 | # Packages 7 | *.egg 8 | *.egg-info 9 | dist 10 | build 11 | eggs 12 | parts 13 | bin 14 | var 15 | sdist 16 | develop-eggs 17 | .installed.cfg 18 | lib 19 | lib64 20 | __pycache__ 21 | 22 | # Installer logs 23 | pip-log.txt 24 | 25 | # Unit test / coverage reports 26 | .coverage 27 | .tox 28 | nosetests.xml 29 | 30 | # Translations 31 | *.mo 32 | 33 | # Mr Developer 34 | .mr.developer.cfg 35 | .project 36 | .pydevproject 37 | 38 | iterate.dat 39 | *.npz 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Sum of Functions Optimizer (SFO) 2 | ================================ 3 | 4 | SFO is a function optimizer for the case where the target function breaks into a sum over minibatches, or a sum over contributing functions. It combines the benefits of both quasi-Newton and stochastic gradient descent techniques, and will likely converge faster and to a better function value than either. It does not require tuning of hyperparameters. It is described in more detail in the paper: 5 | > Jascha Sohl-Dickstein, Ben Poole, and Surya Ganguli
6 | > Fast large-scale optimization by unifying stochastic gradient and quasi-Newton methods
7 | > International Conference on Machine Learning (2014)
8 | > arXiv preprint arXiv:1311.2115 (2013)
9 | > http://arxiv.org/abs/1311.2115 10 | 11 | This repository provides easy to use Python and MATLAB implementations of SFO, as well as functions to exactly reproduce the figures in the paper.
12 | 13 | ## Use SFO 14 | 15 | Simple example code which trains an autoencoder is in **sfo_demo.py** and **sfo_demo.m**, and is reproduced at the end of this README. 16 | 17 | ### Python package 18 | 19 | To use SFO, you should first import SFO, 20 | `from sfo import SFO` 21 | then initialize it, 22 | `optimizer = SFO(f_df, theta_init, subfunction_references)` 23 | then call the optimizer, specifying the number of optimization passes to perform, 24 | `theta = optimizer.optimize(num_passes=1)`. 25 | 26 | The three required initialization parameters are: 27 | - *f_df* - Returns the function value and gradient for a single subfunction 28 | call. Should have the form 29 | `f, dfdtheta = f_df(theta, subfunction_references[idx])`, 30 | where *idx* is the index of a single subfunction. 31 | - *theta_init* - The initial parameters to be used for optimization. *theta_init* can 32 | be either a NumPy array, an array of NumPy arrays, a dictionary of NumPy 33 | arrays, or a nested combination thereof. The gradient returned by *f_df* 34 | should have the same form as *theta_init*. 35 | - *subfunction_references* - A list containing an identifying element for 36 | each subfunction. The elements in this list could be, eg, numpy 37 | matrices containing minibatches, or indices identifying the 38 | subfunction, or filenames from which target data should be read. 39 | **If each subfunction corresponds to a minibatch, then the number of 40 | subfunctions should be approximately \[number subfunctions\] = sqrt(\[dataset size\])/10**. 41 | 42 | More detailed documentation, and additional options, can be found in **sfo.py**. 43 | 44 | ### MATLAB package 45 | 46 | To use SFO you must first initialize the optimizer, 47 | `optimizer = sfo(@f_df, theta_init, subfunction_references, [varargin]);` 48 | then call the optimizer, specifying the number of optimization passes to perform, 49 | `theta = optimizer.optimize(20);`. 50 | 51 | The initialization parameters are: 52 | - *f_df* - Returns the function value and gradient for a single subfunction 53 | call. Should have the form 54 | `[f, dfdtheta] = f_df(theta, subfunction_references{idx}, varargin{:})`, 55 | where *idx* is the index of a single subfunction. 56 | - *theta_init* - The initial parameters to be used for optimization. *theta_init* can 57 | be either a vector, a matrix, or a cell array with a vector or 58 | matrix in every cell. The gradient returned by *f_df* should have the 59 | same form as *theta_init*. 60 | - *subfunction_references* - A cell array containing an identifying element 61 | for each subfunction. The elements in this list could be, eg, 62 | matrices containing minibatches, or indices identifying the 63 | subfunction, or filenames from which target data should be read. 64 | **If each subfunction corresponds to a minibatch, then the number of 65 | subfunctions should be approximately \[number subfunctions\] = sqrt(\[dataset size\])/10**. 66 | - *[varargin]* - Any additional parameters, which will be passed through to *f_df* each time 67 | it is called. 68 | 69 | Slightly more documentation can be found in **sfo.m**. 70 | 71 | ### Special situations 72 | 73 | Email jaschasd@google.com with questions if you don't find your answer here. 74 | 75 | #### Reducing overhead 76 | 77 | If too much time is spent inside SFO relative to inside the objective function, then reduce the number of subfunctions by increasing the minibatch size or merging subfunctions. 78 | 79 | #### Replacing minibatches / using SFO in the infinite data limit 80 | 81 | In the Python version of the code, a subfunction or minibatch can be replaced by calling the function `replace_subfunction`. See the documentation in ***sfo.py*** for more details. Note that replacing a subfunction without calling `replace_subfunction` will cause the optimizer to fail, since SFO relies on subfunctions returning consistent gradients. 82 | 83 | #### Using with dropout 84 | 85 | Stochastic gradients will break SFO, because it uses the change in the minibatch/subfunction gradient to estimate the Hessian matrix. The benefits of noise regularization can be achieved without making the gradients stochastic by using frozen noise. That is, in the case of dropout, assign a random dropout mask to each datapoint. Every time that datapoint is evaluated however, use the same dropout mask. This makes the gradients consistent across multiple evaluations of the minibatch. 86 | 87 | ## Reproduce figures from paper 88 | To reproduce the figures from the paper, run **generate\_figures/figure\_\*.py**. Several of the figures rely on a subdirectory **figure_data/** with training data. This can be downloaded from https://www.dropbox.com/sh/h9z4djlgl2tagmu/GlVAJyErf8 . 89 | 90 | ## Example code 91 | 92 | The following code blocks train an autoencoder using SFO in Python and MATLAB respectively. Identical code is in **sfo_demo.py** and **sfo_demo.m**. 93 | 94 | ### Python example code 95 | 96 | ```python 97 | import matplotlib.pyplot as plt 98 | import numpy as np 99 | from numpy.random import randn 100 | from sfo import SFO 101 | 102 | # define an objective function and gradient 103 | def f_df(theta, v): 104 | """ 105 | Calculate reconstruction error and gradient for an autoencoder with sigmoid 106 | nonlinearity. 107 | v contains the training data, and will be different for each subfunction. 108 | """ 109 | h = 1./(1. + np.exp(-(np.dot(theta['W'], v) + theta['b_h']))) 110 | v_hat = np.dot(theta['W'].T, h) + theta['b_v'] 111 | f = np.sum((v_hat - v)**2) / v.shape[1] 112 | dv_hat = 2.*(v_hat - v) / v.shape[1] 113 | db_v = np.sum(dv_hat, axis=1).reshape((-1,1)) 114 | dW = np.dot(h, dv_hat.T) 115 | dh = np.dot(theta['W'], dv_hat) 116 | db_h = np.sum(dh*h*(1.-h), axis=1).reshape((-1,1)) 117 | dW += np.dot(dh*h*(1.-h), v.T) 118 | dfdtheta = {'W':dW, 'b_h':db_h, 'b_v':db_v} 119 | return f, dfdtheta 120 | 121 | # set model and training data parameters 122 | M = 20 # number visible units 123 | J = 10 # number hidden units 124 | D = 100000 # full data batch size 125 | N = int(np.sqrt(D)/10.) # number minibatches 126 | # generate random training data 127 | v = randn(M,D) 128 | 129 | # create the array of subfunction specific arguments 130 | sub_refs = [] 131 | for i in range(N): 132 | # extract a single minibatch of training data. 133 | sub_refs.append(v[:,i::N]) 134 | 135 | # initialize parameters 136 | theta_init = {'W':randn(J,M), 'b_h':randn(J,1), 'b_v':randn(M,1)} 137 | # initialize the optimizer 138 | optimizer = SFO(f_df, theta_init, sub_refs) 139 | # run the optimizer for 1 pass through the data 140 | theta = optimizer.optimize(num_passes=1) 141 | # continue running the optimizer for another 20 passes through the data 142 | theta = optimizer.optimize(num_passes=20) 143 | 144 | # plot the convergence trace 145 | plt.plot(np.array(optimizer.hist_f_flat)) 146 | plt.xlabel('Iteration') 147 | plt.ylabel('Minibatch Function Value') 148 | plt.title('Convergence Trace') 149 | 150 | # test the gradient of f_df 151 | optimizer.check_grad() 152 | ``` 153 | 154 | ### MATLAB example code 155 | 156 | ```MATLAB 157 | % set model and training data parameters 158 | M = 20; % number visible units 159 | J = 10; % number hidden units 160 | D = 100000; % full data batch size 161 | N = floor(sqrt(D)/10.); % number minibatches 162 | % generate random training data 163 | v = randn(M,D); 164 | 165 | % create the cell array of subfunction specific arguments 166 | sub_refs = cell(N,1); 167 | for i = 1:N 168 | % extract a single minibatch of training data. 169 | sub_refs{i} = v(:,i:N:end); 170 | end 171 | 172 | % initialize parameters 173 | % Parameters can be stored as a vector, a matrix, or a cell array with a 174 | % vector or matrix in each cell. Here the parameters are 175 | % {[weight matrix], [hidden bias], [visible bias]}. 176 | theta_init = {randn(J,M), randn(J,1), randn(M,1)}; 177 | % initialize the optimizer 178 | optimizer = sfo(@f_df_autoencoder, theta_init, sub_refs); 179 | % run the optimizer for half a pass through the data 180 | theta = optimizer.optimize(0.5); 181 | % run the optimizer for another 20 passes through the data, continuing from 182 | % the theta value where the prior call to optimize() ended 183 | theta = optimizer.optimize(20); 184 | 185 | % plot the convergence trace 186 | plot(optimizer.hist_f_flat); 187 | xlabel('Iteration'); 188 | ylabel('Minibatch Function Value'); 189 | title('Convergence Trace'); 190 | 191 | % test the gradient of f_df 192 | optimizer.check_grad(); 193 | ``` 194 | 195 | The subfunction/minibatch objective function and gradient for the MATLAB code is defined as follows, 196 | ```MATLAB 197 | function [f, dfdtheta] = f_df_autoencoder(theta, v) 198 | % [f, dfdtheta] = f_df_autoencoder(theta, v) 199 | % Calculate L2 reconstruction error and gradient for an autoencoder 200 | % with sigmoid nonlinearity. 201 | % Parameters: 202 | % theta - A cell array containing 203 | % {[weight matrix], [hidden bias], [visible bias]}. 204 | % v - A [# visible, # datapoints] matrix containing training data. 205 | % v will be different for each subfunction. 206 | % Returns: 207 | % f - The L2 reconstruction error for data v and parameters theta. 208 | % df - A cell array containing the gradient of f with each of the 209 | % parameters in theta. 210 | 211 | W = theta{1}; 212 | b_h = theta{2}; 213 | b_v = theta{3}; 214 | 215 | h = 1./(1 + exp(-bsxfun(@plus, W * v, b_h))); 216 | v_hat = bsxfun(@plus, W' * h, b_v); 217 | f = sum(sum((v_hat - v).^2)) / size(v, 2); 218 | dv_hat = 2*(v_hat - v) / size(v, 2); 219 | db_v = sum(dv_hat, 2); 220 | dW = h * dv_hat'; 221 | dh = W * dv_hat; 222 | db_h = sum(dh.*h.*(1-h), 2); 223 | dW = dW + dh.*h.*(1-h) * v'; 224 | % give the gradients the same order as the parameters 225 | dfdtheta = {dW, db_h, db_v}; 226 | end 227 | ``` 228 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sohl-Dickstein/Sum-of-Functions-Optimizer/36c66dfef68b344eca0ca9bb56d0a7bed2075036/__init__.py -------------------------------------------------------------------------------- /generate_figures/README.md: -------------------------------------------------------------------------------- 1 | Reproducible Science 2 | ================================ 3 | 4 | The code in this directory reproduces all of the figures from the paper 5 | > Jascha Sohl-Dickstein, Ben Poole, and Surya Ganguli
6 | > An adaptive low dimensional quasi-Newton sum of functions optimizer
7 | > International Conference on Machine Learning (2014)
8 | > arXiv preprint arXiv:1311.2115 (2013)
9 | > http://arxiv.org/abs/1311.2115 10 | 11 | - **figure\_compare_optimizers.py** produces the convergence comparison of different optimizers in Figure 3. 12 | - **figure\_compare_sfo_N.py** produces the convergence comparison for SFO with different numbers of subfunctions or minibatches in Figure 2(c). 13 | - **figure\_compare_sfo_variations.py** produces the convergence comparison for SFO with different design choices in supplemental Figure C.1. 14 | - **figure\_overhead.py** produces the computational overhead analysis in Figure 2(a,b). 15 | - **figure\_cartoon.py** produces the cartoon illustration of the SFO algorithm in Figure 1. 16 | 17 | To include a new objective function in the convergence comparison figure: 18 | 19 | 1. Add a class to **models.py** which provides the objective function and gradient, and initialization, for the new objective. The *toy* class is a good template to modify. 20 | 2. Add the new objective class to *models_to_train* in **figure_compare_optimizers.py**. 21 | 3. (optional) Modify plot characteristics in *make_plot_single_model* in **convergence_utils.py**. 22 | 4. Run **figure_compare_optimizers.py**. 23 | 24 | To include a new optimizer in the convergence comparison figure: 25 | 26 | 1. Add a function implementing the optimizer to **optimization_wrapper.py**. The *SGD* function is a good template to modify. 27 | 2. Add the model to *optimizers_to_use* in **figure_compare_optimizers.py**. 28 | 3. (optional) Modify plot characteristics in *make_plot_single_model* in **convergence_utils.py**. (for instance, to select only the best performing hyperparameters, add the new optimizer to the list of *best_traces* function calls) 29 | 4. Run **figure_compare_optimizers.py**. 30 | 31 | For more documentation on SFO in general, see the README.md in the parent directory. 32 | 33 | Several of the figures rely on a subdirectory **figure_data/** with training data. This can be downloaded from https://www.dropbox.com/sh/h9z4djlgl2tagmu/GlVAJyErf8 . 34 | -------------------------------------------------------------------------------- /generate_figures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sohl-Dickstein/Sum-of-Functions-Optimizer/36c66dfef68b344eca0ca9bb56d0a7bed2075036/generate_figures/__init__.py -------------------------------------------------------------------------------- /generate_figures/adagrad.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2013 Jascha Sohl-Dickstein, Ben Poole 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | 13 | This is an implementation of the ADAGrad algorithm: 14 | Duchi, J., Hazan, E., & Singer, Y. (2010). 15 | Adaptive subgradient methods for online learning and stochastic optimization. 16 | Journal of Machine Learning Research 17 | http://www.eecs.berkeley.edu/Pubs/TechRpts/2010/EECS-2010-24.pdf 18 | """ 19 | 20 | from numpy import * 21 | import numpy as np 22 | 23 | class ADAGrad(object): 24 | 25 | def __init__(self, f_df, theta, subfunction_references, reps=1, learning_rate=0.1, args=(), kwargs={}): 26 | 27 | self.reps = reps 28 | self.learning_rate = learning_rate 29 | 30 | self.N = len(subfunction_references) 31 | self.sub_ref = subfunction_references 32 | self.f_df = f_df 33 | self.args = args 34 | self.kwargs = kwargs 35 | 36 | self.num_steps = 0 37 | self.theta = theta.copy().reshape((-1,1)) 38 | self.grad_history = np.zeros_like(self.theta) 39 | self.M = self.theta.shape[0] 40 | 41 | self.f = ones((self.N))*np.nan 42 | 43 | 44 | def optimize(self, num_passes = 10, num_steps = None): 45 | if num_steps==None: 46 | num_steps = num_passes*self.N 47 | for i in range(num_steps): 48 | if not self.optimization_step(): 49 | break 50 | #print 'L ', self.L 51 | return self.theta 52 | 53 | def optimization_step(self): 54 | idx = np.random.randint(self.N) 55 | gradii = np.zeros_like(self.theta) 56 | lossii = 0. 57 | for i in range(self.reps): 58 | lossi, gradi = self.f_df(self.theta, (self.sub_ref[idx], ), *self.args, **self.kwargs) 59 | lossii += lossi / self.reps 60 | gradii += gradi.reshape(gradii.shape) / self.reps 61 | 62 | self.num_steps += 1 63 | learning_rates = self.learning_rate / (np.sqrt(1./self.num_steps + self.grad_history)) 64 | learning_rates[np.isinf(learning_rates)] = self.learning_rate 65 | self.theta -= learning_rates * gradii 66 | self.grad_history += gradii**2 67 | self.f[idx] = lossii 68 | 69 | if not np.isfinite(lossii): 70 | print("Non-finite subfunction. Ending run.") 71 | return False 72 | return True 73 | 74 | -------------------------------------------------------------------------------- /generate_figures/convergence_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains various shared functions used to generate the convergence 3 | figures in the SFO paper. 4 | 5 | Copyright 2014 Jascha Sohl-Dickstein 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import matplotlib 19 | matplotlib.use('Agg') # no displayed figures -- need to call before loading pylab 20 | import matplotlib.pyplot as plt 21 | 22 | import datetime 23 | import glob 24 | import numpy as np 25 | import re 26 | import warnings 27 | 28 | from collections import defaultdict 29 | from itertools import cycle 30 | 31 | 32 | def sorted_nicely(strings): 33 | "Sort strings the way humans are said to expect." 34 | return sorted(strings, key=natural_sort_key) 35 | 36 | def natural_sort_key(key): 37 | import re 38 | return [int(t) if t.isdigit() else t for t in re.split(r'(\d+)', key)] 39 | 40 | def best_traces(prop, hist_f, styledict, color, neighbor_dist = 1): 41 | """ 42 | Prunes the history to only the best trace in prop, and its 43 | neighbors within neighbor_dist in the sorted order. It sets 44 | only the best trace to a dark line, and makes them all color. 45 | """ 46 | 47 | if len(prop)>0: 48 | minnm = prop[0] 49 | prop = sorted_nicely(prop) 50 | # get the best one 51 | minf = np.inf 52 | for nm in prop: 53 | ff = np.asarray(hist_f[nm]) 54 | minf2 = ff[-1] #np.min(ff) 55 | print(nm, minf2) 56 | if minf2 < minf: 57 | minnm = nm 58 | minf = minf2 59 | ii = prop.index(minnm) 60 | for i in range(len(prop)): 61 | if np.abs(i-ii) > neighbor_dist: 62 | del hist_f[prop[i]] 63 | for jj in range(1,neighbor_dist+1): 64 | try: 65 | styledict[prop[ii-jj]] = {'color':color, 'ls':':'} 66 | except: 67 | print("failure around", prop[0]) 68 | try: 69 | styledict[prop[ii+jj]] = {'color':color, 'ls':'-.'} 70 | except: 71 | print("failure around", prop[-1]) 72 | styledict[prop[ii]] = {'color':color, 'ls':'-', 'linewidth':4} 73 | 74 | def make_plot_single_model(hist_f, hist_x_projection, hist_events, model_name, 75 | num_subfunctions, full_objective_period, display_events=False, 76 | display_trajectory=False, figsize=(4.5,3.5), external_legend=True, name_prefix=""): 77 | """ 78 | Plot the different optimizers against each other for a single 79 | model. 80 | """ 81 | 82 | # xlabel was getting cut off, this seems to fix? 83 | from matplotlib import rcParams 84 | rcParams.update({'figure.autolayout': True}) 85 | 86 | # set the title 87 | title = model_name 88 | if model_name == 'protein' or model_name == 'protein logistic regression': 89 | title = 'Logistic Regression, Protein Dataset' 90 | elif model_name == 'Hopfield': 91 | title = 'Ising / Hopfield with MPF Objective' 92 | elif 'hard' in model_name: 93 | title = 'Multi-Layer Perceptron, Rectified Linear' 94 | elif 'soft' in model_name: 95 | title = 'Multi-Layer Perceptron, Sigmoid' 96 | elif model_name == 'sumf': 97 | title = 'Sum of Norms' 98 | elif model_name == 'Pylearn_conv': 99 | title = 'Convolutional Network, CIFAR-10' 100 | elif model_name == 'GLM': 101 | title = 'GLM, soft rectifying nonlinearity' 102 | model_name = 'GLM_soft' 103 | 104 | # set up linestyle cycler 105 | styles = [ 106 | {'color':'r', 'ls':'--'}, 107 | {'color':'g', 'ls':'-'}, 108 | {'color':'b', 'ls':'-.'}, 109 | {'color':'k', 'ls':':'}, 110 | {'color':'y', 'ls':'--'}, 111 | {'color':'r', 'ls':'-'}, 112 | {'color':'g', 'ls':'-.'}, 113 | {'color':'b', 'ls':':'}, 114 | {'color':'k', 'ls':'--'}, 115 | {'color':'y', 'ls':'-'}, 116 | {'color':'r', 'ls':'-.'}, 117 | {'color':'g', 'ls':':'}, 118 | {'color':'b', 'ls':'--'}, 119 | {'color':'k', 'ls':'-'}, 120 | {'color':'y', 'ls':'-.'}, 121 | ] 122 | stylecycler = cycle(styles) 123 | styledict = defaultdict(lambda: next(stylecycler)) 124 | 125 | zorder = dict() 126 | sorted_nm = sorted(hist_f.keys(), key=lambda nm: np.asarray(hist_f[nm])[-1], reverse=True) 127 | for ii, nm in enumerate(sorted_nm): 128 | zorder[nm] = ii 129 | 130 | ## override the default cycled styles for specific optimizers 131 | # LBFGS 132 | prop = [nm for nm in hist_f.keys() if 'LBFGS' in nm] 133 | for nm in prop: 134 | styledict[nm] = {'color':'r', 'ls':'-', 'linewidth':3} 135 | if 'batch' in nm: 136 | styledict[nm]['ls'] = ':' 137 | # SAG 138 | prop = [nm for nm in hist_f.keys() if 'SAG' in nm] 139 | best_traces(prop, hist_f, styledict, 'g') 140 | # SGD 141 | prop = [nm for nm in hist_f.keys() if 'SGD' in nm and 'momentum' not in nm] 142 | best_traces(prop, hist_f, styledict, 'c') 143 | # SGD_momentum 144 | prop = [nm for nm in hist_f.keys() if 'SGD' in nm and 'momentum' in nm] 145 | best_traces(prop, hist_f, styledict, 'm', neighbor_dist=0) 146 | # ADA 147 | prop = [nm for nm in hist_f.keys() if 'ADA' in nm] 148 | best_traces(prop, hist_f, styledict, 'b') 149 | # SFO 150 | prop = [nm for nm in hist_f.keys() if nm == 'SFO' or nm == 'SFO standard'] 151 | for nm in prop: 152 | styledict[nm] = {'color':'k', 'ls':'-', 'linewidth':4} 153 | # SFO number minibatches 154 | prop = [nm for nm in hist_f.keys() if 'SFO' in nm and 'N=' in nm] 155 | nprop = len(prop) 156 | for ii, nm in enumerate(sorted_nicely(prop)): 157 | #styledict[nm] = {'color':'k', 'dashes':(7,nprop/(ii+1.),), 'linewidth':5.*(ii+1.)/nprop} 158 | styledict[nm] = {'color':(1. - ii/(nprop-1.), 0., 0.), 'ls':'-', 'linewidth':4.*(ii+1.)/nprop} 159 | zorder[nm] = ii 160 | # SFO number history terms 161 | prop = [nm for nm in hist_f.keys() if 'SFO' in nm and 'L=' in nm] 162 | nprop = len(prop) 163 | for ii, nm in enumerate(sorted_nicely(prop)): 164 | #styledict[nm] = {'color':'k', 'dashes':(7,nprop/(ii+1.),), 'linewidth':5.*(ii+1.)/nprop} 165 | styledict[nm] = {'color':(1. - ii/(nprop-1.), 0., 0.), 'ls':'-', 'linewidth':4.*(ii+1.)/nprop} 166 | zorder[nm] = ii 167 | 168 | # plot the learning trace, and save pdf 169 | fig = plt.figure(figsize=figsize) 170 | 171 | lines = [] 172 | labels = [] 173 | 174 | fewest_passes = 0. 175 | 176 | for nm in sorted_nicely(hist_f.keys()): 177 | ff = np.asarray(hist_f[nm]) 178 | ff[ff>1.5*ff[0]] = np.nan 179 | xx = np.arange(1, len(ff)+1).astype(float)*full_objective_period/num_subfunctions 180 | if max(np.max(xx), fewest_passes) > 2*min(np.max(xx), fewest_passes) or nm == 'LBFGS': 181 | # ignore cases that were terminated early eg because of bad learning rate 182 | # also, sometimes LBFGS terminates early, so don't judge based on that 183 | fewest_passes = max(np.max(xx), fewest_passes) 184 | else: 185 | fewest_passes = min(np.max(xx), fewest_passes) 186 | #assert(fewest_passes > 4.) 187 | line = plt.semilogy( xx, ff, label=nm, zorder=zorder[nm], **styledict[nm] ) 188 | lines.append(line[0]) 189 | labels.append(nm) 190 | if display_events: 191 | # add special events 192 | for jj, events in enumerate(hist_events[nm]): 193 | st = {'s':100} 194 | if events.get('natural gradient subspace update', False): 195 | st['marker'] = '>' 196 | st['c'] = 'r' 197 | elif events.get('collapse subspace', False): 198 | st['marker'] = '*' 199 | st['c'] = 'y' 200 | elif events.get('step_failure', False): 201 | st['marker'] = '<' 202 | st['c'] = 'c' 203 | else: 204 | continue 205 | plt.scatter(xx[jj], ff[jj], **st) 206 | 207 | 208 | plt.ylabel( 'Full Batch Objective' ) 209 | plt.xlabel( 'Effective Passes Through Data' ) 210 | plt.title(title) 211 | plt.grid() 212 | plt.axes().set_axisbelow(True) 213 | 214 | ax = plt.axis() 215 | plt.axis([0, fewest_passes, ax[2], ax[3]]) 216 | ax = plt.axis() 217 | if "Autoencoder" in title: 218 | plt.yticks(np.arange(10, 46, 5.0), ["%d"%tt for tt in np.arange(10, 46, 5.0)]) 219 | plt.axis([ax[0], ax[1], 13, 30]) 220 | # elif "ICA" in title: 221 | # plt.axis([ax[0], ax[1], 0.23, 140]) 222 | elif "hard" in title: 223 | plt.axis([ax[0], ax[1], 1e-13, 1e1]) 224 | elif "Perceptron" in title: 225 | plt.axis([ax[0], ax[1], 1e-7, 1e1]) 226 | elif "ICA" in title: 227 | plt.axis([ax[0], ax[1], 1e0, 1e3]) 228 | try: 229 | plt.tight_layout() 230 | except: 231 | warnings.warn('tight_layout failed. try running with an Agg backend.') 232 | 233 | # update the labels to prettier text 234 | labels = map(lambda x: str.replace(x, "SGD_momentum ", r"SGD+mom "), labels) 235 | labels = map(lambda x: str.replace(x, "SGD ", r"SGD eta="), labels) 236 | labels = map(lambda x: str.replace(x, "ADAGrad ", r"ADAGrad eta="), labels) 237 | labels = map(lambda x: str.replace(x, "SAG ", r"SAG L="), labels) 238 | labels = map(lambda x: str.replace(x, "L=", r"$L=$"), labels) 239 | labels = map(lambda x: str.replace(x, "eta=", r"$\eta=$"), labels) 240 | labels = map(lambda x: str.replace(x, "mu=", r"$\mu=$"), labels) 241 | def number_shaver(ch, regx = re.compile('(?1.5*ff[0]] = np.nan 279 | xx = np.arange(1, len(ff)+1).astype(float)*full_objective_period/num_subfunctions 280 | plt.semilogy( xx, ff-minf, label=nm, zorder=zorder[nm], **styledict[nm] ) 281 | plt.ylabel( 'Full Batch Objective - Minimum' ) 282 | plt.xlabel( 'Effective Passes Through Data' ) 283 | if not external_legend: 284 | plt.legend( loc='best' ) 285 | plt.title(title) 286 | plt.grid() 287 | plt.axes().set_axisbelow(True) 288 | ax = plt.axis() 289 | plt.axis([0, fewest_passes, ax[2], ax[3]]) 290 | try: 291 | plt.tight_layout() 292 | except: 293 | warnings.warn('tight_layout failed. try running with an Agg backend.') 294 | fig.savefig(('figure_' + name_prefix + model_name + '_diff.pdf').replace(' ', '-')) 295 | 296 | if display_trajectory: 297 | # SAG 298 | prop = [nm for nm in hist_f.keys() if 'SAG' in nm] 299 | best_traces(prop, hist_f, styledict, 'g', neighbor_dist=0) 300 | # SGD 301 | prop = [nm for nm in hist_f.keys() if 'SGD' in nm] 302 | best_traces(prop, hist_f, styledict, 'c', neighbor_dist=0) 303 | # ADA 304 | prop = [nm for nm in hist_f.keys() if 'ADA' in nm] 305 | best_traces(prop, hist_f, styledict, 'b', neighbor_dist=0) 306 | 307 | # plot the learning trajectory in low-d projections, and save pdf 308 | # make the line styles appropriate 309 | for nm in hist_f.keys(): 310 | styledict[nm].pop('ls', None) 311 | styledict[nm].pop('linewidth', None) 312 | #styledict[nm]['linestyle'] = 'None' 313 | styledict[nm]['marker'] = '.' # ',' 314 | styledict[nm]['alpha'] = 0.5 # ',' 315 | fig = plt.figure(figsize=figsize) 316 | nproj = 3 # could make larger 317 | for i1 in range(nproj): 318 | for i2 in range(nproj): 319 | plt.subplot(nproj, nproj, i1 + nproj*i2 + 1) 320 | for nm in sorted_nicely(hist_f.keys()): 321 | xp = np.asarray(hist_x_projection[nm]) 322 | plt.plot( xp[:,i1], xp[:,i2], label=nm, zorder=zorder[nm], **styledict[nm] ) 323 | if display_events: 324 | # add special events 325 | for jj, events in enumerate(hist_events[nm]): 326 | st = {'s':100} 327 | if events.get('natural gradient subspace update', False): 328 | st['marker'] = '>' 329 | st['c'] = 'r' 330 | elif events.get('collapse subspace', False): 331 | st['marker'] = '*' 332 | st['c'] = 'y' 333 | elif events.get('step_failure', False): 334 | st['marker'] = '<' 335 | st['c'] = 'c' 336 | else: 337 | continue 338 | plt.scatter(xp[jj,i1], xp[jj,i2], **st) 339 | if not external_legend: 340 | plt.legend( loc='best' ) 341 | plt.suptitle(title) 342 | try: 343 | plt.tight_layout() 344 | except: 345 | warnings.warn('tight_layout failed. try running with an Agg backend.') 346 | fig.savefig(('figure_' + name_prefix + model_name + '_trajectory.pdf').replace(' ', '-')) 347 | 348 | 349 | def make_plots(history_nested, *args, **kwargs): 350 | for model_name in history_nested: 351 | history = history_nested[model_name] 352 | fig = make_plot_single_model(history['f'], history['x_projection'], history['events'], model_name, *args, **kwargs) 353 | 354 | def load_results(fnames=None, base_fname='figure_data_'): 355 | """ 356 | Load the function value traces during optimization for the 357 | set of models and optimizers provided by fnames. Find all 358 | files with matching filenames in the current directory if 359 | fnames not passed in. 360 | 361 | Output is a dictionary of dictionaries, where the inner dictionary 362 | contains different optimizers for each model, and the outer dictionary 363 | contains different models. 364 | 365 | Note that the files loaded can be as granular as a single optimizer 366 | for a single model for each file. 367 | """ 368 | 369 | if fnames==None: 370 | fnames = glob.glob(base_fname + '*.npz') 371 | 372 | num_subfunctions = None 373 | full_objective_period = None 374 | 375 | history_nested = {} 376 | for fn in fnames: 377 | data = np.load(fn) 378 | if num_subfunctions is None: 379 | num_subfunctions = data['num_subfunctions'] 380 | full_objective_period = data['full_objective_period'] 381 | if not (num_subfunctions == data['num_subfunctions'] and full_objective_period == data['full_objective_period']): 382 | print "****************" 383 | print "WARNING: mixing data with different numbers of subfunctions or delays between evaluating the full objective" 384 | print "make sure you are doing this intentionally (eg, for the convergence vs., number subfunctions plot)" 385 | print "****************" 386 | model_name = data['model_name'].tostring() 387 | print("loading", model_name) 388 | if not model_name in history_nested: 389 | history_nested[model_name] = data['history'][()].copy() 390 | else: 391 | print("updating") 392 | for subkey in history_nested[model_name].keys(): 393 | print subkey 394 | history_nested[model_name][subkey].update(data['history'][()].copy()[subkey]) 395 | data.close() 396 | 397 | return history_nested, num_subfunctions, full_objective_period 398 | 399 | def save_results(trainer, base_fname='figure_data_', store_x=True): 400 | """ 401 | Save the function trace for different optimizers for a 402 | given model to a .npz file. 403 | """ 404 | if not store_x: 405 | # delete the saved final x value so we don't run out of memory 406 | trainer.history['x'] = defaultdict(list) 407 | 408 | fname = base_fname + trainer.model.name + ".npz" 409 | np.savez(fname, history=trainer.history, model_name=trainer.model.name, 410 | num_subfunctions = trainer.num_subfunctions, 411 | full_objective_period=trainer.full_objective_period) 412 | -------------------------------------------------------------------------------- /generate_figures/dropout/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sohl-Dickstein/Sum-of-Functions-Optimizer/36c66dfef68b344eca0ca9bb56d0a7bed2075036/generate_figures/dropout/.DS_Store -------------------------------------------------------------------------------- /generate_figures/dropout/.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.txt 3 | *.png 4 | *.swp 5 | *.pyc 6 | -------------------------------------------------------------------------------- /generate_figures/dropout/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2012 Misha Denil 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | 21 | -------------------------------------------------------------------------------- /generate_figures/dropout/README: -------------------------------------------------------------------------------- 1 | Theano implementation of dropout. See: http://arxiv.org/abs/1207.0580 2 | 3 | Run with: 4 | 5 | ./mlp.py dropout 6 | 7 | for dropout, or 8 | 9 | ./mlp.py backprop 10 | 11 | for regular backprop with no dropout. 12 | 13 | Use: 14 | 15 | ./plot_results.sh results.png 16 | 17 | to visualize the results. 18 | 19 | Based on code from: 20 | - http://deeplearning.net/tutorial/mlp.html 21 | - http://deeplearning.net/tutorial/logreg.html 22 | 23 | Use the data here to make the units of the results comparable to Hinton's paper: 24 | - http://www.cs.ubc.ca/~mdenil/hidden/mnist_batches.npz 25 | 26 | -------------------------------------------------------------------------------- /generate_figures/dropout/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sohl-Dickstein/Sum-of-Functions-Optimizer/36c66dfef68b344eca0ca9bb56d0a7bed2075036/generate_figures/dropout/__init__.py -------------------------------------------------------------------------------- /generate_figures/dropout/deepae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import time 5 | 6 | import theano 7 | import theano.tensor as T 8 | from theano.ifelse import ifelse 9 | import theano.printing 10 | import theano.tensor.shared_randomstreams 11 | 12 | from logistic_sgd import LogisticRegression 13 | from load_data import load_mnist 14 | 15 | class HiddenLayer(object): 16 | def __init__(self, rng, input, n_in, n_out, 17 | activation, W=None, b=None, 18 | use_bias=False, sparse_init=15): 19 | 20 | self.input = input 21 | self.activation = activation 22 | 23 | if W is None: 24 | if sparse_init is not None: 25 | W_values = np.zeros((n_in, n_out), dtype=theano.config.floatX) 26 | for i in xrange(n_out): 27 | for j in xrange(sparse_init): 28 | idx = rng.randint(0, n_in) 29 | # don't worry about getting exactly sparse_init nonzeroese 30 | #while W_values[idx, i] != 0.: 31 | # idx = rng.randint(0, n_in) 32 | W_values[idx, i] = rng.randn()/np.sqrt(sparse_init) * 1.2 33 | else: 34 | W_values = 1. * np.asarray(rng.standard_normal( 35 | size=(n_in, n_out)), dtype=theano.config.floatX) / np.float(n_in)**0.5 36 | W = theano.shared(value=W_values, name='W') 37 | if b is None: 38 | b_values = np.zeros((n_out,), dtype=theano.config.floatX) 39 | b = theano.shared(value=b_values, name='b') 40 | self.W = W 41 | self.b = b 42 | 43 | if use_bias: 44 | lin_output = T.dot(input, self.W) + self.b 45 | else: 46 | lin_output = T.dot(input, self.W) 47 | 48 | self.output = (lin_output if activation is None else activation(lin_output)) 49 | # parameters of the model 50 | if use_bias: 51 | self.params = [self.W, self.b] 52 | else: 53 | self.params = [self.W] 54 | 55 | 56 | class DeepAE(object): 57 | """ 58 | Deep Autoencoder. 59 | """ 60 | def __init__(self, 61 | rng, 62 | input, 63 | layer_sizes, 64 | use_bias=True, objective='crossent'): 65 | # Activation functions are all sigmoid except 66 | # "coding" layer that is linear. 67 | # TODO(ben): Make sure layer after coding layer is sigmoid. 68 | layer_acts = [T.nnet.sigmoid for l in layer_sizes[1:-1]] 69 | layer_acts = layer_acts + [None, T.nnet.sigmoid] + layer_acts 70 | # use for untied weights 71 | #layer_acts = layer_acts + [None, None] + layer_acts 72 | layer_sizes = layer_sizes + layer_sizes[:-1][::-1] 73 | print 'Layer sizes:',layer_sizes 74 | print 'Layer acts:',layer_acts 75 | 76 | # Set up all the hidden layers 77 | weight_matrix_sizes = zip(layer_sizes, layer_sizes[1:]) 78 | self.layers = [] 79 | next_layer_input = input 80 | idx = 0 81 | for n_in, n_out in weight_matrix_sizes: 82 | print (n_in,n_out), layer_acts[idx] 83 | #if idx == 1: 84 | # W_in = next_layer.W.T 85 | #else: 86 | # W_in = None 87 | next_layer = HiddenLayer(rng=rng, 88 | input=next_layer_input, 89 | activation=layer_acts[idx], 90 | n_in=n_in, n_out=n_out, 91 | use_bias=use_bias) 92 | self.layers.append(next_layer) 93 | next_layer_input = next_layer.output 94 | idx += 1 95 | xpred = next_layer_input 96 | # Compute cost function, making sure to sum across data examples 97 | # so that we can properly average across minibatches. 98 | if objective == 'crossent': 99 | self.cost = (-input * T.log(xpred) - (1 - input) * T.log(1 - xpred)).sum(axis=1).sum() 100 | else: 101 | self.cost = ((input-xpred)**2).sum(axis=1).sum() 102 | # Grab all the parameters together. 103 | #XXX 104 | self.params = [ param for layer in self.layers for param in layer.params ] 105 | #self.params = [self.layers[0].W] 106 | 107 | def convert_variable(x): 108 | if x.ndim == 1: 109 | return T.vector(x.name, dtype=x.dtype) 110 | else: 111 | return T.matrix(x.name, dtype=x.dtype) 112 | 113 | def build_f_df(layer_sizes, dropout=False, **kwargs): 114 | print '... building the model' 115 | x = T.matrix('x') # the data is presented as rasterized images 116 | y = T.ivector('y') # the labels are presented as 1D vector of 117 | # [int] labels 118 | # Make sure initialization is repeatable 119 | rng = np.random.RandomState(1234) 120 | # construct the MLP class 121 | dae = DeepAE(rng=rng, input=x, 122 | layer_sizes=layer_sizes, **kwargs) 123 | # Build the expresson for the cost function. 124 | cost = dae.cost 125 | gparams = [] 126 | for param in dae.params: 127 | gparam = T.grad(cost, param) 128 | gparams.append(gparam) 129 | symbolic_params = [convert_variable(param) for param in dae.params] 130 | givens = dict(zip(dae.params, symbolic_params)) 131 | f_df = theano.function(inputs=symbolic_params + [x], outputs=[cost] + gparams, givens=givens) 132 | return f_df, dae 133 | -------------------------------------------------------------------------------- /generate_figures/dropout/load_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cPickle 3 | import gzip 4 | import os 5 | import sys 6 | 7 | import theano 8 | import theano.tensor as T 9 | 10 | def _shared_dataset(data_xy): 11 | """ Function that loads the dataset into shared variables 12 | 13 | The reason we store our dataset in shared variables is to allow 14 | Theano to copy it into the GPU memory (when code is run on GPU). 15 | Since copying data into the GPU is slow, copying a minibatch everytime 16 | is needed (the default behaviour if the data is not in a shared 17 | variable) would lead to a large decrease in performance. 18 | """ 19 | data_x, data_y = data_xy 20 | shared_x = theano.shared(np.asarray(data_x, 21 | dtype=theano.config.floatX)) 22 | shared_y = theano.shared(np.asarray(data_y, 23 | dtype=theano.config.floatX)) 24 | # When storing data on the GPU it has to be stored as floats 25 | # therefore we will store the labels as ``floatX`` as well 26 | # (``shared_y`` does exactly that). But during our computations 27 | # we need them as ints (we use labels as index, and if they are 28 | # floats it doesn't make sense) therefore instead of returning 29 | # ``shared_y`` we will have to cast it to int. This little hack 30 | # lets ous get around this issue 31 | return shared_x, T.cast(shared_y, 'int32') 32 | 33 | def load_mnist(path): 34 | mnist = np.load(path) 35 | train_set_x = mnist['train_data'] 36 | train_set_y = mnist['train_labels'] 37 | test_set_x = mnist['test_data'] 38 | test_set_y = mnist['test_labels'] 39 | 40 | train_set_x, train_set_y = _shared_dataset((train_set_x, train_set_y)) 41 | test_set_x, test_set_y = _shared_dataset((test_set_x, test_set_y)) 42 | valid_set_x, valid_set_y = test_set_x, test_set_y 43 | 44 | rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y), 45 | (test_set_x, test_set_y)] 46 | return rval 47 | 48 | 49 | def load_umontreal_data(dataset): 50 | ''' Loads the dataset 51 | 52 | :type dataset: string 53 | :param dataset: the path to the dataset (here MNIST) 54 | ''' 55 | 56 | ############# 57 | # LOAD DATA # 58 | ############# 59 | 60 | # Download the MNIST dataset if it is not present 61 | data_dir, data_file = os.path.split(dataset) 62 | if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz': 63 | import urllib 64 | origin = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' 65 | print 'Downloading data from %s' % origin 66 | urllib.urlretrieve(origin, dataset) 67 | 68 | print '... loading data' 69 | 70 | # Load the dataset 71 | f = gzip.open(dataset, 'rb') 72 | train_set, valid_set, test_set = cPickle.load(f) 73 | f.close() 74 | #train_set, valid_set, test_set format: tuple(input, target) 75 | #input is an np.ndarray of 2 dimensions (a matrix) 76 | #witch row's correspond to an example. target is a 77 | #np.ndarray of 1 dimensions (vector)) that have the same length as 78 | #the number of rows in the input. It should give the target 79 | #target to the example with the same index in the input. 80 | 81 | def _shared_dataset(data_xy): 82 | """ Function that loads the dataset into shared variables 83 | 84 | The reason we store our dataset in shared variables is to allow 85 | Theano to copy it into the GPU memory (when code is run on GPU). 86 | Since copying data into the GPU is slow, copying a minibatch everytime 87 | is needed (the default behaviour if the data is not in a shared 88 | variable) would lead to a large decrease in performance. 89 | """ 90 | data_x, data_y = data_xy 91 | shared_x = theano.shared(np.asarray(data_x, 92 | dtype=theano.config.floatX)) 93 | shared_y = theano.shared(np.asarray(data_y, 94 | dtype=theano.config.floatX)) 95 | # When storing data on the GPU it has to be stored as floats 96 | # therefore we will store the labels as ``floatX`` as well 97 | # (``shared_y`` does exactly that). But during our computations 98 | # we need them as ints (we use labels as index, and if they are 99 | # floats it doesn't make sense) therefore instead of returning 100 | # ``shared_y`` we will have to cast it to int. This little hack 101 | # lets ous get around this issue 102 | return shared_x, T.cast(shared_y, 'int32') 103 | 104 | test_set_x, test_set_y = _shared_dataset(test_set) 105 | valid_set_x, valid_set_y = _shared_dataset(valid_set) 106 | train_set_x, train_set_y = _shared_dataset(train_set) 107 | 108 | rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y), 109 | (test_set_x, test_set_y)] 110 | return rval 111 | 112 | -------------------------------------------------------------------------------- /generate_figures/dropout/logistic_sgd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import theano 4 | import theano.tensor as T 5 | 6 | class LogisticRegression(object): 7 | """Multi-class Logistic Regression Class 8 | 9 | The logistic regression is fully described by a weight matrix :math:`W` 10 | and bias vector :math:`b`. Classification is done by projecting data 11 | points onto a set of hyperplanes, the distance to which is used to 12 | determine a class membership probability. 13 | """ 14 | 15 | def __init__(self, input, n_in, n_out, W=None, b=None): 16 | """ Initialize the parameters of the logistic regression 17 | 18 | :type input: theano.tensor.TensorType 19 | :param input: symbolic variable that describes the input of the 20 | architecture (one minibatch) 21 | 22 | :type n_in: int 23 | :param n_in: number of input units, the dimension of the space in 24 | which the datapoints lie 25 | 26 | :type n_out: int 27 | :param n_out: number of output units, the dimension of the space in 28 | which the labels lie 29 | 30 | """ 31 | 32 | # initialize with 0 the weights W as a matrix of shape (n_in, n_out) 33 | if W is None: 34 | self.W = theano.shared( 35 | value=np.zeros((n_in, n_out), dtype=theano.config.floatX), 36 | name='W') 37 | else: 38 | self.W = W 39 | 40 | # initialize the baises b as a vector of n_out 0s 41 | if b is None: 42 | self.b = theano.shared( 43 | value=np.zeros((n_out,), dtype=theano.config.floatX), 44 | name='b') 45 | else: 46 | self.b = b 47 | 48 | # compute vector of class-membership probabilities in symbolic form 49 | self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b) 50 | 51 | # compute prediction as class whose probability is maximal in 52 | # symbolic form 53 | self.y_pred = T.argmax(self.p_y_given_x, axis=1) 54 | 55 | # parameters of the model 56 | self.params = [self.W, self.b] 57 | 58 | def negative_log_likelihood(self, y): 59 | """Return the mean of the negative log-likelihood of the prediction 60 | of this model under a given target distribution. 61 | 62 | .. math:: 63 | 64 | \frac{1}{|\mathcal{D}|} \mathcal{L} (\theta=\{W,b\}, \mathcal{D}) = 65 | \frac{1}{|\mathcal{D}|} \sum_{i=0}^{|\mathcal{D}|} \log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\ 66 | \ell (\theta=\{W,b\}, \mathcal{D}) 67 | 68 | :type y: theano.tensor.TensorType 69 | :param y: corresponds to a vector that gives for each example the 70 | correct label 71 | 72 | Note: we use the mean instead of the sum so that 73 | the learning rate is less dependent on the batch size 74 | """ 75 | # y.shape[0] is (symbolically) the number of rows in y, i.e., 76 | # number of examples (call it n) in the minibatch 77 | # T.arange(y.shape[0]) is a symbolic vector which will contain 78 | # [0,1,2,... n-1] T.log(self.p_y_given_x) is a matrix of 79 | # Log-Probabilities (call it LP) with one row per example and 80 | # one column per class LP[T.arange(y.shape[0]),y] is a vector 81 | # v containing [LP[0,y[0]], LP[1,y[1]], LP[2,y[2]], ..., 82 | # LP[n-1,y[n-1]]] and T.mean(LP[T.arange(y.shape[0]),y]) is 83 | # the mean (across minibatch examples) of the elements in v, 84 | # i.e., the mean log-likelihood across the minibatch. 85 | return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]), y]) 86 | 87 | def errors(self, y): 88 | """Return a float representing the number of errors in the minibatch ; 89 | zero one loss over the size of the minibatch 90 | 91 | :type y: theano.tensor.TensorType 92 | :param y: corresponds to a vector that gives for each example the 93 | correct label 94 | """ 95 | 96 | # check if y has same dimension of y_pred 97 | if y.ndim != self.y_pred.ndim: 98 | raise TypeError('y should have the same shape as self.y_pred', 99 | ('y', target.type, 'y_pred', self.y_pred.type)) 100 | # check if y is of the correct datatype 101 | if y.dtype.startswith('int'): 102 | # the T.neq operator returns a vector of 0s and 1s, where 1 103 | # represents a mistake in prediction 104 | return T.sum(T.neq(self.y_pred, y)) 105 | else: 106 | raise NotImplementedError() 107 | 108 | 109 | -------------------------------------------------------------------------------- /generate_figures/dropout/mlp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import time 5 | 6 | import theano 7 | import theano.tensor as T 8 | from theano.ifelse import ifelse 9 | import theano.printing 10 | import theano.tensor.shared_randomstreams 11 | 12 | from logistic_sgd import LogisticRegression 13 | from load_data import load_mnist 14 | 15 | class HiddenLayer(object): 16 | def __init__(self, rng, input, n_in, n_out, 17 | activation, W=None, b=None, 18 | use_bias=False): 19 | 20 | self.input = input 21 | self.activation = activation 22 | 23 | if W is None: 24 | #W_values = np.asarray(0.01 * rng.standard_normal( 25 | # size=(n_in, n_out)), dtype=theano.config.floatX) 26 | W_values = 2. * np.asarray(rng.standard_normal( 27 | size=(n_in, n_out)), dtype=theano.config.floatX) / np.float(n_in)**0.5 28 | W = theano.shared(value=W_values, name='W') 29 | 30 | if b is None: 31 | b_values = np.zeros((n_out,), dtype=theano.config.floatX) 32 | b = theano.shared(value=b_values, name='b') 33 | 34 | self.W = W 35 | self.b = b 36 | 37 | if use_bias: 38 | lin_output = T.dot(input, self.W) + self.b 39 | else: 40 | lin_output = T.dot(input, self.W) 41 | 42 | self.output = (lin_output if activation is None else activation(lin_output)) 43 | 44 | # parameters of the model 45 | if use_bias: 46 | self.params = [self.W, self.b] 47 | else: 48 | self.params = [self.W] 49 | 50 | 51 | def _dropout_from_layer(rng, layer, p): 52 | """p is the probablity of dropping a unit 53 | """ 54 | srng = theano.tensor.shared_randomstreams.RandomStreams( 55 | rng.randint(999999)) 56 | # p=1-p because 1's indicate keep and p is prob of dropping 57 | mask = srng.binomial(n=1, p=1-p, size=layer.shape) 58 | # The cast is important because 59 | # int * float32 = float64 which pulls things off the gpu 60 | output = layer * T.cast(mask, theano.config.floatX) 61 | return output 62 | 63 | class DropoutHiddenLayer(HiddenLayer): 64 | def __init__(self, rng, input, n_in, n_out, 65 | activation, use_bias, W=None, b=None): 66 | super(DropoutHiddenLayer, self).__init__( 67 | rng=rng, input=input, n_in=n_in, n_out=n_out, W=W, b=b, 68 | activation=activation, use_bias=use_bias) 69 | 70 | self.output = _dropout_from_layer(rng, self.output, p=0.5) 71 | 72 | 73 | class MLP(object): 74 | """A multilayer perceptron with all the trappings required to do dropout 75 | training. 76 | 77 | """ 78 | def __init__(self, 79 | rng, 80 | input, 81 | layer_sizes, 82 | use_bias=True, 83 | rectifier=None): 84 | 85 | if rectifier == 'soft': 86 | rectified_linear_activation = lambda x: T.nnet.softplus(x) 87 | elif rectifier == 'hard': 88 | rectified_linear_activation = lambda x: T.maximum(0.0, x) 89 | 90 | # Set up all the hidden layers 91 | weight_matrix_sizes = zip(layer_sizes, layer_sizes[1:]) 92 | self.layers = [] 93 | self.dropout_layers = [] 94 | next_layer_input = input 95 | # dropout the input with prob 0.2 96 | next_dropout_layer_input = _dropout_from_layer(rng, input, p=0.2) 97 | for n_in, n_out in weight_matrix_sizes[:-1]: 98 | next_dropout_layer = DropoutHiddenLayer(rng=rng, 99 | input=next_dropout_layer_input, 100 | activation=rectified_linear_activation, 101 | n_in=n_in, n_out=n_out, use_bias=use_bias) 102 | self.dropout_layers.append(next_dropout_layer) 103 | next_dropout_layer_input = next_dropout_layer.output 104 | 105 | # Reuse the paramters from the dropout layer here, in a different 106 | # path through the graph. 107 | next_layer = HiddenLayer(rng=rng, 108 | input=next_layer_input, 109 | activation=rectified_linear_activation, 110 | W=next_dropout_layer.W * 0.5, 111 | b=next_dropout_layer.b, 112 | n_in=n_in, n_out=n_out, 113 | use_bias=use_bias) 114 | self.layers.append(next_layer) 115 | next_layer_input = next_layer.output 116 | 117 | # Set up the output layer 118 | n_in, n_out = weight_matrix_sizes[-1] 119 | dropout_output_layer = LogisticRegression( 120 | input=next_dropout_layer_input, 121 | n_in=n_in, n_out=n_out) 122 | self.dropout_layers.append(dropout_output_layer) 123 | 124 | # Again, reuse paramters in the dropout output. 125 | output_layer = LogisticRegression( 126 | input=next_layer_input, 127 | W=dropout_output_layer.W * 0.5, 128 | b=dropout_output_layer.b, 129 | n_in=n_in, n_out=n_out) 130 | self.layers.append(output_layer) 131 | 132 | # Use the negative log likelihood of the logistic regression layer as 133 | # the objective. 134 | self.dropout_negative_log_likelihood = self.dropout_layers[-1].negative_log_likelihood 135 | self.dropout_errors = self.dropout_layers[-1].errors 136 | 137 | self.negative_log_likelihood = self.layers[-1].negative_log_likelihood 138 | self.errors = self.layers[-1].errors 139 | 140 | # Grab all the parameters together. 141 | self.params = [ param for layer in self.dropout_layers for param in layer.params ] 142 | 143 | def convert_variable(x): 144 | if x.ndim == 1: 145 | return T.vector(x.name, dtype=x.dtype) 146 | else: 147 | return T.matrix(x.name, dtype=x.dtype) 148 | 149 | def f_df_split(fnc, *args, **kwargs): 150 | result = fnc(*args, **kwargs) 151 | return result[0], results[1:] 152 | 153 | def build_f_df(layer_sizes, dropout=False, use_bias=False, rectifier=None): 154 | print '... building the model' 155 | x = T.matrix('x') # the data is presented as rasterized images 156 | y = T.ivector('y') # the labels are presented as 1D vector of 157 | # [int] labels 158 | # Make sure initialization is repeatable 159 | rng = np.random.RandomState(1234) 160 | # construct the MLP class 161 | classifier = MLP(rng=rng, input=x, 162 | layer_sizes=layer_sizes, use_bias=use_bias, rectifier=rectifier) 163 | # Build the expresson for the cost function. 164 | cost = classifier.negative_log_likelihood(y) 165 | dropout_cost = classifier.dropout_negative_log_likelihood(y) 166 | # Compile theano function for testing. 167 | # test_model = theano.function(inputs=[index], 168 | # outputs=classifier.errors(y), 169 | # givens={ 170 | # x: test_set_x[index * batch_size:(index + 1) * batch_size], 171 | # y: test_set_y[index * batch_size:(index + 1) * batch_size]}) 172 | # Compute gradients of the model wrt parameters 173 | gparams = [] 174 | for param in classifier.params: 175 | # Use the right cost function here to train with or without dropout. 176 | gparam = T.grad(dropout_cost if dropout else cost, param) 177 | gparams.append(gparam) 178 | symbolic_params = [convert_variable(param) for param in classifier.params] 179 | givens = dict(zip(classifier.params, symbolic_params)) 180 | f_df = theano.function(inputs=symbolic_params + [x, y], outputs=[cost] + gparams, givens=givens) 181 | return f_df, classifier 182 | 183 | 184 | def test_mlp( 185 | initial_learning_rate, 186 | learning_rate_decay, 187 | squared_filter_length_limit, 188 | n_epochs, 189 | batch_size, 190 | dropout, 191 | results_file_name, 192 | layer_sizes, 193 | dataset, 194 | use_bias): 195 | """ 196 | The dataset is the one from the mlp demo on deeplearning.net. This training 197 | function is lifted from there almost exactly. 198 | 199 | :type dataset: string 200 | :param dataset: the path of the MNIST dataset file from 201 | http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz 202 | 203 | 204 | """ 205 | datasets = load_mnist(dataset) 206 | train_set_x, train_set_y = datasets[0] 207 | valid_set_x, valid_set_y = datasets[1] 208 | test_set_x, test_set_y = datasets[2] 209 | 210 | # compute number of minibatches for training, validation and testing 211 | n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size 212 | n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size 213 | n_test_batches = test_set_x.get_value(borrow=True).shape[0] / batch_size 214 | 215 | ###################### 216 | # BUILD ACTUAL MODEL # 217 | ###################### 218 | 219 | print '... building the model' 220 | 221 | # allocate symbolic variables for the data 222 | index = T.lscalar() # index to a [mini]batch 223 | epoch = T.scalar() 224 | x = T.matrix('x') # the data is presented as rasterized images 225 | y = T.ivector('y') # the labels are presented as 1D vector of 226 | # [int] labels 227 | learning_rate = theano.shared(np.asarray(initial_learning_rate, 228 | dtype=theano.config.floatX)) 229 | 230 | rng = np.random.RandomState(1234) 231 | 232 | # construct the MLP class 233 | classifier = MLP(rng=rng, input=x, 234 | layer_sizes=layer_sizes, use_bias=use_bias) 235 | 236 | # Build the expresson for the cost function. 237 | cost = classifier.negative_log_likelihood(y) 238 | dropout_cost = classifier.dropout_negative_log_likelihood(y) 239 | 240 | # Compile theano function for testing. 241 | test_model = theano.function(inputs=[index], 242 | outputs=classifier.errors(y), 243 | givens={ 244 | x: test_set_x[index * batch_size:(index + 1) * batch_size], 245 | y: test_set_y[index * batch_size:(index + 1) * batch_size]}) 246 | #theano.printing.pydotprint(test_model, outfile="test_file.png", 247 | # var_with_name_simple=True) 248 | 249 | # Compile theano function for validation. 250 | validate_model = theano.function(inputs=[index], 251 | outputs=classifier.errors(y), 252 | givens={ 253 | x: valid_set_x[index * batch_size:(index + 1) * batch_size], 254 | y: valid_set_y[index * batch_size:(index + 1) * batch_size]}) 255 | #theano.printing.pydotprint(validate_model, outfile="validate_file.png", 256 | # var_with_name_simple=True) 257 | 258 | # Compute gradients of the model wrt parameters 259 | gparams = [] 260 | for param in classifier.params: 261 | # Use the right cost function here to train with or without dropout. 262 | gparam = T.grad(dropout_cost if dropout else cost, param) 263 | gparams.append(gparam) 264 | 265 | # ... and allocate mmeory for momentum'd versions of the gradient 266 | gparams_mom = [] 267 | for param in classifier.params: 268 | gparam_mom = theano.shared(np.zeros(param.get_value(borrow=True).shape, 269 | dtype=theano.config.floatX)) 270 | gparams_mom.append(gparam_mom) 271 | 272 | # Compute momentum for the current epoch 273 | mom = ifelse(epoch < 500, 274 | 0.5*(1. - epoch/500.) + 0.99*(epoch/500.), 275 | 0.99) 276 | 277 | # Update the step direction using momentum 278 | updates = {} 279 | for gparam_mom, gparam in zip(gparams_mom, gparams): 280 | updates[gparam_mom] = mom * gparam_mom + (1. - mom) * gparam 281 | 282 | # ... and take a step along that direction 283 | for param, gparam_mom in zip(classifier.params, gparams_mom): 284 | stepped_param = param - (1.-mom) * learning_rate * gparam_mom 285 | 286 | # This is a silly hack to constrain the norms of the rows of the weight 287 | # matrices. This just checks if there are two dimensions to the 288 | # parameter and constrains it if so... maybe this is a bit silly but it 289 | # should work for now. 290 | if param.get_value(borrow=True).ndim == 2: 291 | squared_norms = T.sum(stepped_param**2, axis=1).reshape((stepped_param.shape[0],1)) 292 | scale = T.clip(T.sqrt(squared_filter_length_limit / squared_norms), 0., 1.) 293 | updates[param] = stepped_param * scale 294 | else: 295 | updates[param] = stepped_param 296 | 297 | 298 | # Compile theano function for training. This returns the training cost and 299 | # updates the model parameters. 300 | output = dropout_cost if dropout else cost 301 | train_model = theano.function(inputs=[epoch, index], outputs=output, 302 | updates=updates, 303 | givens={ 304 | x: train_set_x[index * batch_size:(index + 1) * batch_size], 305 | y: train_set_y[index * batch_size:(index + 1) * batch_size]}) 306 | #theano.printing.pydotprint(train_model, outfile="train_file.png", 307 | # var_with_name_simple=True) 308 | 309 | # Theano function to decay the learning rate, this is separate from the 310 | # training function because we only want to do this once each epoch instead 311 | # of after each minibatch. 312 | decay_learning_rate = theano.function(inputs=[], outputs=learning_rate, 313 | updates={learning_rate: learning_rate * learning_rate_decay}) 314 | 315 | ############### 316 | # TRAIN MODEL # 317 | ############### 318 | print '... training' 319 | 320 | best_params = None 321 | best_validation_errors = np.inf 322 | best_iter = 0 323 | test_score = 0. 324 | epoch_counter = 0 325 | start_time = time.clock() 326 | 327 | results_file = open(results_file_name, 'wb') 328 | 329 | while epoch_counter < n_epochs: 330 | # Train this epoch 331 | epoch_counter = epoch_counter + 1 332 | for minibatch_index in xrange(n_train_batches): 333 | minibatch_avg_cost = train_model(epoch_counter, minibatch_index) 334 | 335 | # Compute loss on validation set 336 | validation_losses = [validate_model(i) for i in xrange(n_valid_batches)] 337 | this_validation_errors = np.sum(validation_losses) 338 | 339 | # Report and save progress. 340 | print "epoch {}, test error {}, learning_rate={}{}".format( 341 | epoch_counter, this_validation_errors, 342 | learning_rate.get_value(borrow=True), 343 | " **" if this_validation_errors < best_validation_errors else "") 344 | 345 | best_validation_errors = min(best_validation_errors, 346 | this_validation_errors) 347 | results_file.write("{0}\n".format(this_validation_errors)) 348 | results_file.flush() 349 | 350 | new_learning_rate = decay_learning_rate() 351 | 352 | end_time = time.clock() 353 | print(('Optimization complete. Best validation score of %f %% ' 354 | 'obtained at iteration %i, with test performance %f %%') % 355 | (best_validation_errors * 100., best_iter, test_score * 100.)) 356 | print >> sys.stderr, ('The code for file ' + 357 | os.path.split(__file__)[1] + 358 | ' ran for %.2fm' % ((end_time - start_time) / 60.)) 359 | 360 | 361 | if __name__ == '__main__': 362 | import sys 363 | 364 | initial_learning_rate = 1.0 365 | learning_rate_decay = 0.998 366 | squared_filter_length_limit = 15.0 367 | n_epochs = 3000 368 | batch_size = 100 369 | layer_sizes = [ 28*28, 1200, 1200, 10 ] 370 | dataset = 'data/mnist_batches.npz' 371 | #dataset = 'data/mnist.pkl.gz' 372 | 373 | build_f_df(layer_sizes, use_bias=False) 374 | 375 | if len(sys.argv) < 2: 376 | print "Usage: {0} [dropout|backprop]".format(sys.argv[0]) 377 | exit(1) 378 | 379 | elif sys.argv[1] == "dropout": 380 | dropout = True 381 | results_file_name = "results_dropout.txt" 382 | 383 | elif sys.argv[1] == "backprop": 384 | dropout = False 385 | results_file_name = "results_backprop.txt" 386 | 387 | else: 388 | print "I don't know how to '{0}'".format(sys.argv[1]) 389 | exit(1) 390 | 391 | test_mlp(initial_learning_rate=initial_learning_rate, 392 | learning_rate_decay=learning_rate_decay, 393 | squared_filter_length_limit=squared_filter_length_limit, 394 | n_epochs=n_epochs, 395 | batch_size=batch_size, 396 | layer_sizes=layer_sizes, 397 | dropout=dropout, 398 | dataset=dataset, 399 | results_file_name=results_file_name, 400 | use_bias=False) 401 | 402 | -------------------------------------------------------------------------------- /generate_figures/dropout/plot_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ]; then 4 | TERM="dumb" 5 | PLOT_FILE=/dev/stdout 6 | else 7 | TERM="png" 8 | PLOT_FILE="$1" 9 | fi 10 | 11 | SMOOTHING=0.0 12 | 13 | ( 14 | gnuplot 2>/dev/null < 1e5: 78 | return "$10^{%d}$"%int(np.log10(tt)) 79 | return "$%d$"%tt 80 | 81 | def plot_shared(M, M_arr, N_arr, T_arr, v_fixed, v_change): 82 | figsize=(2.5,2.,) 83 | 84 | idx = np.flatnonzero(M_arr == M) 85 | ord = np.argsort(N_arr[idx]) 86 | idx = idx[ord] 87 | if idx.shape[0] > 1: 88 | plt.figure(figsize=figsize) 89 | plt.plot(N_arr[idx], T_arr[idx], 'x--') 90 | plt.grid() 91 | ax = plt.axis() 92 | plt.axis([0, ax[1], 0, ax[3]]) 93 | nn = np.linspace(0, np.max(T_arr[idx]), 5) 94 | nnt = ["$%d$"%tt for tt in nn] 95 | nnt[1] = '' 96 | nnt[3] = '' 97 | plt.yticks(nn, nnt) 98 | nn = np.linspace(0, np.max(N_arr[idx]), 5) 99 | nnt = [convert_num(tt) for tt in nn] 100 | nnt[1] = '' 101 | nnt[2] = '' 102 | nnt[3] = '' 103 | plt.xticks(nn, nnt) 104 | plt.axes().set_axisbelow(True) 105 | plt.xlabel('$%s$'%v_change) 106 | plt.ylabel('Overhead (s)') 107 | if M > 10**3: 108 | plt.title('Fixed $%s=10^{%d}$'%(v_fixed, int(np.log10(M)))) 109 | else: 110 | plt.title('Fixed $%s=%d$'%(v_fixed, M)) 111 | try: 112 | plt.tight_layout() 113 | except: 114 | warnings.warn('tight_layout failed. try running with an Agg backend.') 115 | plt.savefig('figure_overhead_fixed%s.pdf'%(v_fixed)) 116 | 117 | 118 | def make_plots(M_arr, N_arr, T_arr): 119 | for M in np.unique(M_arr): 120 | plot_shared(M, M_arr, N_arr, T_arr, 'M', 'N') 121 | for N in np.unique(N_arr): 122 | plot_shared(N, N_arr, M_arr, T_arr, 'N', 'M') 123 | 124 | 125 | if __name__ == '__main__': 126 | M_arr, N_arr, T_arr = explore_MN() 127 | make_plots(M_arr, N_arr, T_arr) 128 | 129 | """ 130 | import figure_overhead 131 | reload(figure_overhead) 132 | 133 | M_arr, N_arr, T_arr = figure_overhead.explore_MN() 134 | figure_overhead.make_plots(M_arr, N_arr, T_arr) 135 | 136 | 137 | 138 | """ 139 | -------------------------------------------------------------------------------- /generate_figures/figures_cae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | cae.py 5 | 6 | u pythonic library for Contractive Auto-Encoders. This is 7 | for people who want to give CAEs a quick try and for people 8 | who want to understand how they are implemented. For this 9 | purpose we tried to make the code as simple and clean as possible. 10 | The only dependency is numpy, which is used to perform all 11 | expensive operations. The code is quite fast, however much better 12 | performance can be achieved using the Theano version of this code. 13 | 14 | Created by Yann N. Dauphin, Salah Rifai on 2012-01-17. 15 | Copyright (c) 2012 Yann N. Dauphin, Salah Rifai. All rights reserved. 16 | """ 17 | 18 | import numpy as np 19 | 20 | 21 | class CAE(object): 22 | """ 23 | A Contractive Auto-Encoder (CAE) with sigmoid input units and sigmoid 24 | hidden units. 25 | """ 26 | def __init__(self, 27 | n_hiddens=1024, 28 | jacobi_penalty=0.1, 29 | W=None, 30 | hidbias=None, 31 | visbias=None): 32 | """ 33 | Initialize a CAE. 34 | 35 | Parameters 36 | ---------- 37 | n_hiddens : int, optional 38 | Number of binary hidden units 39 | jacobi_penalty : float, optional 40 | Scalar by which to multiply the gradients coming from the jacobian 41 | penalty. 42 | W : array-like, shape (n_inputs, n_hiddens), optional 43 | Weight matrix, where n_inputs in the number of input 44 | units and n_hiddens is the number of hidden units. 45 | hidbias : array-like, shape (n_hiddens,), optional 46 | Biases of the hidden units 47 | visbias : array-like, shape (n_inputs,), optional 48 | Biases of the input units 49 | """ 50 | self.n_hiddens = n_hiddens 51 | self.jacobi_penalty = jacobi_penalty 52 | self.W = W 53 | self.hidbias = hidbias 54 | self.visbias = visbias 55 | 56 | 57 | def _sigmoid(self, x): 58 | """ 59 | Implements the logistic function. 60 | 61 | Parameters 62 | ---------- 63 | x: array-like, shape (M, N) 64 | 65 | Returns 66 | ------- 67 | x_new: array-like, shape (M, N) 68 | """ 69 | return 1. / (1. + np.exp(-np.maximum(np.minimum(x, 100), -100))) 70 | 71 | def encode(self, x): 72 | """ 73 | Computes the hidden code for the input {\bf x}. 74 | 75 | Parameters 76 | ---------- 77 | x: array-like, shape (n_examples, n_inputs) 78 | 79 | Returns 80 | ------- 81 | h: array-like, shape (n_examples, n_hiddens) 82 | """ 83 | return self._sigmoid(np.dot(x, self.W) + self.hidbias) 84 | 85 | def decode(self, h): 86 | """ 87 | Compute the reconstruction from the hidden code {\bf h}. 88 | 89 | Parameters 90 | ---------- 91 | h: array-like, shape (n_examples, n_hiddens) 92 | 93 | Returns 94 | ------- 95 | x: array-like, shape (n_examples, n_inputs) 96 | """ 97 | return self._sigmoid(np.dot(h, self.W.T) + self.visbias) 98 | 99 | def reconstruct(self, x): 100 | """ 101 | Compute the reconstruction of the input {\bf x}. 102 | 103 | Parameters 104 | ---------- 105 | x: array-like, shape (n_examples, n_inputs) 106 | 107 | Returns 108 | ------- 109 | x_new: array-like, shape (n_examples, n_inputs) 110 | """ 111 | return self.decode(self.encode(x)) 112 | 113 | def loss(self, x, h=None, r=None): 114 | """ 115 | Computes the error of the model with respect 116 | to the total cost. 117 | 118 | Parameters 119 | ---------- 120 | x: array-like, shape (n_examples, n_inputs) 121 | h: array-like, shape (n_examples, n_hiddens), optional 122 | r: array-like, shape (n_examples, n_inputs), optional 123 | 124 | Returns 125 | ------- 126 | loss: array-like, shape (n_examples,) 127 | """ 128 | if h == None: 129 | h = self.encode(x) 130 | if r == None: 131 | r = self.decode(h) 132 | 133 | def _reconstruction_loss(h, r): 134 | """ 135 | Computes the error of the model with respect 136 | to the reconstruction (L2) cost. 137 | """ 138 | return 1/2. * ((r-x)**2).sum()/x.shape[0] 139 | 140 | def _jacobi_loss(h): 141 | """ 142 | Computes the error of the model with respect 143 | the Frobenius norm of the jacobian. 144 | """ 145 | return ((h *(1-h))**2 * (self.W**2).sum(0)).sum()/x.shape[0] 146 | recon_loss = _reconstruction_loss(h, r) 147 | jacobi_loss = _jacobi_loss(h) 148 | return (recon_loss + self.jacobi_penalty * jacobi_loss) 149 | 150 | def get_params(self): 151 | return dict(W=self.W, hidbias=self.hidbias, visbias=self.visbias) 152 | 153 | def set_params(self, theta): 154 | self.W = theta['W'] 155 | self.hidbias = theta['hidbias'] 156 | self.visbias = theta['visbias'] 157 | 158 | def f_df(self, theta, x): 159 | """ 160 | Compute objective and gradient of the CAE objective using the 161 | examples {\bf x}. 162 | 163 | Parameters 164 | ---------- 165 | x: array-like, shape (n_examples, n_inputs) 166 | 167 | Parameters 168 | ---------- 169 | loss: array-like, shape (n_examples,) 170 | Value of the loss function for each example before the step. 171 | """ 172 | self.set_params(theta) 173 | h = self.encode(x) 174 | r = self.decode(h) 175 | def _contraction_jacobian(): 176 | """ 177 | Compute the gradient of the contraction cost w.r.t parameters. 178 | """ 179 | a = 2*(h * (1 - h))**2 180 | d = ((1 - 2 * h) * a * (self.W**2).sum(0)[None, :]) 181 | b = np.dot(x.T / x.shape[0], d) 182 | c = a.mean(0) * self.W 183 | return (b + c), d.mean(0) 184 | 185 | def _reconstruction_jacobian(): 186 | """ 187 | Compute the gradient of the reconstruction cost w.r.t parameters. 188 | """ 189 | dr = (r - x) / x.shape[0] 190 | dr *= r * (1-r) 191 | dd = np.dot(dr.T, h) 192 | dh = np.dot(dr, self.W) * h * (1. - h) 193 | de = np.dot(x.T, dh) 194 | return (dd + de), dr.sum(0), dh.sum(0) 195 | 196 | W_rec, c_rec, b_rec = _reconstruction_jacobian() 197 | W_con, b_con = _contraction_jacobian() 198 | dW = W_rec + self.jacobi_penalty * W_con 199 | dhidbias = b_rec + self.jacobi_penalty * b_con 200 | dvisbias = c_rec 201 | return self.loss(x, h, r), dict(W=dW, hidbias=dhidbias, visbias=dvisbias); 202 | 203 | def init_weights(self, n_input, dtype=np.float32): 204 | self.W = np.asarray(np.random.uniform( 205 | #low=-4*np.sqrt(6./(n_input+self.n_hiddens)), 206 | #high=4*np.sqrt(6./(n_input+self.n_hiddens)), 207 | low=-1./np.sqrt(self.n_hiddens), 208 | high=1./np.sqrt(self.n_hiddens), 209 | size=(n_input, self.n_hiddens)), dtype=dtype) 210 | self.hidbias = np.zeros(self.n_hiddens, dtype=dtype) 211 | self.visbias = np.zeros(n_input, dtype=dtype) 212 | return self.get_params() 213 | -------------------------------------------------------------------------------- /generate_figures/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model classes for each of the demo cases. Each class contains 3 | an objective function f_df, initial parameters theta_init, a 4 | reference for each subfunction subfunction_references, and a 5 | set of full_objective_references that are evaluated every update step 6 | to make the plots of objective function vs. learning iteration. 7 | 8 | This is designed to be called by figure_convergence.py. 9 | 10 | 11 | Copyright 2014 Jascha Sohl-Dickstein 12 | Licensed under the Apache License, Version 2.0 (the "License"); 13 | you may not use this file except in compliance with the License. 14 | You may obtain a copy of the License at 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | 23 | import numpy as np 24 | import scipy.special 25 | import warnings 26 | import random 27 | from figures_cae import CAE 28 | from os.path import join 29 | 30 | try: 31 | import pyGLM.simGLM as glm 32 | import pyGLM.gabor as gabor 33 | except: 34 | pass 35 | 36 | # numpy < 1.7 does not have np.random.choice 37 | def my_random_choice(n, k, replace): 38 | perm = np.random.permutation(n) 39 | return perm[:k] 40 | if hasattr(np.random, 'choice'): 41 | random_choice = np.random.choice 42 | else: 43 | random_choice = my_random_choice 44 | 45 | 46 | class toy: 47 | """ 48 | Toy problem. Sum of squared errors from random means, raised 49 | to random powers. 50 | """ 51 | 52 | def __init__(self, num_subfunctions=100, num_dims=10): 53 | self.name = '||x-u||^a, a in U[1.5,4.5]' 54 | 55 | # create the array of subfunction identifiers 56 | self.subfunction_references = [] 57 | N = num_subfunctions 58 | for i in range(N): 59 | npow = np.random.rand()*3. + 1.5 60 | mn = np.random.randn(num_dims,1) 61 | self.subfunction_references.append([npow,mn]) 62 | self.full_objective_references = self.subfunction_references 63 | 64 | ## initialize parameters 65 | self.theta_init = np.random.randn(num_dims,1) 66 | 67 | def f_df(self, x, args): 68 | npow = args[0]/2. 69 | mn = args[1] 70 | f = np.sum(((x-mn)**2)**npow) 71 | df = npow*((x-mn)**2)**(npow-1.)*2*(x-mn) 72 | scl = 1. / np.prod(x.shape) 73 | return f*scl, df*scl 74 | 75 | 76 | class Hopfield: 77 | def __init__(self, num_subfunctions=100, reg=1., scale_by_N=True): 78 | """ 79 | Train a Hopfield network/Ising model using MPF. 80 | 81 | Adapted from code by Chris Hillar, Kilian Koepsell, Jascha Sohl-Dickstein, 2011 82 | 83 | TODO insert Hopfield and MPF references. 84 | """ 85 | self.name = 'Hopfield' 86 | self.reg = reg/num_subfunctions 87 | 88 | # Load data 89 | X, _ = load_mnist() 90 | 91 | # binarize data 92 | X = (np.sign(X-0.5)+1)/2 93 | 94 | # only keep units which actually change state 95 | gd = ((np.sum(X,axis=1) > 0) & (np.sum(1-X,axis=1) > 0)) 96 | X = X[gd,:] 97 | # TODO -- discard units with correlation of exactly 1? 98 | 99 | # break the data up into minibatches 100 | self.subfunction_references = [] 101 | for mb in range(num_subfunctions): 102 | self.subfunction_references.append(X[:, mb::num_subfunctions].T) 103 | #self.full_objective_references = (X[:, random_choice(X.shape[1], 10000, replace=False)].copy().T,) 104 | self.full_objective_references = self.subfunction_references 105 | 106 | if scale_by_N: 107 | self.scl = float(num_subfunctions) / float(X.shape[1]) 108 | else: 109 | self.scl = 100. / float(X.shape[1]) 110 | 111 | # parameter initialization 112 | self.theta_init = np.random.randn(X.shape[0], X.shape[0])/np.sqrt(X.shape[0])/10. 113 | self.theta_init = (self.theta_init + self.theta_init.T)/2. 114 | 115 | 116 | def f_df(self, J, X): 117 | J = (J + J.T)/2. 118 | X = np.atleast_2d(X) 119 | S = 2 * X - 1 120 | Kfull = np.exp(-S * np.dot(X, J.T) + .5 * np.diag(J)[None, :]) 121 | dJ = -np.dot(X.T, Kfull * S) + .5 * np.diag(Kfull.sum(0)) 122 | dJ = (dJ + dJ.T)/2. 123 | dJ = np.nan_to_num(dJ) 124 | 125 | K = Kfull.sum() 126 | if not np.isfinite(K): 127 | K = 1e50 128 | 129 | K *= self.scl 130 | dJ *= self.scl 131 | 132 | K += self.reg * np.sum(J**2) 133 | dJ += 2. * self.reg * J 134 | return K, dJ 135 | 136 | 137 | class logistic: 138 | """ 139 | logistic regression on "protein" dataset 140 | """ 141 | 142 | def __init__(self, num_subfunctions=100, scale_by_N=True): 143 | self.name = 'protein logistic regression' 144 | 145 | try: 146 | data = np.loadtxt('figure_data/bio_train.dat') 147 | except: 148 | raise Exception("Missing data. Download from and place in figure_data subdirectory.") 149 | 150 | target = data[:,[2]] 151 | feat = data[:,3:] 152 | feat = self.whiten(feat) 153 | feat = np.hstack((feat, np.ones((data.shape[0],1)))) 154 | 155 | # create the array of subfunction identifiers 156 | self.subfunction_references = [] 157 | N = num_subfunctions 158 | nper = float(feat.shape[0])/float(N) 159 | if scale_by_N: 160 | lam = 1./nper**2 161 | scl = 1./nper 162 | else: 163 | default_N = 100. 164 | nnp = float(feat.shape[0])/default_N 165 | lam = (1./nnp**2) * (default_N / float(N)) 166 | scl = 1./nnp 167 | for i in range(N): 168 | i_s = int(np.floor(i*nper)) 169 | i_f = int(np.floor((i+1)*nper)) 170 | if i == N-1: 171 | # don't drop any at the end 172 | i_f = target.shape[0] 173 | l_targ = target[i_s:i_f,:] 174 | l_feat = feat[i_s:i_f,:] 175 | self.subfunction_references.append([l_targ, l_feat, lam, scl, i]) 176 | self.full_objective_references = self.subfunction_references 177 | 178 | self.theta_init = np.random.randn(feat.shape[1],1)/np.sqrt(feat.shape[1])/10. # parameters initialization 179 | 180 | # remove first order dependencies from X, and scale to unit norm 181 | def whiten(self, X,fudge=1E-10): 182 | max_whiten_lines = 10000 183 | 184 | # zero mean 185 | X -= np.mean(X, axis=0).reshape((1,-1)) 186 | 187 | # the matrix X should be observations-by-components 188 | # get the covariance matrix 189 | Xsu = X[:max_whiten_lines,:] 190 | Xcov = np.dot(Xsu.T,Xsu)/Xsu.shape[0] 191 | # eigenvalue decomposition of the covariance matrix 192 | try: 193 | d,V = np.linalg.eigh(Xcov) 194 | except: 195 | print("could not get eigenvectors and eigenvalues using numpy.linalg.eigh of ", Xcov.shape, Xcov) 196 | d,V = np.linalg.eig(Xcov) 197 | d = np.nan_to_num(d+fudge) 198 | d[d==0] = 1 199 | V = np.nan_to_num(V) 200 | 201 | # a fudge factor can be used so that eigenvectors associated with 202 | # small eigenvalues do not get overamplified. 203 | # TODO(jascha) D could be a vector not a matrix 204 | D = np.diag(1./np.sqrt(d)) 205 | 206 | D = np.nan_to_num(D) 207 | 208 | # whitening matrix 209 | W = np.dot(np.dot(V,D),V.T) 210 | # multiply by the whitening matrix 211 | Y = np.dot(X,W) 212 | return Y 213 | 214 | def sigmoid(self, u): 215 | return 1./(1.+np.exp(-u)) 216 | def f_df(self, x, args): 217 | target = args[0] 218 | feat = args[1] 219 | lam = args[2] 220 | scl = args[3] 221 | 222 | feat = feat*(2*target - 1) 223 | ip = -np.dot(feat, x.reshape((-1,1))) 224 | et = np.exp(ip) 225 | logL = np.log(1. + et) 226 | etrat = et/(1.+et) 227 | bd = np.nonzero(ip[:,0]>50)[0] 228 | logL[bd,:] = ip[bd,:] 229 | etrat[bd,:] = 1. 230 | logL = np.sum(logL) 231 | dlogL = -np.dot(feat.T, etrat) 232 | 233 | logL *= scl 234 | dlogL *= scl 235 | 236 | reg = lam*np.sum(x**2) 237 | dreg = 2.*lam*x 238 | 239 | return logL + reg, dlogL+dreg 240 | 241 | 242 | class ICA: 243 | """ 244 | ICA with Student's t-experts on MNIST images. 245 | """ 246 | 247 | def __init__(self, num_subfunctions=100): 248 | self.name = 'ICA' 249 | 250 | # Load data 251 | #X = load_cifar10_imagesonly() 252 | X, _ = load_mnist() 253 | 254 | # do PCA to eliminate degenerate dimensions and whiten 255 | C = np.dot(X, X.T) / X.shape[1] 256 | w, V = np.linalg.eigh(C) 257 | # take only the non-negligible eigenvalues 258 | mw = np.max(np.real(w)) 259 | max_ratio = 1e4 260 | gd = np.nonzero(np.real(w) > mw/max_ratio)[0] 261 | # # whiten 262 | # P = V[gd,:]*(np.real(w[gd])**(-0.5)).reshape((-1,1)) 263 | # don't whiten -- make the problem harder 264 | P = V[gd,:] 265 | X = np.dot(P, X) 266 | X /= np.std(X) 267 | 268 | # break the data up into minibatches 269 | self.subfunction_references = [] 270 | for mb in range(num_subfunctions): 271 | self.subfunction_references.append(X[:, mb::num_subfunctions]) 272 | # compute the full objective on all the data 273 | self.full_objective_references = self.subfunction_references 274 | # # the subset of the data used to compute the full objective function value 275 | # idx = random_choice(X.shape[1], 10000, replace=False) 276 | # self.full_objective_references = (X[:,idx].copy(),) 277 | 278 | ## initialize parameters 279 | num_dims = X.shape[0] 280 | self.theta_init = {'W':np.random.randn(num_dims, num_dims)/np.sqrt(num_dims), 281 | 'logalpha':np.random.randn(num_dims,1).ravel()} 282 | 283 | # rescale the objective and gradient so that the same hyperparameter ranges work for 284 | # ICA as for the other objectives 285 | self.scale = 1. / self.subfunction_references[0].shape[1] / 100. 286 | 287 | 288 | def f_df(self, params, X): 289 | """ 290 | ICA objective function and gradient, using a Student's t-distribution prior. 291 | The energy function has form: 292 | E = \sum_i \alpha_i \log( 1 + (\sum_j W_{ij} x_j )^2 ) 293 | 294 | params is a dictionary containing the filters W and the log of the 295 | X is the training data, with each column corresponding to a sample. 296 | 297 | L is the average negative log likelihood. 298 | dL is its gradient. 299 | """ 300 | W = params['W'] 301 | logalpha = params['logalpha'].reshape((-1,1)) 302 | alpha = np.exp(logalpha)+0.5 303 | 304 | ## calculate the energy 305 | ff = np.dot(W, X) 306 | ff2 = ff**2 307 | off2 = 1 + ff2 308 | lff2 = np.log( off2 ) 309 | alff2 = lff2 * alpha.reshape((-1,1)) 310 | E = np.sum(alff2) 311 | 312 | ## calculate the energy gradient 313 | # could just sum instead of using the rscl 1s vector. this is functionality 314 | # left over from MPF MATLAB code. May want it again in a future project though. 315 | rscl = np.ones((X.shape[1],1)) 316 | lt = (ff/off2) * alpha.reshape((-1,1)) 317 | dEdW = 2 * np.dot(lt * rscl.T, X.T) 318 | dEdalpha = np.dot(lff2, rscl) 319 | dEdlogalpha = (alpha-0.5) * dEdalpha 320 | 321 | ## calculate log Z 322 | nu = alpha * 2. - 1. 323 | #logZ = -np.log(scipy.special.gamma((nu + 1.) / 2.)) + 0.5 * np.log(np.pi) + \ 324 | # np.log(scipy.special.gamma(nu/2.)) 325 | logZ = -scipy.special.gammaln((nu + 1.) / 2.) + 0.5 * np.log(np.pi) + \ 326 | scipy.special.gammaln((nu/2.)) 327 | logZ = np.sum(logZ) 328 | ## DEBUG slogdet has memory leak! 329 | ## eg, call "a = np.linalg.slogdet(random.randn(5000,5000))" 330 | ## repeatedly, and watch memory usage. So, we do this with an 331 | ## explicit eigendecomposition instead 332 | ## logZ += -np.linalg.slogdet(W)[1] 333 | W2 = np.dot(W.T, W) 334 | W2eig, _ = np.linalg.eig(W2) 335 | logZ += -np.sum(np.log(W2eig))/2. 336 | 337 | ## calculate gradient of log Z 338 | # log determinant contribution 339 | dlogZdW = -np.linalg.inv(W).T 340 | if np.min(nu) < 0: 341 | dlogZdnu = np.zeros(nu.shape) 342 | warnings.warn('not a normalizable distribution!') 343 | E = np.inf 344 | else: 345 | dlogZdnu = -scipy.special.psi((nu + 1) / 2 )/2 + \ 346 | scipy.special.psi( nu/2 )/2 347 | dlogZdalpha = 2. * dlogZdnu 348 | dlogZdlogalpha = (alpha-0.5) * dlogZdalpha 349 | 350 | ## full objective and gradient 351 | L = (E + logZ) * self.scale 352 | dLdW = (dEdW + dlogZdW) * self.scale 353 | dLdlogalpha = (dEdlogalpha + dlogZdlogalpha) * self.scale 354 | 355 | ddL = {'W':dLdW, 'logalpha':dLdlogalpha.ravel()} 356 | 357 | if not np.isfinite(L): 358 | warnings.warn('numerical problems') 359 | L = np.inf 360 | 361 | return L, ddL 362 | 363 | 364 | class DeepAE: 365 | """ 366 | Deep Autoencoder from Hinton, G. E. and Salakhutdinov, R. R. (2006) 367 | """ 368 | def __init__(self, num_subfunctions=50, num_dims=10, objective='l2'): 369 | # don't introduce a Theano dependency until we have to 370 | from utils import _tonp 371 | 372 | self.name = 'DeepAE' 373 | layer_sizes = [ 28*28, 1000, 500, 250, 30] 374 | #layer_sizes = [ 28*28, 20] 375 | # Load data 376 | X, y = load_mnist() 377 | # break the data up into minibatches 378 | self.subfunction_references = [] 379 | for mb in range(num_subfunctions): 380 | self.subfunction_references.append([X[:, mb::num_subfunctions], y[mb::num_subfunctions]]) 381 | # evaluate on subset of training data 382 | self.n_full = 10000 383 | idx = random_choice(X.shape[1], self.n_full, replace=False) 384 | ##use all the training data for a smoother plot 385 | #idx = np.array(range(X.shape[1])) 386 | self.full_objective_references = [[X[:,idx].copy(), y[idx].copy()]] 387 | from dropout.deepae import build_f_df # here so theano not required for import 388 | self.theano_f_df, self.model = build_f_df(layer_sizes, use_bias=True, objective=objective) 389 | crossent_params = False 390 | if crossent_params: 391 | history = np.load('/home/poole/Sum-of-Functions-Optimizer/sfo_output.npz') 392 | out = dict(history=history['arr_0']) 393 | params = out['history'].item()['x']['SFO'] 394 | self.theta_init = params 395 | else: 396 | self.theta_init = [param.get_value() for param in self.model.params] 397 | 398 | def f_df(self, theta, args, gpu_batch_size=128): 399 | X = args[0].T 400 | y = args[1] 401 | rem = np.mod(len(X), gpu_batch_size) 402 | n_batches = (len(X) - rem) / gpu_batch_size 403 | splits = np.split(np.arange(len(X) - rem), n_batches) 404 | if rem > 0: 405 | splits.append(np.arange(len(X)-rem, len(X))) 406 | sum_results = None 407 | for split in splits: 408 | theano_args = theta + [X[split]] 409 | # Convert to float32 so that this works on GPU 410 | theano_args = [arg.astype(np.float32) for arg in theano_args] 411 | results = self.theano_f_df(*theano_args) 412 | results = [_tonp(result) for result in results] 413 | if sum_results is None: 414 | sum_results = results 415 | else: 416 | sum_results = [cur_res + new_res for cur_res, new_res in zip(results, sum_results)] 417 | # Divide by number of datapoints. 418 | sum_results = [result/len(X) for result in sum_results] 419 | return sum_results[0], sum_results[1:] 420 | 421 | class MLP: 422 | """ 423 | Multi-layer-perceptron 424 | """ 425 | def __init__(self, num_subfunctions=100, num_dims=10, rectifier='soft'): 426 | self.name = 'MLP' 427 | #layer_sizes = [ 28*28, 1200, 10 ] 428 | #layer_sizes = [ 28*28, 1200, 10 ] 429 | #layer_sizes = [ 28*28, 500, 120, num_dims ] 430 | #layer_sizes = [ 28*28, 120, 12, num_dims ] 431 | layer_sizes = [ 28*28, 1200, 1200, num_dims ] 432 | # Load data 433 | X, y = load_mnist() 434 | # break the data up into minibatches 435 | self.subfunction_references = [] 436 | for mb in range(num_subfunctions): 437 | self.subfunction_references.append([X[:, mb::num_subfunctions], y[mb::num_subfunctions]]) 438 | # evaluate on subset of training data 439 | idx = random_choice(X.shape[1], 5000, replace=False) 440 | #use all the training data for a smoother plot 441 | #idx = np.array(range(X.shape[1])) 442 | self.full_objective_references = [[X[:,idx].copy(), y[idx].copy()]] 443 | from dropout.mlp import build_f_df # here so theano not required for import 444 | self.theano_f_df, self.model = build_f_df(layer_sizes, rectifier=rectifier, 445 | use_bias=True) 446 | self.theta_init = [param.get_value() for param in self.model.params] 447 | def f_df(self, theta, args): 448 | X = args[0].T 449 | y = args[1] 450 | theano_args = theta + [X, y] 451 | results = self.theano_f_df(*theano_args) 452 | return results[0], results[1:] 453 | 454 | class MLP_hard(MLP): 455 | """ 456 | Multi-layer-perceptron with rectified-linear nonlinearity 457 | """ 458 | def __init__(self, num_subfunctions=100, num_dims=10): 459 | MLP.__init__(self, num_subfunctions=num_subfunctions, num_dims=num_dims, rectifier='hard') 460 | self.name += ' hard' 461 | class MLP_soft(MLP): 462 | """ 463 | Multi-layer-perceptron with sigmoid nonlinearity 464 | """ 465 | def __init__(self, num_subfunctions=100, num_dims=10): 466 | MLP.__init__(self, num_subfunctions=num_subfunctions, num_dims=num_dims, rectifier='soft') 467 | self.name += ' soft' 468 | 469 | class ContractiveAutoencoder: 470 | """ 471 | Contractive autoencoder on MNIST dataset. 472 | """ 473 | def __init__(self, num_subfunctions=100, num_dims=10): 474 | self.name = 'Contractive Autoencoder' 475 | 476 | # Load data 477 | X, y = load_mnist() 478 | # break the data up into minibatches 479 | self.subfunction_references = [] 480 | for mb in range(num_subfunctions): 481 | self.subfunction_references.append(X[:, mb::num_subfunctions]) 482 | #self.full_objective_references = (X[:,np.random.choice(X.shape[1], 1000, replace=False)].copy(),) 483 | self.full_objective_references = (X[:, random_choice(X.shape[1], 1000, replace=False)].copy(),) 484 | #self.full_objective_references = (X.copy(),) 485 | # Initialize CAE model 486 | self.cae = CAE(n_hiddens=256, jacobi_penalty=1.0) 487 | # Initialize parameters 488 | self.theta_init = self.cae.init_weights(X.shape[0], dtype=np.float64) 489 | 490 | def f_df(self, theta, X): 491 | return self.cae.f_df(theta, X.T) 492 | 493 | class PylearnModel: 494 | def __init__(self, filename, load_fnc, num_subfunctions=100, num_dims=10): 495 | # Import here so we don't depend on pylearn elsewhere 496 | from nnet.model_gradient import load_model 497 | self.name = 'Pylearn' 498 | X, y = load_fnc() 499 | batch_size = X.shape[0] / num_subfunctions 500 | mg = load_model(filename, batch_size=batch_size) 501 | self.mg = mg 502 | self.model = mg.model 503 | self._f_df = mg.f_df 504 | 505 | 506 | # break the data up into minibatches 507 | self.subfunction_references = [] 508 | for mb in range(num_subfunctions): 509 | self.subfunction_references.append([X[ mb::num_subfunctions,...], y[mb::num_subfunctions,...]]) 510 | # evaluate on subset of training data 511 | idx = random_choice(X.shape[0], 1000, replace=False) 512 | #use all the training data for a smoother plot 513 | #idx = np.array(range(X.shape[1])) 514 | self.full_objective_references = [[X[idx,...].copy(), y[idx,...].copy()]] 515 | self.theta_init = [param.get_value() for param in self.mg.params] 516 | 517 | def f_df(self, thetas, args): 518 | thetas32 = [theta.astype(np.float32) for theta in thetas] 519 | return self._f_df(thetas32, args) 520 | 521 | class MNISTConvNet(PylearnModel): 522 | def __init__(self, num_subfunctions=100, num_dims=10): 523 | fn = 'nnet/mnist.yaml' 524 | #super(ConvNet, self).__init__(fn, num_subfunctions, num_dims) 525 | PylearnModel.__init__(self,fn, load_mnist, num_subfunctions, num_dims) 526 | self.name += '_conv' 527 | 528 | class CIFARConvNet(PylearnModel): 529 | def __init__(self, num_subfunctions=100, num_dims=10): 530 | fn = 'nnet/conv.yaml' 531 | #super(ConvNet, self).__init__(fn, num_subfunctions, num_dims) 532 | PylearnModel.__init__(self,fn, load_cifar, num_subfunctions, num_dims) 533 | self.name += '_conv' 534 | 535 | 536 | class GLM: 537 | """ 538 | Train a GLM on simulated data. 539 | """ 540 | 541 | def __init__(self, num_subfunctions_ratio=0.05, baseDir='/home/nirum/data/retina/glm-feb-19/'): 542 | #def __init__(self, num_subfunctions_ratio=0.05, baseDir='/home/nirum/data/retina/glm-feb-19/small'): 543 | self.name = 'GLM' 544 | 545 | print('Initializing parameters...') 546 | 547 | ## FOR CUSTOM SIZES 548 | #self.params = glm.setParameters(m=5000, dh=10, ds=50) # small 549 | #self.params = glm.setParameters(m=5000, dh=10, ds=49) # small 550 | #self.params = glm.setParameters(m=20000, dh=100, ds=500) # large 551 | #self.params = glm.setParameters(m=20000, dh=25, ds=256, n=5) # huge 552 | #self.params = glm.setParameters(m=1e5, dh=10, ds=256, n=50) # Jascha huge 553 | #self.params = glm.setParameters(m=1e5, n=100, ds=256, dh=10) # shared 554 | 555 | ## FROM EXTERNAL DATA FILES 556 | # load sizes of external data 557 | shapes = np.load(join(baseDir, 'metadata.npz')) 558 | 559 | # set up GLM parameters 560 | self.params = glm.setParameters(dh=40, ds=shapes['stimSlicedShape'][1], n=shapes['rateShape'][1]) 561 | #self.params = glm.setParameters(dh=40, ds=shapes['stimSlicedShape'][1], n=2) 562 | 563 | ## initialize parameters 564 | print('Generating model...') 565 | #self.theta_true = glm.generateModel(self.params) 566 | self.theta_init = glm.generateModel(self.params) 567 | for key in self.theta_init.keys(): 568 | self.theta_init[key] /= 1e3 569 | 570 | #print('Simulating model to generate data...') 571 | #self.data = glm.generateData(self.theta_true, self.params) 572 | 573 | # load external data files as memmapped arrays 574 | print('Loading external data...') 575 | self.data = glm.loadExternalData('stimulus_sliced.dat', 'rates.dat', shapes, baseDir=baseDir) 576 | 577 | # all the data 578 | batch_size = self.data['x'].shape[0] 579 | #batch_size = 10000 580 | 581 | print('ready to go!') 582 | 583 | #trueMinimum, trueGrad = self.f_df(self.theta_true, (0, batch_size)) 584 | #print('Minimum for true parameters %g'%(trueMinimum)) 585 | 586 | #print('Norm of the gradient at the minimum:') 587 | #for key in trueGrad.keys(): 588 | #print('grad[' + key + ']: %g'%(np.linalg.norm(trueGrad[key]))) 589 | 590 | # print out some network information 591 | #glm.visualizeNetwork(self.theta_true, self.params, self.data) 592 | 593 | # break the data up into minibatches 594 | 595 | self.N = int(np.ceil(np.sqrt(batch_size)*num_subfunctions_ratio)) 596 | self.subfunction_references = [] 597 | samples_per_subfunction = int(np.floor(batch_size/self.N)) 598 | for mb in range(self.N): 599 | print(mb, self.N) 600 | start_idx = mb*samples_per_subfunction 601 | end_idx = (mb+1)*samples_per_subfunction 602 | self.subfunction_references.append((start_idx, end_idx,)) 603 | 604 | self.full_objective_references = self.subfunction_references 605 | print('initialized ...') 606 | #self.full_objective_references = random.sample(self.subfunction_references, int(num_subfunctions/10)) 607 | 608 | 609 | def f_df(self, theta, idx_range): 610 | """ 611 | objective assuming Poisson noise 612 | 613 | function [fval grad] = objPoissonGLM(theta, datapath) 614 | Computes the Poisson log-likelihood objective and gradient 615 | for the generalized linear model (GLM) 616 | """ 617 | 618 | data_subf = dict() 619 | for key in self.data.keys(): 620 | data_subf[key] = np.array(self.data[key][idx_range[0]:idx_range[1]]) 621 | 622 | return glm.f_df(theta, data_subf, self.params) 623 | 624 | 625 | 626 | # def load_fingerprints(nsamples=None): 627 | # """ load fingerprint image 628 | # 1 <= m <= 10 (subjects), 1 <= n <= 8 (fingers) 629 | # """ 630 | # from PIL import Image 631 | 632 | # img_list = [] 633 | # for m = range(1,11): 634 | # for n = range(1,9): 635 | # fname = os.path.join('figure_data/DB2_B', '1%02d_%d.tif'%(m,n)) 636 | # arr = np.array(Image.open(fname)) # unit8 637 | # img = arr.astype(np.double) 638 | # img = img[100:-64,28:-28] 639 | # return img[::2,::2] 640 | 641 | 642 | 643 | 644 | 645 | 646 | def load_cifar10_imagesonly(nsamples=None): 647 | X = np.load('figure_data/cifar10_images.npy') 648 | if nsamples == None: 649 | nsamples = X.shape[0] 650 | perm = random_choice(X.shape[0], nsamples, replace=False) 651 | X = X[perm, :] 652 | return X.T 653 | 654 | def load_cifar(nsamples=None): 655 | try: 656 | from pylearn2.utils import serial 657 | except: 658 | raise Exception("pylearn2 must be installed.") 659 | try: 660 | dset = serial.load('figure_data/train.pkl') 661 | except: 662 | raise Exception("Missing data. Download CIFAR!") 663 | X = dset.X 664 | y = dset.y 665 | # Convert to one hot 666 | one_hot = np.zeros((y.shape[0], 10), dtype='float32') 667 | for i in range(y.shape[0]): 668 | one_hot[i, y[i]] = 1 669 | 670 | if nsamples == None: 671 | nsamples = X.shape[0] 672 | perm = random_choice(X.shape[0], nsamples, replace=False) 673 | X = X[perm, :] 674 | X = X.reshape(-1, 3,32,32) 675 | X = X.astype(np.float32) 676 | y = one_hot[perm, :] 677 | return X,y 678 | 679 | 680 | def load_mnist(nsamples=None): 681 | """ 682 | Load the MNIST dataset. 683 | """ 684 | 685 | try: 686 | data = np.load('figure_data/mnist_train.npz') 687 | except: 688 | raise Exception("Missing data. Download mnist_train.npz from and place in figure_data subdirectory.") 689 | X = data['X'] 690 | y = data['y'] 691 | if nsamples == None: 692 | nsamples = X.shape[0] 693 | perm = random_choice(X.shape[0], nsamples, replace=False) 694 | X = X[perm, :] 695 | X = X.T 696 | y = y[perm] 697 | return X, y 698 | -------------------------------------------------------------------------------- /generate_figures/nnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sohl-Dickstein/Sum-of-Functions-Optimizer/36c66dfef68b344eca0ca9bb56d0a7bed2075036/generate_figures/nnet/__init__.py -------------------------------------------------------------------------------- /generate_figures/nnet/conv.yaml: -------------------------------------------------------------------------------- 1 | { 2 | model: !obj:pylearn2.models.mlp.MLP { 3 | batch_size: $batch_size, 4 | layers: [ 5 | !obj:pylearn2.models.mlp.ConvRectifiedLinear { 6 | layer_name: 'h0', 7 | #pad: 4, 8 | #tied_b: 1, 9 | W_lr_scale: .05, 10 | b_lr_scale: .05, 11 | output_channels: 48, 12 | #num_pieces: 2, 13 | kernel_shape: [8, 8], 14 | pool_shape: [4, 4], 15 | pool_stride: [2, 2], 16 | irange: .05, 17 | #max_kernel_norm: .9, 18 | }, 19 | !obj:pylearn2.models.mlp.ConvRectifiedLinear { 20 | layer_name: 'h1', 21 | #pad: 3, 22 | #tied_b: 1, 23 | W_lr_scale: .05, 24 | b_lr_scale: .05, 25 | output_channels: 128, 26 | #num_pieces: 2, 27 | kernel_shape: [8, 8], 28 | pool_shape: [4, 4], 29 | pool_stride: [2, 2], 30 | irange: .05, 31 | # max_kernel_norm: 1.9365, 32 | }, 33 | #!obj:pylearn2.models.maxout.MaxoutConvC01B { 34 | # pad: 3, 35 | # layer_name: 'h2', 36 | # tied_b: 1, 37 | # W_lr_scale: .05, 38 | # b_lr_scale: .05, 39 | # num_channels: 128, 40 | # num_pieces: 2, 41 | # kernel_shape: [5, 5], 42 | # pool_shape: [2, 2], 43 | # pool_stride: [2, 2], 44 | # irange: .005, 45 | # max_kernel_norm: 1.9365, 46 | #}, 47 | !obj:pylearn2.models.mlp.RectifiedLinear { 48 | layer_name: 'h3', 49 | irange: .005, 50 | #num_units: 240, 51 | #num_pieces: 5, 52 | dim: 240, 53 | #max_col_norm: 1.9 54 | }, 55 | !obj:pylearn2.models.mlp.Softmax { 56 | #max_col_norm: 1.9365, 57 | layer_name: 'y', 58 | n_classes: 10, 59 | irange: .05 60 | } 61 | ], 62 | input_space: !obj:pylearn2.space.Conv2DSpace { 63 | shape: [32, 32], 64 | num_channels: 3, 65 | axes: ['c', 0, 1, 'b'], 66 | }, 67 | }, 68 | algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { 69 | learning_rate: .1, 70 | learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { 71 | init_momentum: 0.5, 72 | }, 73 | #cost: !obj:pylearn2.costs.mlp.dropout.Dropout { 74 | # input_include_probs: { 'h0' : .8 }, 75 | # input_scales: { 'h0': 1. } 76 | #}, 77 | termination_criterion: !obj:pylearn2.termination_criteria.MonitorBased { 78 | channel_name: "valid_y_misclass", 79 | prop_decrease: 0., 80 | N: 100 81 | }, 82 | }, 83 | save_freq: 1 84 | } 85 | -------------------------------------------------------------------------------- /generate_figures/nnet/mnist.yaml: -------------------------------------------------------------------------------- 1 | !obj:pylearn2.train.Train { 2 | dataset: &train !obj:pylearn2.datasets.mnist.MNIST { 3 | which_set: 'train', 4 | one_hot: 1, 5 | axes: ['c', 0, 1, 'b'], 6 | start: 0, 7 | stop: 50000 8 | }, 9 | model: !obj:pylearn2.models.mlp.MLP { 10 | batch_size: 128, 11 | layers: [ 12 | !obj:pylearn2.models.maxout.MaxoutConvC01B { 13 | layer_name: 'h0', 14 | pad: 0, 15 | num_channels: 48, 16 | num_pieces: 2, 17 | kernel_shape: [8, 8], 18 | pool_shape: [4, 4], 19 | pool_stride: [2, 2], 20 | irange: .005, 21 | max_kernel_norm: .9, 22 | }, 23 | !obj:pylearn2.models.maxout.MaxoutConvC01B { 24 | layer_name: 'h1', 25 | pad: 3, 26 | num_channels: 48, 27 | num_pieces: 2, 28 | kernel_shape: [8, 8], 29 | pool_shape: [4, 4], 30 | pool_stride: [2, 2], 31 | irange: .005, 32 | max_kernel_norm: 1.9365, 33 | }, 34 | !obj:pylearn2.models.maxout.MaxoutConvC01B { 35 | pad: 3, 36 | layer_name: 'h2', 37 | num_channels: 24, 38 | num_pieces: 4, 39 | kernel_shape: [5, 5], 40 | pool_shape: [2, 2], 41 | pool_stride: [2, 2], 42 | irange: .005, 43 | max_kernel_norm: 1.9365, 44 | }, 45 | !obj:pylearn2.models.mlp.Softmax { 46 | max_col_norm: 1.9365, 47 | layer_name: 'y', 48 | n_classes: 10, 49 | irange: .005 50 | } 51 | ], 52 | input_space: !obj:pylearn2.space.Conv2DSpace { 53 | shape: [28, 28], 54 | num_channels: 1, 55 | axes: ['c', 0, 1, 'b'], 56 | }, 57 | }, 58 | algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { 59 | learning_rate: .05, 60 | learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { 61 | init_momentum: .5, 62 | }, 63 | monitoring_dataset: 64 | { 65 | 'valid' : !obj:pylearn2.datasets.mnist.MNIST { 66 | axes: ['c', 0, 1, 'b'], 67 | which_set: 'train', 68 | one_hot: 1, 69 | start: 50000, 70 | stop: 60000 71 | }, 72 | }, 73 | cost: !obj:pylearn2.costs.mlp.dropout.Dropout { 74 | input_include_probs: { 'h0' : .8 }, 75 | input_scales: { 'h0': 1. } 76 | }, 77 | termination_criterion: !obj:pylearn2.termination_criteria.MonitorBased { 78 | channel_name: "valid_y_misclass", 79 | prop_decrease: 0., 80 | N: 100 81 | }, 82 | update_callbacks: !obj:pylearn2.training_algorithms.sgd.ExponentialDecay { 83 | decay_factor: 1.00004, 84 | min_lr: .000001 85 | } 86 | }, 87 | extensions: [ 88 | !obj:pylearn2.train_extensions.best_params.MonitorBasedSaveBest { 89 | channel_name: 'valid_y_misclass', 90 | save_path: "${PYLEARN2_TRAIN_DIR}mnist_best.pkl" 91 | }, 92 | !obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor { 93 | start: 1, 94 | saturate: 250, 95 | final_momentum: .7 96 | } 97 | ] 98 | } 99 | -------------------------------------------------------------------------------- /generate_figures/nnet/model_gradient.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | from collections import OrderedDict 5 | from itertools import izip 6 | from string import Template 7 | 8 | from pylearn2.utils import serial, safe_zip 9 | from pylearn2.utils.data_specs import DataSpecsMapping 10 | from pylearn2.config.yaml_parse import load_path, load 11 | from theano.sandbox.cuda import CudaNdarray 12 | 13 | def _tonp(x): 14 | if type(x) not in [CudaNdarray, np.array, np.ndarray]: 15 | x = x.eval() 16 | return np.array(x) 17 | 18 | def load_model(filename, batch_size=100): 19 | out=Template(open(filename, 'r').read()).substitute({'batch_size':batch_size}) 20 | stuff = load(out) 21 | model = stuff['model'] 22 | #model.batch_size = batch_size 23 | #model.set_batch_size(batch_size) 24 | cost = stuff['algorithm'].cost 25 | if cost is None: 26 | cost = model.get_default_cost() 27 | mg = ModelGradient(model, cost,batch_size=batch_size) 28 | return mg 29 | 30 | 31 | class ModelGradient: 32 | def __init__(self, model, cost=None, batch_size=100): 33 | self.model = model 34 | self.model.set_batch_size(batch_size) 35 | self.model._test_batch_size = batch_size 36 | print 'it should really be ', batch_size 37 | self.cost = cost 38 | self.batch_size = batch_size 39 | self.setup() 40 | 41 | def setup(self): 42 | self.X = T.matrix('X') 43 | self.Y = T.matrix('Y') 44 | 45 | # Taken from pylearn2/training_algorithms/sgd.py 46 | 47 | 48 | data_specs = self.cost.get_data_specs(self.model) 49 | mapping = DataSpecsMapping(data_specs) 50 | space_tuple = mapping.flatten(data_specs[0], return_tuple=True) 51 | source_tuple = mapping.flatten(data_specs[1], return_tuple=True) 52 | 53 | # Build a flat tuple of Theano Variables, one for each space. 54 | # We want that so that if the same space/source is specified 55 | # more than once in data_specs, only one Theano Variable 56 | # is generated for it, and the corresponding value is passed 57 | # only once to the compiled Theano function. 58 | theano_args = [] 59 | for space, source in safe_zip(space_tuple, source_tuple): 60 | name = '%s[%s]' % (self.__class__.__name__, source) 61 | arg = space.make_theano_batch(name=name, batch_size = self.batch_size) 62 | theano_args.append(arg) 63 | print 'BATCH SIZE=',self.batch_size 64 | theano_args = tuple(theano_args) 65 | 66 | # Methods of `self.cost` need args to be passed in a format compatible 67 | # with data_specs 68 | nested_args = mapping.nest(theano_args) 69 | print self.cost 70 | fixed_var_descr = self.cost.get_fixed_var_descr(self.model, nested_args) 71 | print self.cost 72 | self.on_load_batch = fixed_var_descr.on_load_batch 73 | params = list(self.model.get_params()) 74 | self.X = nested_args[0] 75 | self.Y = nested_args[1] 76 | init_grads, updates = self.cost.get_gradients(self.model, nested_args) 77 | 78 | params = self.model.get_params() 79 | # We need to replace parameters with purely symbolic variables in case some are shared 80 | # Create gradient and cost functions 81 | self.params = params 82 | symbolic_params = [self._convert_variable(param) for param in params] 83 | givens = dict(zip(params, symbolic_params)) 84 | costfn = self.model.cost_from_X((self.X, self.Y)) 85 | gradfns = [init_grads[param] for param in params] 86 | #self.symbolic_params = symbolic_params 87 | #self._loss = theano.function(symbolic_para[self.X, self.Y], self.model.cost_from_X((self.X, self.Y)))#, givens=givens) 88 | #1/0 89 | print 'Compiling function...' 90 | self.theano_f_df = theano.function(inputs=symbolic_params + [self.X, self.Y], outputs=[costfn] + gradfns, givens=givens) 91 | print 'done' 92 | # self.grads = theano.function(symbolic_params + [self.X, self.Y], [init_grads[param] for param in params], givens=givens) 93 | # self._loss = theano.function(symbolic_params + [self.X, self.Y], self.model.cost(self.X, self.Y), givens=givens) 94 | # Maps params -> their derivative 95 | 96 | def f_df(self, theta, args): 97 | X = args[0] 98 | y = args[1] 99 | X = np.transpose(X,(1,2,3,0)) 100 | y = y 101 | nsamples = X.shape[-1] 102 | nbatches = nsamples / self.batch_size 103 | # lets hope it's actually divisible 104 | #XXX(ben): fix this 105 | cost = 0.0 106 | thetas = None 107 | idxs = np.array_split(np.arange(nsamples), nbatches) 108 | for idx in idxs: 109 | theano_args = theta + [X[...,idx], y[idx,...]] 110 | results = self.theano_f_df(*theano_args) 111 | results = [_tonp(result) for result in results] 112 | if thetas is None: 113 | thetas = results[1:] 114 | else: 115 | thetas = [np.array(t) + np.array(result) for t,result in zip(thetas,results[1:])] 116 | cost += results[0]/nbatches 117 | 118 | thetas = [np.array(theta)/nbatches for theta in thetas] 119 | 120 | # if X.shape[1] != self.batch_size: 121 | 122 | 123 | # print X.shape, y.shape 124 | # theano_args = theta + [X,y] 125 | # results = self.theano_f_df(*theano_args) 126 | # return results[0], results[1:] 127 | return cost, thetas 128 | 129 | def _convert_variable(self, x): 130 | return T.TensorType(x.dtype, x.broadcastable)(x.name) #'int32', broadcastable=())('myvar') 131 | 132 | 133 | 134 | if __name__ == '__main__': 135 | # Load train obj 136 | f = np.load('test.npz') 137 | model = f['model'].item() 138 | cost = f['cost'].item() 139 | m = ModelGradient(model, cost) 140 | p = model.weights.shape.eval()[0] 141 | X = np.random.randn(100,p).astype(np.float32) 142 | theta = m.get_params() 143 | m.f_df(theta, X) 144 | 1/0 145 | 146 | # Load cost function 147 | W = np.random.randn(10,10) 148 | W = theano.shared(W) 149 | 150 | cost = (W**2).sum() 151 | grad = T.grad(cost, W) 152 | 153 | params = [W] 154 | 1/0 155 | -------------------------------------------------------------------------------- /generate_figures/optimization_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains the model using a variety of optimization algorithms. 3 | This class also wraps the objective and gradient of the model, 4 | so that it can store a history of the objective during 5 | optimization. 6 | 7 | This is slower than directly calling the optimizers, because 8 | it periodically evaluates (and stores) the FULL objective rather 9 | than always evaluating only a single subfunction per update step. 10 | 11 | Designed to be used by figure_comparison*.py. 12 | 13 | Copyright 2014 Jascha Sohl-Dickstein 14 | Licensed under the Apache License, Version 2.0 (the "License"); 15 | you may not use this file except in compliance with the License. 16 | You may obtain a copy of the License at 17 | http://www.apache.org/licenses/LICENSE-2.0 18 | Unless required by applicable law or agreed to in writing, software 19 | distributed under the License is distributed on an "AS IS" BASIS, 20 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | See the License for the specific language governing permissions and 22 | limitations under the License. 23 | """ 24 | 25 | # allow SFO to be imported despite being in the parent directoy 26 | import sys 27 | sys.path.append("..") 28 | sys.path.append(".") 29 | 30 | from sfo import SFO 31 | from sag import SAG 32 | from adagrad import ADAGrad 33 | from collections import defaultdict 34 | from itertools import product 35 | import numpy as np 36 | from scipy.optimize import fmin_l_bfgs_b 37 | 38 | # numpy < 1.7 does not have np.random.choice 39 | def my_random_choice(n, k, replace): 40 | perm = np.random.permutation(n) 41 | return perm[:k] 42 | if hasattr(np.random, 'choice'): 43 | random_choice = np.random.choice 44 | else: 45 | random_choice = my_random_choice 46 | 47 | 48 | class train: 49 | """ 50 | Trains the model using a variety of optimization algorithms. 51 | This class also wraps the objective and gradient of the model, 52 | so that it can evaluate and store the full objective for each 53 | step in the optimization. 54 | 55 | This is WAY SLOWER than just calling the optimizers, because 56 | it evaluates the FULL objective and gradient instead of a single 57 | subfunction several times per pass. 58 | 59 | Designed to be used by figure_convergence.py. 60 | """ 61 | 62 | def __init__(self, model, calculate_full_objective=True, num_projection_dims=5, full_objective_per_pass=4): 63 | """ 64 | Trains the model using a variety of optimization algorithms. 65 | This class also wraps the objective and gradient of the model, 66 | so that it can evaluate and store the full objective for each 67 | step in the optimization. 68 | 69 | This is WAY SLOWER than just calling the optimizers, because 70 | it evaluates the FULL objective and gradient instead of a single 71 | subfunction several times per pass. 72 | 73 | Designed to be used by figure_convergence.py. 74 | """ 75 | 76 | self.model = model 77 | self.history = {'f':defaultdict(list), 'x_projection':defaultdict(list), 'events':defaultdict(list), 'x':defaultdict(list)} 78 | 79 | # we use SFO to flatten/unflatten parameters for the other optimizers 80 | self.x_map = SFO(self.model.f_df, self.model.theta_init, self.model.subfunction_references) 81 | self.xinit_flat = self.x_map.theta_original_to_flat(self.model.theta_init) 82 | self.calculate_full_objective = calculate_full_objective 83 | 84 | M = self.xinit_flat.shape[0] 85 | self.x_projection_matrix = np.random.randn(num_projection_dims, M)/np.sqrt(M) 86 | 87 | self.num_subfunctions = len(self.model.subfunction_references) 88 | self.full_objective_period = int(self.num_subfunctions/full_objective_per_pass) 89 | 90 | 91 | def f_df_wrapper(self, *args, **kwargs): 92 | """ 93 | This (slightly hacky) function stands between the optimizer and the objective function. 94 | It evaluates the objective on the full function every full_objective_function times a 95 | subfunction is evaluated, and stores the history of the full objective function value. 96 | """ 97 | 98 | ## call the true subfunction objective function, passing through all parameters 99 | f, df = self.model.f_df(*args, **kwargs) 100 | 101 | if len(self.history['f'][self.learner_name]) == 0: 102 | # this is the first time step for this learner 103 | self.last_f = np.inf 104 | self.last_idx = -1 105 | self.nsteps_this_learner = 0 106 | 107 | self.nsteps_this_learner += 1 108 | # only record the step every once every self.full_objective_period steps 109 | if np.mod(self.nsteps_this_learner, self.full_objective_period) != 1 and self.full_objective_period > 1: 110 | return f, df 111 | 112 | # the full objective function on all subfunctions 113 | if self.calculate_full_objective: 114 | new_f = 0. 115 | for ref in self.model.full_objective_references: 116 | new_f += self.model.f_df(args[0], ref)[0] 117 | else: 118 | new_f = f 119 | 120 | events = dict() # holds anything special about this step 121 | # a unique identifier for the current subfunction 122 | new_idx = id(args[1]) 123 | if 'SFO' in self.learner_name: 124 | events = dict(self.optimizer.events) 125 | # append the full objective value, projections, etc to the history 126 | self.history['f'][self.learner_name].append(new_f) 127 | x_proj = np.dot(self.x_projection_matrix, self.x_map.theta_original_to_flat(args[0])).ravel() 128 | self.history['x_projection'][self.learner_name].append(x_proj) 129 | self.history['events'][self.learner_name].append(events) 130 | self.history['x'][self.learner_name] = args[0] 131 | print("full f %g"%(new_f)) 132 | # store the prior values 133 | self.last_f = new_f 134 | self.last_idx = new_idx 135 | 136 | return f, df 137 | 138 | 139 | def f_df_wrapper_flattened(self, x_flat, subfunction_references, *args, **kwargs): 140 | """ 141 | Calculate the subfunction objective and gradient. 142 | Takes a 1d parameter vector, and returns a 1d gradient, even 143 | if the parameters for f_df are a list or a dictionary. 144 | x_flat should be the flattened version of the parameters. 145 | """ 146 | 147 | x = self.x_map.theta_flat_to_original(x_flat) 148 | f = 0. 149 | df = 0. 150 | for sr in subfunction_references: 151 | fl, dfl = self.f_df_wrapper(x, sr, *args, **kwargs) 152 | dfl_flat = self.x_map.theta_original_to_flat(dfl) 153 | f += fl 154 | df += dfl_flat 155 | return f, df.ravel() 156 | 157 | 158 | def SGD(self, num_passes=20): 159 | """ Train model using SGD with various learning rates """ 160 | 161 | # get the number of minibatches 162 | N = len(self.model.subfunction_references) 163 | # step through all the hyperparameters. eta is step length. 164 | for eta in 10**np.linspace(-5,2,8): 165 | # label this convergence trace using the optimizer name and hyperparameter 166 | self.learner_name = "SGD %.4f"%eta 167 | print("\n\n" + self.learner_name) 168 | 169 | # initialize the parameters 170 | x = self.xinit_flat.copy() 171 | ## perform stochastic gradient descent 172 | for _ in range(num_passes*N): # number of minibatch evaluations 173 | # choose a minibatch at random 174 | idx = np.random.randint(N) 175 | sr = self.model.subfunction_references[idx] 176 | # evaluate the objective and gradient for that minibatch 177 | fl, dfl = self.f_df_wrapper_flattened(x.reshape((-1,1)), (sr,)) 178 | # update the parameters 179 | x -= dfl.reshape(x.shape) * eta 180 | # if the objective has diverged, skip the rest of the run for this hyperparameter 181 | if not np.isfinite(fl): 182 | print("Non-finite subfunction.") 183 | break 184 | 185 | 186 | def LBFGS(self, num_passes=20): 187 | """ Train model using LBFGS """ 188 | 189 | self.learner_name = "LBFGS" 190 | print("\n\n" + self.learner_name) 191 | _, _, _ = fmin_l_bfgs_b( 192 | self.f_df_wrapper_flattened, 193 | self.xinit_flat.copy(), 194 | disp=1, 195 | args=(self.model.subfunction_references, ), 196 | maxfun=num_passes) 197 | 198 | 199 | def SFO(self, num_passes=20, learner_name='SFO', **kwargs): 200 | """ Train model using SFO.""" 201 | self.learner_name = learner_name 202 | print("\n\n" + self.learner_name) 203 | 204 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references, **kwargs) 205 | # # check the gradients 206 | # self.optimizer.check_grad() 207 | x = self.optimizer.optimize(num_passes=num_passes) 208 | 209 | 210 | def SFO_variations(self, num_passes=20): 211 | """ 212 | Train model using several variations on the standard SFO algorithm. 213 | """ 214 | 215 | np.random.seed(0) # make experiments repeatable 216 | self.learner_name = 'SFO standard' 217 | print("\n\n" + self.learner_name) 218 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references) 219 | x = self.optimizer.optimize(num_passes=num_passes) 220 | 221 | np.random.seed(0) # make experiments repeatable 222 | self.learner_name = 'SFO all active' 223 | print("\n\n" + self.learner_name) 224 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references, 225 | init_subf=len(self.model.subfunction_references)) 226 | x = self.optimizer.optimize(num_passes=num_passes) 227 | 228 | np.random.seed(0) # make experiments repeatable 229 | self.learner_name = 'SFO rank 1' 230 | print("\n\n" + self.learner_name) 231 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references, 232 | hessian_algorithm='rank1') 233 | x = self.optimizer.optimize(num_passes=num_passes) 234 | 235 | self.learner_name = 'SFO random' 236 | print("\n\n" + self.learner_name) 237 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references, 238 | subfunction_selection='random' 239 | ) 240 | x = self.optimizer.optimize(num_passes=num_passes) 241 | 242 | self.learner_name = 'SFO cyclic' 243 | print("\n\n" + self.learner_name) 244 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references, 245 | subfunction_selection='cyclic' 246 | ) 247 | x = self.optimizer.optimize(num_passes=num_passes) 248 | 249 | 250 | def SAG(self, num_passes=20): 251 | """ Train model using SAG with line search, for various initial Lipschitz """ 252 | 253 | # larger L is easier, so start large 254 | for L in 10**(-np.linspace(-3, 3, 7)): 255 | self.learner_name = "SAG %.4f"%L 256 | #learner_name = "SAG (diverges)" 257 | print("\n\n" + self.learner_name) 258 | self.optimizer = SAG(self.f_df_wrapper_flattened, self.xinit_flat.copy(), self.model.subfunction_references, L=L) 259 | x = self.optimizer.optimize(num_passes=num_passes) 260 | print(np.mean(self.optimizer.f), "average value at last evaluation") 261 | 262 | 263 | def LBFGS_minibatch(self, num_passes=20, data_fraction=0.1, num_steps=10): 264 | """ Perform LBFGS on minibatches of size data_fraction of the full datastep, and with num_steps LBFGS steps per minibatch.""" 265 | 266 | self.learner_name = "LBFGS minibatch" 267 | 268 | 269 | x = self.xinit_flat.copy() 270 | for epoch in range(num_passes): 271 | idx = random_choice(len(self.model.subfunction_references), 272 | int(data_fraction*len(self.model.subfunction_references)), 273 | replace=False) 274 | sr = [] 275 | for ii in idx: 276 | sr.append(self.model.subfunction_references[ii]) 277 | x, _, _ = fmin_l_bfgs_b( 278 | self.f_df_wrapper_flattened, 279 | x, 280 | args=(sr, ), 281 | disp=1, 282 | maxfun=num_steps) 283 | 284 | 285 | def SGD_momentum(self, num_passes=20): 286 | """ Train model using SGD with various learning rates and momentums""" 287 | 288 | learning_rates = 10**np.linspace(-5,2,8) 289 | momentums = np.array([0.5, 0.9, 0.95, 0.99]) 290 | params = product(learning_rates, momentums) 291 | N = len(self.model.subfunction_references) 292 | for eta, momentum in params: 293 | self.learner_name = "SGD_momentum eta=%.5f, mu=%.2f" % (eta, momentum) 294 | print("\n\n" + self.learner_name) 295 | f = np.ones((N))*np.nan 296 | x = self.xinit_flat.copy() 297 | # Prevous step 298 | inc = 0.0 299 | for epoch in range(num_passes): 300 | for minibatch in range(N): 301 | idx = np.random.randint(N) 302 | sr = self.model.subfunction_references[idx] 303 | fl, dfl = self.f_df_wrapper_flattened(x.reshape((-1,1)), (sr,)) 304 | inc = momentum * inc - eta * dfl.reshape(x.shape) 305 | x += inc 306 | f[idx] = fl 307 | if not np.isfinite(fl): 308 | print("Non-finite subfunction. Ending run.") 309 | break 310 | if not np.isfinite(fl): 311 | print("Non-finite subfunction. Ending run.") 312 | break 313 | print(np.mean(f[np.isfinite(f)]), "average finite value at last evaluation") 314 | 315 | 316 | def ADA(self, num_passes=20): 317 | """ Train model using ADAgrad with various learning rates """ 318 | 319 | for eta in 10**np.linspace(-3,1,5): 320 | self.learner_name = "ADAGrad %.4f"%eta 321 | print("\n\n" + self.learner_name) 322 | self.optimizer = ADAGrad(self.f_df_wrapper_flattened, self.xinit_flat.copy(), self.model.subfunction_references, learning_rate=eta) 323 | x = self.optimizer.optimize(num_passes=num_passes) 324 | print(np.mean(self.optimizer.f), "average value at last evaluation") 325 | 326 | -------------------------------------------------------------------------------- /generate_figures/sag.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2014 Jascha Sohl-Dickstein 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | 13 | This is an implementation of the Stochastic Average Gradient (SAG) algorithm: 14 | Le Roux, Nicolas, Mark Schmidt, and Francis Bach. 15 | "A Stochastic Gradient Method with an Exponential Convergence _Rate for Finite Training Sets." 16 | Advances in Neural Information Processing Systems 25. 2012. 17 | http://books.nips.cc/papers/files/nips25/NIPS2012_1246.pdf 18 | """ 19 | 20 | from numpy import * 21 | import scipy.linalg 22 | 23 | class SAG(object): 24 | 25 | def __init__(self, f_df, theta, subfunction_references, args=(), kwargs={}, L=1., L_freq=0): 26 | """ 27 | L is the Lipschitz constant. Smaller corresponds to faster (but 28 | noisier) learning. If L_freq is greater than 0, then every L_freq 29 | steps L will be adjusted as described in the SAG paper. 30 | """ 31 | self.L = L 32 | self.L_freq = L_freq # how often to perform the extra function evaluations in order to test L 33 | 34 | self.N = len(subfunction_references) 35 | self.sub_ref = subfunction_references 36 | self.M = theta.shape[0] 37 | self.f_df = f_df 38 | self.args = args 39 | self.kwargs = kwargs 40 | 41 | self.num_steps = 0 42 | self.theta = theta.copy().reshape((-1,1)) 43 | 44 | self.f = ones((self.N))*nan 45 | self.df = zeros((self.M,self.N)) 46 | 47 | self.grad_sum = sum( self.df, axis=1 ).reshape((-1,1)) 48 | 49 | def optimize(self, num_passes = 10, num_steps = None): 50 | if num_steps==None: 51 | num_steps = num_passes*self.N 52 | for i in range(num_steps): 53 | if not self.optimization_step(): 54 | break 55 | return self.theta 56 | 57 | def optimization_step(self): 58 | # choose a subfunction at random 59 | ind = int(floor(self.N*random.rand())) 60 | 61 | # calculate the objective function and gradient for the subfunction 62 | fl, dfl = self.f_df(self.theta, (self.sub_ref[ind], ), *self.args, **self.kwargs) 63 | dfl = dfl.reshape(-1,1) 64 | # store them 65 | self.f[ind] = fl 66 | self.grad_sum += (dfl - self.df[:,[ind]]) # TODO this may slowly accumulate errors. occasionally do the full sum? 67 | self.df[:,ind] = dfl.flat 68 | 69 | # only adjust the learning rate with frequency L_freq, to reduce computational load 70 | if self.L_freq > 0 and mod(self.num_steps, self.L_freq) == 0: 71 | # adapt the learning rate 72 | self.L *= 2**(-1. / self.N) 73 | theta_shift = self.theta - dfl / self.L 74 | # evaluate the objective function at theta_shift 75 | fl_shift, _ = self.f_df(theta_shift, (self.sub_ref[ind], ), *self.args, **self.kwargs) 76 | # test whether the change in objective satisfies Lip. inequality, otherwise increase constant 77 | if fl_shift - fl > -sum((dfl)**2) / (2.*self.L): 78 | self.L *= 2. 79 | 80 | self.num_steps += 1 81 | 82 | # take a gradient descent step 83 | div = min([self.num_steps, self.N]) 84 | delta_theta = -self.grad_sum / self.L / div 85 | self.theta += delta_theta 86 | 87 | if not isfinite(fl): 88 | print("Non-finite subfunction. Ending run.") 89 | return False 90 | return True 91 | -------------------------------------------------------------------------------- /generate_figures/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from theano.sandbox.cuda import CudaNdarraySharedVariable as CudaNdarray 3 | 4 | #XXX: Shamefully stolen from pyautodiff 5 | # https://github.com/jaberg/pyautodiff/blob/master/autodiff/fmin_scipy.py 6 | 7 | def _tonp(x): 8 | try: 9 | x = x.eval() 10 | except: 11 | pass 12 | #if not np.isscalar(x) and type(x) not in [CudaNdarray, np.array, np.ndarray]: 13 | #x = x.eval() 14 | return np.array(x) 15 | #return np.array(x) 16 | #elif type(x) in [np.array, np.ndarray]: 17 | # return x 18 | #else: 19 | # return np.array(x.eval()) 20 | 21 | def vector_from_args(raw_args): 22 | args = [_tonp(a) for a in raw_args] 23 | args_sizes = [w.size for w in args] 24 | x_size = sum(args_sizes) 25 | x = np.empty(x_size, dtype=args[0].dtype) # has to be float64 for fmin_l_bfgs_b 26 | i = 0 27 | for w in args: 28 | x[i: i + w.size] = w.flatten() 29 | i += w.size 30 | return x 31 | 32 | def args_from_vector(x, orig_args): 33 | #if type(orig_args[0]) != np.ndarray: 34 | # orig_args = [a.eval() for a in orig_args] 35 | # unpack x_opt -> args-like structure `args_opt` 36 | rval = [] 37 | i = 0 38 | for w in orig_args: 39 | size = _tonp(w.size) 40 | rval.append(x[i: i + size].reshape(_tonp(w.shape)).astype(w.dtype)) 41 | i += size 42 | return rval 43 | -------------------------------------------------------------------------------- /sfo.m: -------------------------------------------------------------------------------- 1 | % Implements the Sum of Functions Optimizer (SFO), as described in the paper: 2 | % Jascha Sohl-Dickstein, Ben Poole, and Surya Ganguli 3 | % An adaptive low dimensional quasi-Newton sum of functions optimizer 4 | % arXiv preprint arXiv:1311.2115 (2013) 5 | % http://arxiv.org/abs/1311.2115 6 | % 7 | % Sample code is provided in sfo_demo.m 8 | % 9 | % Useful functions in this class are: 10 | % obj = sfo(f_df, theta, subfunction_references, varargin) 11 | % Initializes the optimizer class. 12 | % Parameters: 13 | % f_df - Returns the function value and gradient for a single subfunction 14 | % call. Should have the form 15 | % [f, dfdtheta] = f_df(theta, subfunction_references{idx}, 16 | % varargin{:}) 17 | % where idx is the index of a single subfunction. 18 | % theta - The initial parameters to be used for optimization. theta can 19 | % be either a vector, a matrix, or a cell array with a vector or 20 | % matrix in every cell. The gradient returned by f_df should have the 21 | % same form as theta. 22 | % subfunction_references - A cell array containing an identifying element 23 | % for each subfunction. The elements in this list could be, eg, 24 | % matrices containing minibatches, or indices identifying the 25 | % subfunction, or filenames from which target data should be read. 26 | % If each subfunction corresponds to a minibatch, then the number of 27 | % subfunctions should be approximately 28 | % [number subfunctions] = sqrt([dataset size])/10. 29 | % varargin - Any additional parameters will be passed through to f_df 30 | % each time it is called. 31 | % Returns: 32 | % obj - sfo class instance. 33 | % theta = optimize(num_passes) 34 | % Optimize the objective function. 35 | % Parameters: 36 | % num_passes - The number of effective passes through 37 | % subfunction_references to perform. 38 | % Returns: 39 | % theta - The estimated parameter vector after num_passes of optimization. 40 | % check_grad() 41 | % Numerically checks the gradient of f_df for each element in 42 | % subfunction_references. 43 | % 44 | % Copyright 2014 Jascha Sohl-Dickstein 45 | % 46 | % Licensed under the Apache License, Version 2.0 (the "License"); 47 | % you may not use this file except in compliance with the License. 48 | % You may obtain a copy of the License at 49 | 50 | % http://www.apache.org/licenses/LICENSE-2.0 51 | 52 | % Unless required by applicable law or agreed to in writing, software 53 | % distributed under the License is distributed on an "AS IS" BASIS, 54 | % WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 55 | % See the License for the specific language governing permissions and 56 | % limitations under the License. 57 | 58 | classdef sfo < handle 59 | 60 | properties 61 | display = 2; 62 | f_df; 63 | args; 64 | max_history = 10; 65 | max_gradient_noise = 1; 66 | hess_max_dev = 1e8; 67 | hessian_init = 1e5; 68 | N; 69 | subfunction_references; 70 | % theta, in its original format; 71 | theta_original; 72 | % theta, flattented into a 1d array; 73 | theta; 74 | % theta from the previous learning step -- initialize to theta; 75 | theta_prior_step; 76 | % number of data dimensions; 77 | M; 78 | % the update steps will be rescaled by this; 79 | step_scale = 1; 80 | % 'very small' for various tasks, most importantly identifying when; 81 | % update steps or gradient changes are too small to be used for Hessian 82 | % updates without incurring large numerical errors.; 83 | eps = 1e-12; 84 | 85 | % The shortest step length allowed. Any update 86 | % steps shorter than this will be made this length. Set this so as to 87 | % prevent numerical errors when computing the difference in gradients 88 | % before and after a step. 89 | minimum_step_length = 1e-8; 90 | % The length of the longest allowed update step, 91 | % relative to the average length of prior update steps. Takes effect 92 | % after the first full pass through the data. 93 | max_step_length_ratio = 10; 94 | 95 | % the min & max dimenstionality for the subspace; 96 | K_min; 97 | K_max; 98 | % the current dimensionality of the subspace; 99 | K_current = 1; 100 | % obj.P holds the subspace; 101 | P; 102 | 103 | % store the minimum & maximum eigenvalue from each approximate; 104 | % Hessian; 105 | min_eig_sub; 106 | max_eig_sub; 107 | 108 | % store the total time spent in optimization, & the amount of time; 109 | % spent in the objective function; 110 | time_pass = 0.; 111 | time_func = 0.; 112 | 113 | % how many steps since the active set size was increased; 114 | iter_since_active_growth = 0; 115 | 116 | % which subfunctions are active; 117 | init_subf = 2; 118 | active; 119 | 120 | % the total path length traveled during optimization; 121 | total_distance = 0.; 122 | % number of function evaluations for each subfunction; 123 | eval_count; 124 | 125 | % theta projected into current working subspace; 126 | theta_proj; 127 | % holds the last position & the last gradient for all the objective functions; 128 | last_theta; 129 | last_df; 130 | % the history of theta changes for each subfunction; 131 | hist_deltatheta; 132 | % the history of gradient changes for each subfunction; 133 | hist_deltadf; 134 | % the history of function values for each subfunction; 135 | hist_f; 136 | % a flat history of all returned subfunction values for debugging/diagnostics; 137 | hist_f_flat = []; 138 | 139 | % the approximate Hessian for each subfunction is stored; 140 | % as dot(b(:.:.index), b(:.:.inedx).') 141 | b; 142 | 143 | % the full Hessian (sum over all the subfunctions); 144 | full_H = 0; 145 | 146 | % parameters that are passed through to f_df 147 | varargin_stored = {}; 148 | 149 | % the predicted improvement in the total objective from the current update step 150 | f_predicted_total_improvement = 0; 151 | end 152 | 153 | methods 154 | 155 | function obj = sfo(f_df, theta, subfunction_references, varargin) 156 | % obj = sfo(f_df, theta, subfunction_references, varargin) 157 | % Initializes the optimizer class. 158 | % Parameters: 159 | % f_df - Returns the function value and gradient for a single subfunction 160 | % call. Should have the form 161 | % [f, dfdtheta] = f_df(theta, subfunction_references{idx}, 162 | % varargin{:}) 163 | % where idx is the index of a single subfunction. 164 | % theta - The initial parameters to be used for optimization. theta can 165 | % be either a vector, a matrix, or a cell array with a vector or 166 | % matrix in every cell. The gradient returned by f_df should have the 167 | % same form as theta. 168 | % subfunction_references - A cell array containing an identifying element 169 | % for each subfunction. The elements in this list could be, eg, 170 | % matrices containing minibatches, or indices identifying the 171 | % subfunction, or filenames from which target data should be read. 172 | % varargin - Any additional parameters will be passed through to f_df 173 | % each time it is called. 174 | % Returns: 175 | % obj - sfo class instance. 176 | 177 | obj.N = length(subfunction_references); 178 | obj.theta_original = theta; 179 | obj.theta = obj.theta_original_to_flat(obj.theta_original); 180 | obj.theta_prior_step = obj.theta; 181 | obj.M = length(obj.theta); 182 | obj.f_df = f_df; 183 | obj.varargin_stored = varargin; 184 | obj.subfunction_references = subfunction_references; 185 | 186 | subspace_dimensionality = 2.*obj.N+2; % 2 to include current location; 187 | % subspace can't be larger than the full space; 188 | subspace_dimensionality = min([subspace_dimensionality, obj.M]); 189 | 190 | 191 | % the min & max dimenstionality for the subspace; 192 | obj.K_min = subspace_dimensionality; 193 | obj.K_max = ceil(obj.K_min.*1.5); 194 | obj.K_max = min([obj.K_max, obj.M]); 195 | % obj.P holds the subspace; 196 | obj.P = zeros(obj.M,obj.K_max); 197 | 198 | % store the minimum & maximum eigenvalue from each approximate; 199 | % Hessian; 200 | obj.min_eig_sub = zeros(obj.N,1); 201 | obj.max_eig_sub = zeros(obj.N,1); 202 | 203 | % which subfunctions are active; 204 | obj.active = false(obj.N,1); 205 | obj.init_subf = min(obj.N, obj.init_subf); 206 | inds = randperm(obj.N, obj.init_subf); 207 | obj.active(inds) = true; 208 | obj.min_eig_sub(inds) = obj.hessian_init; 209 | obj.max_eig_sub(inds) = obj.hessian_init; 210 | 211 | % number of function evaluations for each subfunction; 212 | obj.eval_count = zeros(obj.N,1); 213 | 214 | % set the first column of the subspace to be the initial; 215 | % theta; 216 | rr = sqrt(sum(obj.theta.^2)); 217 | if rr > 0; 218 | obj.P(:,1) = obj.theta/rr; 219 | else 220 | % initial theta is 0 -- initialize randomly; 221 | obj.P(:,1) = randn(obj.M,1); 222 | obj.P(:,1) = obj.P(:,1) / sqrt(sum(obj.P(:,1).^2)); 223 | end 224 | 225 | if obj.M == obj.K_max 226 | % if the subspace spans the full space, then (j)ust make; 227 | % P the identity matrix; 228 | if obj.display > 1; 229 | fprintf('subspace spans full space'); 230 | end 231 | obj.P = eye(obj.M); 232 | obj.K_current = obj.M+1; 233 | end 234 | 235 | 236 | % theta projected into current working subspace; 237 | obj.theta_proj = obj.P' * obj.theta; 238 | % holds the last position & the last gradient for all the objective functions; 239 | obj.last_theta = obj.theta_proj * ones(1,obj.N); 240 | obj.last_df = zeros(obj.K_max,obj.N); 241 | % the history of theta changes for each subfunction; 242 | obj.hist_deltatheta = zeros(obj.K_max,obj.max_history,obj.N); 243 | % the history of gradient changes for each subfunction; 244 | obj.hist_deltadf = zeros(obj.K_max,obj.max_history,obj.N); 245 | % the history of function values for each subfunction; 246 | obj.hist_f = ones(obj.N, obj.max_history).*nan; 247 | 248 | % the approximate Hessian for each subfunction is stored; 249 | % as dot(obj.b(:.:.index), obj.b(:.:.inedx).') 250 | obj.b = zeros(obj.K_max,2.*obj.max_history,obj.N); 251 | 252 | if obj.N < 25 && obj.display > 0 253 | fprintf( '\n\nIn experiments, performance suffered when the data was broken up into fewer\nthan 25 minibatches (and performance saturated after about 50 minibatches).\nSee Figure 2c. You may want to use more than the current %d minibatches.\n\n', obj.N); 254 | end 255 | end 256 | 257 | 258 | function theta = optimize(obj, num_passes) 259 | % theta = optimize(num_passes) 260 | % Optimize the objective function. 261 | % Parameters: 262 | % num_passes - The number of effective passes through 263 | % subfunction_references to perform. 264 | % Returns: 265 | % theta - The estimated parameter vector after num_passes of optimization. 266 | 267 | num_steps = ceil(num_passes.*obj.N); 268 | for i = 1:num_steps 269 | if obj.display > 1 270 | fprintf('pass %g, step %d,', sum(obj.eval_count)/obj.N, i); 271 | end 272 | obj.optimization_step(); 273 | if obj.display > 1 274 | fprintf('active %d/%d, sfo time %g s, func time %g s, f %f, %f\n', sum(obj.active), size(obj.active, 1), obj.time_pass - obj.time_func, obj.time_func, obj.hist_f_flat(end), mean(obj.hist_f(obj.eval_count>0,1))); 275 | end 276 | end 277 | if obj.display > 0 278 | fprintf('active %d/%d, pass %g, sfo time %g s, func time %g s, %f\n', sum(obj.active), size(obj.active, 1), sum(obj.eval_count)/obj.N, obj.time_pass - obj.time_func, obj.time_func, mean(obj.hist_f(obj.eval_count>0,1))); 279 | end 280 | 281 | theta = obj.theta_flat_to_original(obj.theta); 282 | end 283 | 284 | function check_grad(obj) 285 | % A diagnostic function to check the gradients for the subfunctions. It 286 | % checks the subfunctions in random order, & the dimensions of each 287 | % subfunction in random order. This way, a representitive set of 288 | % gradients can be checked quickly, even for high dimensional objectives. 289 | 290 | % step size to use for gradient check; 291 | small_diff = obj.eps.*1e6; 292 | fprintf('Testing step size %g\n', small_diff); 293 | 294 | for i = randperm(obj.N) 295 | [fl, ~, dfl] = obj.f_df_wrapper(obj.theta, i); 296 | ep = zeros(obj.M,1); 297 | dfl_obs = zeros(obj.M,1); 298 | dfl_err = zeros(obj.M,1); 299 | for j = randperm(obj.M) 300 | ep(j) = small_diff; 301 | fl2 = obj.f_df_wrapper(obj.theta + ep, i); 302 | dfl_obs(j) = (fl2 - fl)/small_diff; 303 | dfl_err(j) = dfl_obs(j) - dfl(j); 304 | if abs(dfl_err(j)) > small_diff * 1e4 305 | fprintf('large diff '); 306 | else 307 | fprintf(' '); 308 | end 309 | fprintf(' gradient subfunction %d, dimension %d, analytic %g, finite diff %g, error %g\n', i, j, dfl(j), dfl_obs(j), dfl_err(j)); 310 | ep(j) = 0.; 311 | end 312 | gerr = sqrt((sum(dfl - dfl_obs).^2)); 313 | fprintf('subfunction %g, total L2 gradient error %g\n', i, gerr); 314 | end 315 | end 316 | 317 | 318 | function apply_subspace_transformation(obj,T_left,T_right) 319 | % Apply change-of-subspace transformation. This function is called when; 320 | % the subspace is collapsed to project into the new lower dimensional; 321 | % subspace.; 322 | % T_left - The covariant subspace to subspace projection matrix.; 323 | % T_right - The contravariant subspace projection matrix.; 324 | 325 | % (note that currently T_left = T_right always since the subspace is; 326 | % orthogonal. This will change if eg the code is adapted to also; 327 | % incorporate a 'natural gradient' based parameter space transformation.); 328 | 329 | [tt, ss] = size(T_left); 330 | 331 | % project history terms into new subspace; 332 | obj.last_df = (T_right.') * (obj.last_df); 333 | obj.last_theta = (T_left) * (obj.last_theta); 334 | obj.hist_deltadf = obj.reshape_wrapper(T_right.' * obj.reshape_wrapper(obj.hist_deltadf, [ss,-1]), [tt,-1,obj.N]); 335 | obj.hist_deltatheta = obj.reshape_wrapper(T_left * obj.reshape_wrapper(obj.hist_deltatheta, [ss, -1]), [tt,-1,obj.N]); 336 | % project stored hessian for each subfunction in to new subspace; 337 | obj.b = obj.reshape_wrapper(T_right.' * obj.reshape_wrapper(obj.b, [ss,-1]), [tt, 2.*obj.max_history,obj.N]); 338 | 339 | %% To avoid slow accumulation of numerical errors, recompute full_H; 340 | %% & theta_proj when the subspace is collapsed. Should not be a; 341 | %% leading time cost.; 342 | % theta projected into current working subspace; 343 | obj.theta_proj = (obj.P.') * (obj.theta); 344 | % full approximate hessian; 345 | obj.full_H = real(obj.reshape_wrapper(obj.b,[ss,-1]) * obj.reshape_wrapper(obj.b,[ss,-1]).'); 346 | end 347 | 348 | function reorthogonalize_subspace(obj) 349 | % check if the subspace has become non-orthogonal 350 | subspace_eigs = eig(obj.P.' * obj.P); 351 | % TODO(jascha) this may be a stricter cutoff than we need 352 | if max(subspace_eigs) <= 1 + obj.eps 353 | return 354 | end 355 | 356 | if obj.display > 2 357 | fprintf('subspace has become non-orthogonal. Performing QR.\n'); 358 | end 359 | [Porth, ~] = qr(obj.P(:,1:obj.K_current), 0); 360 | Pl = zeros(obj.K_max, obj.K_max); 361 | Pl(:,1:obj.K_current) = obj.P.' * Porth; 362 | % update the subspace; 363 | obj.P(:,1:obj.K_current) = Porth; 364 | % Pl is the projection matrix from old to new basis. apply it to all the history; 365 | % terms; 366 | obj.apply_subspace_transformation(Pl.', Pl); 367 | end 368 | 369 | function collapse_subspace(obj, xl) 370 | % Collapse the subspace to its smallest dimensionality.; 371 | 372 | % xl is a new direction that may not be in the history yet, so we pass; 373 | % it in explicitly to make sure it's included.; 374 | 375 | if obj.display > 2 376 | fprintf('collapsing subspace\n'); 377 | end 378 | 379 | % the projection matrix from old to new subspace; 380 | Pl = zeros(obj.K_max,obj.K_max); 381 | 382 | % yy will hold all the directions to pack into the subspace.; 383 | % initialize it with random noise, so that it still spans K_min; 384 | % dimensions even if not all the subfunctions are active yet; 385 | yy = randn(obj.K_max,obj.K_min); 386 | % the most recent position & gradient for all active subfunctions; 387 | % as well as the current position & gradient (which will not be saved in the history yet); 388 | yz = [obj.last_df(:,obj.active), obj.last_theta(:,obj.active), xl, (obj.P.') * (obj.theta)]; 389 | yy(:,1:size(yz, 2)) = yz; 390 | [Pl(:,1:obj.K_min), ~] = qr(yy, 0); 391 | 392 | % update the subspace; 393 | obj.P = (obj.P) * (Pl); 394 | 395 | % Pl is the projection matrix from old to new basis. apply it to all the history; 396 | % terms; 397 | obj.apply_subspace_transformation(Pl.', Pl); 398 | 399 | % update the stored subspace size; 400 | obj.K_current = obj.K_min; 401 | 402 | % re-orthogonalize the subspace if it's accumulated small errors 403 | obj.reorthogonalize_subspace(); 404 | end 405 | 406 | 407 | function update_subspace(obj, x_in) 408 | % Update the low dimensional subspace by adding a new direction.; 409 | % x_in - The new vector to incorporate into the subspace.; 410 | 411 | if obj.K_current >= obj.M 412 | % no need to update the subspace if it spans the full space; 413 | return; 414 | end 415 | if sum(~isfinite(x_in)) > 0 416 | % bad vector! bail.; 417 | return; 418 | end 419 | x_in_length = sqrt(sum(x_in.^2)); 420 | if x_in_length < obj.eps 421 | % if the new vector is too short, nothing to do; 422 | return; 423 | end 424 | % make x unit length; 425 | xnew = x_in/x_in_length; 426 | 427 | % Find the component of x pointing out of the existing subspace.; 428 | % We need to do this multiple times for numerical stability.; 429 | for i = 1:3 430 | xnew = xnew - obj.P * (obj.P.' * xnew); 431 | ss = sqrt(sum(xnew.^2)); 432 | if ss < obj.eps 433 | % it barely points out of the existing subspace; 434 | % no need to add a new direction to the subspace; 435 | return; 436 | end 437 | % make it unit length; 438 | xnew = xnew / ss; 439 | % if it was already largely orthogonal then numerical; 440 | % stability will be good enough; 441 | % TODO replace this with a more principled test; 442 | if ss > 0.1 443 | break; 444 | end 445 | end 446 | 447 | % add a new column to the subspace containing the new direction; 448 | obj.P(:,obj.K_current+1) = xnew; 449 | obj.K_current = obj.K_current + 1; 450 | 451 | if obj.K_current >= obj.K_max 452 | % the subspace has exceeded its maximum allowed size -- collapse it; 453 | % xl may not be in the history yet, so we pass it in explicitly to make; 454 | % sure it's used; 455 | xl = (obj.P.') * (x_in); 456 | obj.collapse_subspace(xl); 457 | end 458 | end 459 | 460 | function full_H_combined = get_full_H_with_diagonal(obj) 461 | % Get the full approximate Hessian, including the diagonal terms.; 462 | % (note that obj.full_H is stored without including the diagonal terms); 463 | 464 | full_H_combined = obj.full_H + eye(obj.K_max).*sum(obj.min_eig_sub(obj.active)); 465 | end 466 | 467 | function f_pred = get_predicted_subf(obj, indx, theta_proj) 468 | % Get the predicted value of subfunction idx at theta_proj; 469 | % (where theat_proj is in the subspace); 470 | 471 | dtheta = theta_proj - obj.last_theta(:,indx); 472 | bdtheta = obj.b(:,:,indx).' * dtheta; 473 | Hdtheta = real(obj.b(:,:,indx) * bdtheta); 474 | Hdtheta = Hdtheta + dtheta.*obj.min_eig_sub(indx); % the diagonal contribution 475 | %df_pred = obj.last_df(:,indx) + Hdtheta; 476 | f_pred = obj.hist_f(indx,1) + obj.last_df(:,indx).' * dtheta + 0.5.*(dtheta.') * (Hdtheta); 477 | end 478 | 479 | 480 | function update_history(obj, indx, theta_proj, f, df_proj) 481 | % Update history of position differences & gradient differences; 482 | % for subfunction indx.; 483 | 484 | % there needs to be at least one earlier measurement from this; 485 | % subfunction to compute position & gradient differences.; 486 | if obj.eval_count(indx) > 1 487 | % differences in gradient & position; 488 | ddf = df_proj - obj.last_df(:,indx); 489 | ddt = theta_proj - obj.last_theta(:,indx); 490 | % length of gradient & position change vectors; 491 | lddt = sqrt(sum(ddt.^2)); 492 | lddf = sqrt(sum(ddf.^2)); 493 | 494 | corr_ddf_ddt = ddf.' * ddt / (lddt*lddf); 495 | 496 | if obj.display > 3 && corr_ddf_ddt < 0 497 | fprintf('Warning! Negative dgradient dtheta inner product. Adding it anyway.'); 498 | end 499 | if lddt < obj.eps 500 | if obj.display > 2 501 | fprintf('Largest change in theta too small (%g). Not adding to history.', lddt); 502 | end 503 | elseif lddf < obj.eps 504 | if obj.display > 2 505 | fprintf('Largest change in gradient too small (%g). Not adding to history.', lddf); 506 | end 507 | elseif abs(corr_ddf_ddt) < obj.eps 508 | if obj.display > 2 509 | fprintf('Inner product between dgradient and dtheta too small (%g). Not adding to history.', corr_ddf_ddt); 510 | end 511 | else 512 | if obj.display > 3 513 | fprintf('subf ||dtheta|| %g, subf ||ddf|| %g, corr(ddf,dtheta) %g,', lddt, lddf, sum(ddt.*ddf)/(lddt.*lddf)); 514 | end 515 | 516 | % shift the history by one timestep; 517 | obj.hist_deltatheta(:,2:end,indx) = obj.hist_deltatheta(:,1:end-1,indx); 518 | % store the difference in theta since the subfunction was last evaluated; 519 | obj.hist_deltatheta(:,1,indx) = ddt; 520 | % do the same thing for the change in gradient; 521 | obj.hist_deltadf(:,2:end,indx) = obj.hist_deltadf(:,1:end-1,indx); 522 | obj.hist_deltadf(:,1,indx) = ddf; 523 | end 524 | end 525 | 526 | obj.last_theta(:,indx) = theta_proj; 527 | obj.last_df(:,indx) = df_proj; 528 | obj.hist_f(indx,2:end) = obj.hist_f(indx,1:end-1); 529 | obj.hist_f(indx,1) = f; 530 | end 531 | 532 | 533 | function update_hessian(obj,indx) 534 | % Update the Hessian approximation for a single subfunction.; 535 | % indx - The index of the target subfunction for Hessian update.; 536 | 537 | gd = find(sum(obj.hist_deltatheta(:,:,indx).^2, 1)>0); 538 | num_gd = length(gd); 539 | if num_gd == 0 540 | % if no history, initialize with the median eigenvalue from full Hessian; 541 | if obj.display > 2 542 | fprintf(' no history '); 543 | end 544 | obj.b(:,:,indx) = 0.; 545 | H = obj.get_full_H_with_diagonal(); 546 | [U, ~] = obj.eigh_wrapper(H); 547 | obj.min_eig_sub(indx) = median(U)/sum(obj.active); 548 | obj.max_eig_sub(indx) = obj.min_eig_sub(indx); 549 | if obj.eval_count(indx) > 2 550 | if obj.display > 2 || sum(obj.eval_count) < 5 551 | fprintf('Subfunction evaluated %d times, but has no stored history.', obj.eval_count(indx)); 552 | end 553 | if sum(obj.eval_count) < 5 554 | fprintf('You probably need to initialize SFO with a smaller hessian_init value. Scaling down the Hessian to try to recover. You are better off correcting the hessian_init value though!'); 555 | obj.min_eig_sub(indx) = obj.min_eig_sub(indx) / 10.; 556 | end 557 | end 558 | return; 559 | end 560 | 561 | % work in the subspace defined by this subfunction's history for this; 562 | [P_hist, ~] = qr([obj.hist_deltatheta(:,gd,indx),obj.hist_deltadf(:,gd,indx)], 0); 563 | deltatheta_P = P_hist.' * obj.hist_deltatheta(:,gd,indx); 564 | deltadf_P = P_hist.' * obj.hist_deltadf(:,gd,indx); 565 | 566 | %% get an approximation to the smallest eigenvalue.; 567 | %% This will be used as the diagonal initialization for BFGS.; 568 | % calculate Hessian using pinv & squared equation. (j)ust to get; 569 | % smallest eigenvalue.; 570 | % df = H dx; 571 | % df^T df = dx^T H^T H dx = dx^T H^2 dx 572 | pdelthet = pinv(deltatheta_P); 573 | dd = (deltadf_P) * (pdelthet); 574 | H2 = (dd.') * (dd); 575 | [H2w, ~] = obj.eigh_wrapper(H2); 576 | H2w = sqrt(abs(H2w)); 577 | 578 | % only the top ~ num_gd eigenvalues are expected to be well defined; 579 | H2w = sort(H2w, 'descend'); 580 | H2w = H2w(1:num_gd); 581 | 582 | if min(H2w) == 0 || sum(~isfinite(H2w)) > 0 583 | % there was a failure using this history. either deltadf was; 584 | % degenerate (0 case), | deltatheta was (non-finite case).; 585 | % Initialize using other subfunctions; 586 | H2w(:) = max(obj.min_eig_sub(obj.active)); 587 | if obj.display > 3 588 | fprintf('ill-conditioned history'); 589 | end 590 | end 591 | 592 | obj.min_eig_sub(indx) = min(H2w); 593 | obj.max_eig_sub(indx) = max(H2w); 594 | 595 | if obj.min_eig_sub(indx) < obj.max_eig_sub(indx)/obj.hess_max_dev 596 | % constrain using allowed ratio; 597 | obj.min_eig_sub(indx) = obj.max_eig_sub(indx)/obj.hess_max_dev; 598 | if obj.display > 3 599 | fprintf('constraining Hessian initialization'); 600 | end 601 | end 602 | 603 | %% recalculate Hessian; 604 | % number of history terms; 605 | num_hist = size(deltatheta_P, 2); 606 | % the new hessian will be (b_p) * (b_p.') + eye().*obj.min_eig_sub(indx); 607 | b_p = zeros(size(P_hist, 2), num_hist*2); 608 | % step through the history; 609 | for hist_i = num_hist:-1:1 610 | s = deltatheta_P(:,hist_i); 611 | y = deltadf_P(:,hist_i); 612 | 613 | % for numerical stability; 614 | rscl = sqrt(sum(s.^2)); 615 | s = s/rscl; 616 | y = y/rscl; 617 | 618 | % the BFGS step proper 619 | Hs = s.*obj.min_eig_sub(indx) + b_p * ((b_p.') * (s)); 620 | term1 = y / sqrt(sum(y.*s)); 621 | sHs = sum(s.*Hs); 622 | term2 = sqrt(complex(-1.)) .* Hs / sqrt(sHs); 623 | if sum(~isfinite(term1)) > 0 || sum(~isfinite(term2)) > 0 624 | obj.min_eig_sub(indx) = max(H2w); 625 | if obj.display > 1 626 | fprintf('invalid bfgs history term. should never get here!'); 627 | end 628 | continue; 629 | end 630 | b_p(:,2*(hist_i-1)+2) = term1; 631 | b_p(:,2*(hist_i-1)+1) = term2; 632 | end 633 | 634 | H = real((b_p) * (b_p.')) + eye(size(b_p, 1))*obj.min_eig_sub(indx); 635 | % constrain it to be positive definite; 636 | [U, V] = obj.eigh_wrapper(H); 637 | if max(U) <= 0. 638 | % if there aren't any positive eigenvalues, then; 639 | % set them all to be the same conservative diagonal value; 640 | U(:) = obj.max_eig_sub(indx); 641 | if obj.display > 3 642 | fprintf('no positive eigenvalues after BFGS'); 643 | end 644 | end 645 | % set any too-small eigenvalues to the median positive; 646 | % eigenvalue; 647 | U_median = median(U(U>0)); 648 | U(U<(max(abs(U))/obj.hess_max_dev)) = U_median; 649 | 650 | % the Hessian after it's been forced to be positive definite; 651 | H_posdef = bsxfun(@times, V, U') * V.'; 652 | 653 | % now break it apart into matrices b & a diagonal term again; 654 | B_pos = H_posdef - eye(size(b_p, 1))*obj.min_eig_sub(indx); 655 | [U, V] = obj.eigh_wrapper(B_pos); 656 | b_p = bsxfun(@times, V, sqrt(obj.reshape_wrapper(U, [1,-1]))); 657 | 658 | obj.b(:,:,indx) = 0.; 659 | obj.b(:,1:size(b_p,2),indx) = (P_hist) * (b_p); 660 | end 661 | 662 | 663 | function theta_flat = theta_original_to_flat(obj, theta_original) 664 | % Convert from the original parameter format into a 1d array 665 | % The original format can be an array, | a cell array full of 666 | % arrays 667 | 668 | if iscell(theta_original) 669 | theta_length = 0; 670 | for theta_array = theta_original(:)' 671 | % iterate over cells 672 | theta_length = theta_length + numel(theta_array{1}); 673 | end 674 | theta_flat = zeros(theta_length,1); 675 | i_start = 1; 676 | for theta_array = theta_original(:)' 677 | % iterate over cells 678 | i_end = i_start + numel(theta_array{1})-1; 679 | theta_flat(i_start:i_end) = theta_array{1}(:); 680 | i_start = i_end+1; 681 | end 682 | else 683 | theta_flat = theta_original(:); 684 | end 685 | end 686 | function theta_new = theta_flat_to_original(obj, theta_flat) 687 | % Convert from a 1d array into the original parameter format.; 688 | 689 | if iscell(obj.theta_original) 690 | theta_new = cell(size(obj.theta_original)); 691 | i_start = 1; 692 | jj = 1; 693 | for theta_array = obj.theta_original(:)' 694 | % iterate over cells 695 | i_end = i_start + numel(theta_array{1})-1; 696 | theta_new{jj} = obj.reshape_wrapper(theta_flat(i_start:i_end), size(theta_array{1})); 697 | i_start = i_end + 1; 698 | jj = jj + 1; 699 | end 700 | else 701 | theta_new = obj.reshape_wrapper(theta_flat, size(obj.theta_original)); 702 | end 703 | end 704 | 705 | 706 | function [f, df_proj, df_full] = f_df_wrapper(obj, theta_in, idx) 707 | % A wrapper around the subfunction objective f_df, that handles the transformation; 708 | % into & out of the flattened parameterization used internally by SFO.; 709 | 710 | theta_local = obj.theta_flat_to_original(theta_in); 711 | % evaluate; 712 | t = tic(); 713 | [f, df_full] = obj.f_df(theta_local, obj.subfunction_references{idx}, obj.varargin_stored{:}); 714 | time_diff = toc(t); 715 | obj.time_func = obj.time_func + time_diff; % time spent in function evaluation; 716 | df_full = obj.theta_original_to_flat(df_full); 717 | % update the subspace with the new gradient direction; 718 | obj.update_subspace(df_full); 719 | % gradient projected into the current subspace; 720 | df_proj = ( obj.P.') * (df_full ); 721 | % keep a record of function evaluations; 722 | obj.hist_f_flat = [obj.hist_f_flat, f]; 723 | obj.eval_count(idx) = obj.eval_count(idx) + 1; 724 | end 725 | 726 | 727 | function indx = get_target_index(obj) 728 | % Choose which subfunction to update this iteration. 729 | 730 | % if an active subfunction has one evaluation, get a second 731 | % so we can have a Hessian estimate 732 | gd = find((obj.eval_count == 1) & obj.active); 733 | if ~isempty(gd) 734 | indx = gd(randperm(length(gd), 1)); 735 | return 736 | end 737 | % If an active subfunction has less than two observations, then; 738 | % evaluate it. We want to get to two evaluations per subfunction; 739 | % as quickly as possibly so that it's possible to estimate a Hessian; 740 | % for it 741 | gd = find((obj.eval_count < 2) & obj.active); 742 | if ~isempty(gd) 743 | indx = gd(randperm(length(gd), 1)); 744 | return 745 | end 746 | 747 | % use the subfunction evaluated farthest; 748 | % either weighted by the total Hessian, or by the Hessian 749 | % just for that subfunction 750 | if randn() < 0 751 | max_dist = -1; 752 | indx = -1; 753 | for i = 1:obj.N 754 | dtheta = obj.theta_proj - obj.last_theta(:,i); 755 | bdtheta = obj.b(:,:,i).' * dtheta; 756 | dist = sum(bdtheta.^2) + sum(dtheta.^2)*obj.min_eig_sub(i); 757 | if (dist > max_dist) && obj.active(i) 758 | max_dist = dist; 759 | indx = i; 760 | end 761 | end 762 | else 763 | % from the current location, weighted by the Hessian; 764 | % difference between current theta & most recent evaluation; 765 | % for all subfunctions; 766 | dtheta = bsxfun(@plus, obj.theta_proj, -obj.last_theta); 767 | % full Hessian; 768 | full_H_combined = obj.get_full_H_with_diagonal(); 769 | % squared distance; 770 | distance = sum(dtheta.*(full_H_combined * dtheta), 1); 771 | % sort the distances from largest to smallest; 772 | [~, dist_ord] = sort(-distance); 773 | % & keep only the indices that belong to active subfunctions; 774 | dist_ord = dist_ord(obj.active(dist_ord)); 775 | % & choose the active subfunction from farthest away; 776 | indx = dist_ord(1); 777 | if max(distance(obj.active)) < obj.eps && sum(~obj.active)>0 && obj.eval_count(indx)>0 778 | if obj.display > 2 779 | fprintf('all active subfunctions evaluated here. expanding active set.'); 780 | end 781 | inactive = find(~obj.active); 782 | indx = inactive(randperm(length(inactive), 1)); 783 | obj.active(indx) = true; 784 | end 785 | end 786 | 787 | end 788 | 789 | 790 | function [step_failure, f, df_proj] = handle_step_failure(obj, f, df_proj, indx) 791 | % Check whether an update step failed. Update current position if it did.; 792 | 793 | % check to see whether the step should be a failure; 794 | step_failure = false; 795 | if ~isfinite(f) || sum(~isfinite(df_proj))>0 796 | % step is a failure if function | gradient is non-finite; 797 | step_failure = true; 798 | elseif obj.eval_count(indx) == 1 799 | % the step is a candidate for failure if it's a new subfunction, & it's; 800 | % much larger than expected; 801 | if max(obj.eval_count) > 1 802 | if f > mean(obj.hist_f(obj.eval_count>1,1)) + 3*std(obj.hist_f(obj.eval_count>1,1)) 803 | step_failure = true; 804 | end 805 | end 806 | elseif f > obj.hist_f(indx,1) 807 | % if this subfunction has increased in value, then look whether it's larger; 808 | % than its predicted value by enough to trigger a failure; 809 | % calculate the predicted value of this subfunction; 810 | f_pred = obj.get_predicted_subf(indx, obj.theta_proj); 811 | % if the subfunction exceeds its predicted value by more than the predicted average gain; 812 | % in the other subfunctions, then mark the step as a failure 813 | % (note that it's been about N steps since this has been evaluated, & that this subfunction can lay; 814 | % claim to about 1/N fraction of the objective change); 815 | predicted_improvement_others = obj.f_predicted_total_improvement - (obj.hist_f(indx,1) - f_pred); 816 | if f - f_pred > predicted_improvement_others 817 | step_failure = true; 818 | end 819 | end 820 | 821 | if ~step_failure 822 | % decay the step_scale back towards 1; 823 | obj.step_scale = 1./obj.N + obj.step_scale .* (1. - 1./obj.N); 824 | else 825 | % shorten the step length; 826 | obj.step_scale = obj.step_scale / 2.; 827 | 828 | % the subspace may be updated during the function calls; 829 | % so store this in the full space; 830 | df = (obj.P) * (df_proj); 831 | 832 | [f_lastpos, df_lastpos_proj] = obj.f_df_wrapper(obj.theta_prior_step, indx); 833 | df_lastpos = (obj.P) * (df_lastpos_proj); 834 | 835 | %% if the function value exploded, then back it off until it's a; 836 | %% reasonable order of magnitude before adding anything to the history; 837 | f_pred = obj.get_predicted_subf(indx, obj.theta_proj); 838 | if isfinite(obj.hist_f(indx,1)) 839 | predicted_f_diff = abs(f_pred - obj.hist_f(indx,1)); 840 | else 841 | predicted_f_diff = abs(f - f_lastpos); 842 | end 843 | if ~isfinite(predicted_f_diff) || predicted_f_diff < obj.eps 844 | predicted_f_diff = obj.eps; 845 | end 846 | 847 | for i_ls = 1:10 848 | if f - f_lastpos < 10*predicted_f_diff 849 | % the failed update is already with an order of magnitude; 850 | % of the target update value -- no backoff required; 851 | break; 852 | end 853 | if obj.display > 4 854 | fprintf('ls %d f_diff %g predicted_f_diff %g ', i_ls, f - f_lastpos, predicted_f_diff); 855 | end 856 | % make the step length a factor of 100 shorter; 857 | obj.theta = 0.99.*obj.theta_prior_step + 0.01.*obj.theta; 858 | obj.theta_proj = (obj.P.') * (obj.theta); 859 | % & recompute f & df at this new location; 860 | [f, df_proj] = obj.f_df_wrapper(obj.theta, indx); 861 | df = (obj.P) * (df_proj); 862 | end 863 | 864 | % we're done with function calls. can move these back into the subspace.; 865 | df_proj = (obj.P.') * (df); 866 | df_lastpos_proj = (obj.P.') * (df_lastpos); 867 | 868 | if f < f_lastpos 869 | % the original objective was better -- but add the newly evaluated point to the history; 870 | % (j)ust so it's not a wasted function call; 871 | theta_lastpos_proj = (obj.P.') * (obj.theta_prior_step); 872 | obj.update_history(indx, theta_lastpos_proj, f_lastpos, df_lastpos_proj); 873 | if obj.display > 2 874 | fprintf('step failed, but last position was even worse ( f %g, std f %g), ', f_lastpos, std(obj.hist_f(obj.eval_count>0,1))); 875 | end 876 | else 877 | % add the change in theta & the change in gradient to the history for this subfunction; 878 | % before failing over to the last position; 879 | if isfinite(f) && sum(~isfinite(df_proj))==0 880 | obj.update_history(indx, obj.theta_proj, f, df_proj); 881 | end 882 | if obj.display > 2 883 | fprintf('step failed, proposed f %g, std f %g, ', f, std(obj.hist_f(obj.eval_count>0,1))); 884 | end 885 | if (obj.display > -1) && (sum(obj.eval_count>1) < 2) 886 | fprintf([ '\nStep failed on the very first subfunction. This is\n' ... 887 | 'either due to an incorrect gradient, or a very large\n' ... 888 | 'Hessian. Try:\n' ... 889 | ' - Calling check_grad() (see README.md for details)\n' ... 890 | ' - Setting sfo.hessian_init to a larger value.\n']); 891 | end 892 | f = f_lastpos; 893 | df_proj = df_lastpos_proj; 894 | obj.theta = obj.theta_prior_step; 895 | obj.theta_proj = (obj.P.') * (obj.theta); 896 | end 897 | end 898 | 899 | % don't let steps get so short that they don't provide any usable Hessian information; 900 | % TODO use a more principled cutoff here; 901 | obj.step_scale = max([obj.step_scale, 1e-5]); 902 | end 903 | 904 | 905 | function expand_active_subfunctions(obj, full_H_inv, step_failure) 906 | % expand the set of active subfunctions as appropriate; 907 | 908 | % power in the average gradient direction; 909 | df_avg = mean(obj.last_df(:,obj.active), 2); 910 | p_df_avg = sum(df_avg .* (full_H_inv * df_avg)); 911 | % power of the standard error; 912 | ldfs = obj.last_df(:,obj.active) - df_avg*ones(1,sum(obj.active)); 913 | num_active = sum(obj.active); 914 | p_df_sum = sum(sum(ldfs .* (full_H_inv * ldfs))) / num_active / (num_active - 1); 915 | % if the standard errror in the estimated gradient is the same order of magnitude as the gradient; 916 | % we want to increase the size of the active set; 917 | increase_desirable = (p_df_sum >= p_df_avg.*obj.max_gradient_noise); 918 | % increase the active set on step failure; 919 | increase_desirable = increase_desirable || step_failure; 920 | % increase the active set if we've done a full pass without updating it; 921 | increase_desirable = increase_desirable || (obj.iter_since_active_growth > num_active); 922 | % make sure all the subfunctions have enough evaluations for a Hessian approximation; 923 | % before bringing in new subfunctions; 924 | eligibile_for_increase = (min(obj.eval_count(obj.active)) >= 2); 925 | % one more iteration has passed since the active set was last expanded; 926 | obj.iter_since_active_growth = obj.iter_since_active_growth + 1; 927 | if increase_desirable && eligibile_for_increase && sum(~obj.active) > 0 928 | % the index of the new subfunction to activate; 929 | new_gd = find(~obj.active); 930 | new_gd = new_gd(randperm(length(new_gd), 1)); 931 | if ~isempty(new_gd) 932 | obj.iter_since_active_growth = 0; 933 | obj.active(new_gd) = true; 934 | end 935 | end 936 | end 937 | 938 | 939 | function optimization_step(obj) 940 | % Perform a single optimization step. This function is typically called by SFO.optimize().; 941 | 942 | time_pass_start = tic(); 943 | 944 | %% choose an index to update; 945 | indx = obj.get_target_index(); 946 | 947 | if obj.display > 2 948 | fprintf('||dtheta|| %g, ', sqrt(sum((obj.theta - obj.theta_prior_step).^2))); 949 | fprintf('index %d, last f %g, ', indx, obj.hist_f(indx,1)); 950 | fprintf('step scale %g, ', obj.step_scale); 951 | end 952 | if obj.display > 8 953 | C = obj.P.' * obj.P; 954 | eC = eig(C); 955 | fprintf('mne %g, mxe %g, ', min(eC(eC>0)), max(eC)); 956 | end 957 | 958 | % evaluate subfunction value & gradient at new position; 959 | [f, df_proj] = obj.f_df_wrapper(obj.theta, indx); 960 | 961 | % check for a failed update step, & adjust f, df, & obj.theta; 962 | % as appropriate if one occurs.; 963 | [step_failure, f, df_proj] = obj.handle_step_failure(f, df_proj, indx); 964 | 965 | % add the change in theta & the change in gradient to the history for this subfunction; 966 | obj.update_history(indx, obj.theta_proj, f, df_proj); 967 | 968 | % increment the total distance traveled using the last update; 969 | obj.total_distance = obj.total_distance + sqrt(sum((obj.theta - obj.theta_prior_step).^2)); 970 | 971 | % the current contribution from this subfunction to the total Hessian approximation; 972 | H_pre_update = real(obj.b(:,:,indx) * obj.b(:,:,indx).'); 973 | %% update this subfunction's Hessian estimate; 974 | obj.update_hessian(indx); 975 | % the new contribution from this subfunction to the total approximate hessian; 976 | H_new = real(obj.b(:,:,indx) * obj.b(:,:,indx).'); 977 | % update total Hessian using this subfunction's updated contribution; 978 | obj.full_H = obj.full_H + H_new - H_pre_update; 979 | 980 | % calculate the total gradient, total Hessian, & total function value at the current location; 981 | full_df = 0.; 982 | for i = 1:obj.N 983 | dtheta = obj.theta_proj - obj.last_theta(:,i); 984 | bdtheta = obj.b(:,:,i).' * dtheta; 985 | Hdtheta = real(obj.b(:,:,i) * bdtheta); 986 | Hdtheta = Hdtheta + dtheta.*obj.min_eig_sub(i); % the diagonal contribution; 987 | full_df = full_df + Hdtheta + obj.last_df(:,i); 988 | end 989 | full_H_combined = obj.get_full_H_with_diagonal(); 990 | % TODO - Use Woodbury identity instead of recalculating full inverse; 991 | full_H_inv = inv(full_H_combined); 992 | 993 | % calculate an update step; 994 | dtheta_proj = -(full_H_inv) * (full_df) .* obj.step_scale; 995 | 996 | dtheta_proj_length = sqrt(sum(dtheta_proj(:).^2)); 997 | if dtheta_proj_length < obj.minimum_step_length 998 | dtheta_proj = dtheta_proj*obj.minimum_step_length/dtheta_proj_length; 999 | dtheta_proj_length = obj.minimum_step_length; 1000 | if obj.display > 3 1001 | fprintf('forcing minimum step length'); 1002 | end 1003 | end 1004 | if sum(obj.eval_count) > obj.N && dtheta_proj_length > obj.eps 1005 | % only allow a step to be up to a factor of max_step_length_ratio longer than the 1006 | % average step length 1007 | avg_length = obj.total_distance / sum(obj.eval_count); 1008 | length_ratio = dtheta_proj_length / avg_length; 1009 | ratio_scale = obj.max_step_length_ratio; 1010 | if length_ratio > ratio_scale 1011 | if obj.display > 3 1012 | fprintf('truncating step length from %g to %g', dtheta_proj_length, ratio_scale*avg_length); 1013 | end 1014 | dtheta_proj_length = dtheta_proj_length/(length_ratio/ratio_scale); 1015 | dtheta_proj = dtheta_proj/(length_ratio/ratio_scale); 1016 | end 1017 | end 1018 | 1019 | % the update to theta, in the full dimensional space; 1020 | dtheta = (obj.P) * (dtheta_proj); 1021 | 1022 | % backup the prior position, in case this is a failed step; 1023 | obj.theta_prior_step = obj.theta; 1024 | % update theta to the new location; 1025 | obj.theta = obj.theta + dtheta; 1026 | obj.theta_proj = obj.theta_proj + dtheta_proj; 1027 | % the predicted improvement from this update step; 1028 | obj.f_predicted_total_improvement = 0.5 .* dtheta_proj.' * (full_H_combined * dtheta_proj); 1029 | 1030 | %% expand the set of active subfunctions as appropriate; 1031 | obj.expand_active_subfunctions(full_H_inv, step_failure); 1032 | 1033 | % record how much time was taken by this learning step; 1034 | time_diff = toc(time_pass_start); 1035 | obj.time_pass = obj.time_pass + time_diff; 1036 | end 1037 | end 1038 | 1039 | 1040 | methods(Static) 1041 | 1042 | function A = reshape_wrapper(A, shape) 1043 | % a wrapper for reshape which duplicates the numpy behavior, and sets 1044 | % any -1 dimensions to the appropriate length 1045 | 1046 | total_dims = numel(A); 1047 | total_assigned = prod(shape(shape>0)); 1048 | shape(shape==-1) = total_dims/total_assigned; 1049 | A = reshape(A, shape); 1050 | end 1051 | 1052 | 1053 | function [U, V] = eigh_wrapper(A) 1054 | % A wrapper which duplicates the order and format of the numpy 1055 | % eigh routine. (note, eigh further assumes symmetric matrix. don't 1056 | % think there's an equivalent MATLAB function?) 1057 | 1058 | % Note: this function enforces A to be symmetric 1059 | 1060 | [V,U] = eig(0.5 * (A + A')); 1061 | U = diag(U); 1062 | end 1063 | 1064 | end 1065 | end 1066 | -------------------------------------------------------------------------------- /sfo_demo.m: -------------------------------------------------------------------------------- 1 | % Demonstrates usage of the Sum of Functions Optimizer (SFO) MATLAB 2 | % package. See sfo.m and 3 | % https://github.com/Sohl-Dickstein/Sum-of-Functions-Optimizer 4 | % for additional documentation. 5 | % 6 | % Copyright 2014 Jascha Sohl-Dickstein 7 | % 8 | % Licensed under the Apache License, Version 2.0 (the "License"); 9 | % you may not use this file except in compliance with the License. 10 | % You may obtain a copy of the License at 11 | % 12 | % http://www.apache.org/licenses/LICENSE-2.0 13 | % 14 | % Unless required by applicable law or agreed to in writing, software 15 | % distributed under the License is distributed on an "AS IS" BASIS, 16 | % WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | % See the License for the specific language governing permissions and 18 | % limitations under the License. 19 | 20 | function sfo_demo() 21 | % set model and training data parameters 22 | M = 20; % number visible units 23 | J = 10; % number hidden units 24 | D = 100000; % full data batch size 25 | N = floor(sqrt(D)/10.); % number minibatches 26 | % generate random training data 27 | v = randn(M,D); 28 | 29 | % create the cell array of subfunction specific arguments 30 | sub_refs = cell(N,1); 31 | for i = 1:N 32 | % extract a single minibatch of training data. 33 | sub_refs{i} = v(:,i:N:end); 34 | end 35 | 36 | % initialize parameters 37 | % Parameters can be stored as a vector, a matrix, or a cell array with a 38 | % vector or matrix in each cell. Here the parameters are 39 | % {[weight matrix], [hidden bias], [visible bias]}. 40 | theta_init = {randn(J,M), randn(J,1), randn(M,1)}; 41 | % initialize the optimizer 42 | optimizer = sfo(@f_df_autoencoder, theta_init, sub_refs); 43 | % uncomment the following line to test the gradient of f_df 44 | %optimizer.check_grad(); 45 | % run the optimizer for half a pass through the data 46 | theta = optimizer.optimize(0.5); 47 | % run the optimizer for another 20 passes through the data, continuing from 48 | % the theta value where the prior call to optimize() ended 49 | theta = optimizer.optimize(20); 50 | % plot the convergence trace 51 | plot(optimizer.hist_f_flat); 52 | xlabel('Iteration'); 53 | ylabel('Minibatch Function Value'); 54 | title('Convergence Trace'); 55 | end 56 | 57 | % define an objective function and gradient 58 | function [f, dfdtheta] = f_df_autoencoder(theta, v) 59 | % [f, dfdtheta] = f_df_autoencoder(theta, v) 60 | % Calculate L2 reconstruction error and gradient for an autoencoder 61 | % with sigmoid nonlinearity. 62 | % Parameters: 63 | % theta - A cell array containing 64 | % {[weight matrix], [hidden bias], [visible bias]}. 65 | % v - A [# visible, # datapoints] matrix containing training data. 66 | % v will be different for each subfunction. 67 | % Returns: 68 | % f - The L2 reconstruction error for data v and parameters theta. 69 | % df - A cell array containing the gradient of f with each of the 70 | % parameters in theta. 71 | 72 | W = theta{1}; 73 | b_h = theta{2}; 74 | b_v = theta{3}; 75 | 76 | h = 1./(1 + exp(-bsxfun(@plus, W * v, b_h))); 77 | v_hat = bsxfun(@plus, W' * h, b_v); 78 | f = sum(sum((v_hat - v).^2)) / size(v, 2); 79 | dv_hat = 2*(v_hat - v) / size(v, 2); 80 | db_v = sum(dv_hat, 2); 81 | dW = h * dv_hat'; 82 | dh = W * dv_hat; 83 | db_h = sum(dh.*h.*(1-h), 2); 84 | dW = dW + dh.*h.*(1-h) * v'; 85 | % give the gradients the same order as the parameters 86 | dfdtheta = {dW, db_h, db_v}; 87 | end 88 | -------------------------------------------------------------------------------- /sfo_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train an autoencoder using SFO. 3 | 4 | Demonstrates usage of the Sum of Functions Optimizer (SFO) Python 5 | package. See sfo.py and 6 | https://github.com/Sohl-Dickstein/Sum-of-Functions-Optimizer 7 | for additional documentation. 8 | 9 | Copyright 2014 Jascha Sohl-Dickstein 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | 24 | import matplotlib.pyplot as plt 25 | import numpy as np 26 | from numpy.random import randn 27 | from sfo import SFO 28 | 29 | # define an objective function and gradient 30 | def f_df(theta, v): 31 | """ 32 | Calculate reconstruction error and gradient for an autoencoder with sigmoid 33 | nonlinearity. 34 | v contains the training data, and will be different for each subfunction. 35 | """ 36 | h = 1./(1. + np.exp(-(np.dot(theta['W'], v) + theta['b_h']))) 37 | v_hat = np.dot(theta['W'].T, h) + theta['b_v'] 38 | f = np.sum((v_hat - v)**2) / v.shape[1] 39 | dv_hat = 2.*(v_hat - v) / v.shape[1] 40 | db_v = np.sum(dv_hat, axis=1).reshape((-1,1)) 41 | dW = np.dot(h, dv_hat.T) 42 | dh = np.dot(theta['W'], dv_hat) 43 | db_h = np.sum(dh*h*(1.-h), axis=1).reshape((-1,1)) 44 | dW += np.dot(dh*h*(1.-h), v.T) 45 | dfdtheta = {'W':dW, 'b_h':db_h, 'b_v':db_v} 46 | return f, dfdtheta 47 | 48 | # set model and training data parameters 49 | M = 20 # number visible units 50 | J = 10 # number hidden units 51 | D = 100000 # full data batch size 52 | N = int(np.sqrt(D)/10.) # number minibatches 53 | # generate random training data 54 | v = randn(M,D) 55 | 56 | # create the array of subfunction specific arguments 57 | sub_refs = [] 58 | for i in range(N): 59 | # extract a single minibatch of training data. 60 | sub_refs.append(v[:,i::N]) 61 | 62 | # initialize parameters 63 | theta_init = {'W':randn(J,M), 'b_h':randn(J,1), 'b_v':randn(M,1)} 64 | # initialize the optimizer 65 | optimizer = SFO(f_df, theta_init, sub_refs) 66 | # # uncomment the following line to test the gradient of f_df 67 | # optimizer.check_grad() 68 | # run the optimizer for 1 pass through the data 69 | theta = optimizer.optimize(num_passes=1) 70 | # continue running the optimizer for another 20 passes through the data 71 | theta = optimizer.optimize(num_passes=20) 72 | 73 | # plot the convergence trace 74 | plt.plot(np.array(optimizer.hist_f_flat)) 75 | plt.xlabel('Iteration') 76 | plt.ylabel('Minibatch Function Value') 77 | plt.title('Convergence Trace') 78 | plt.show() 79 | --------------------------------------------------------------------------------