├── .gitignore
├── README.md
├── __init__.py
├── generate_figures
├── README.md
├── __init__.py
├── adagrad.py
├── convergence_utils.py
├── dropout
│ ├── .DS_Store
│ ├── .gitignore
│ ├── LICENSE
│ ├── README
│ ├── __init__.py
│ ├── deepae.py
│ ├── load_data.py
│ ├── logistic_sgd.py
│ ├── mlp.py
│ └── plot_results.sh
├── figure_cartoon.py
├── figure_compare_optimizers.py
├── figure_compare_sfo_L.py
├── figure_compare_sfo_N.py
├── figure_compare_sfo_variations.py
├── figure_overhead.py
├── figures_cae.py
├── models.py
├── nnet
│ ├── __init__.py
│ ├── conv.yaml
│ ├── mnist.yaml
│ └── model_gradient.py
├── optimization_wrapper.py
├── sag.py
└── utils.py
├── sfo.m
├── sfo.py
├── sfo_demo.m
└── sfo_demo.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.py[cod]
2 |
3 | # C extensions
4 | *.so
5 |
6 | # Packages
7 | *.egg
8 | *.egg-info
9 | dist
10 | build
11 | eggs
12 | parts
13 | bin
14 | var
15 | sdist
16 | develop-eggs
17 | .installed.cfg
18 | lib
19 | lib64
20 | __pycache__
21 |
22 | # Installer logs
23 | pip-log.txt
24 |
25 | # Unit test / coverage reports
26 | .coverage
27 | .tox
28 | nosetests.xml
29 |
30 | # Translations
31 | *.mo
32 |
33 | # Mr Developer
34 | .mr.developer.cfg
35 | .project
36 | .pydevproject
37 |
38 | iterate.dat
39 | *.npz
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Sum of Functions Optimizer (SFO)
2 | ================================
3 |
4 | SFO is a function optimizer for the case where the target function breaks into a sum over minibatches, or a sum over contributing functions. It combines the benefits of both quasi-Newton and stochastic gradient descent techniques, and will likely converge faster and to a better function value than either. It does not require tuning of hyperparameters. It is described in more detail in the paper:
5 | > Jascha Sohl-Dickstein, Ben Poole, and Surya Ganguli
6 | > Fast large-scale optimization by unifying stochastic gradient and quasi-Newton methods
7 | > International Conference on Machine Learning (2014)
8 | > arXiv preprint arXiv:1311.2115 (2013)
9 | > http://arxiv.org/abs/1311.2115
10 |
11 | This repository provides easy to use Python and MATLAB implementations of SFO, as well as functions to exactly reproduce the figures in the paper.
12 |
13 | ## Use SFO
14 |
15 | Simple example code which trains an autoencoder is in **sfo_demo.py** and **sfo_demo.m**, and is reproduced at the end of this README.
16 |
17 | ### Python package
18 |
19 | To use SFO, you should first import SFO,
20 | `from sfo import SFO`
21 | then initialize it,
22 | `optimizer = SFO(f_df, theta_init, subfunction_references)`
23 | then call the optimizer, specifying the number of optimization passes to perform,
24 | `theta = optimizer.optimize(num_passes=1)`.
25 |
26 | The three required initialization parameters are:
27 | - *f_df* - Returns the function value and gradient for a single subfunction
28 | call. Should have the form
29 | `f, dfdtheta = f_df(theta, subfunction_references[idx])`,
30 | where *idx* is the index of a single subfunction.
31 | - *theta_init* - The initial parameters to be used for optimization. *theta_init* can
32 | be either a NumPy array, an array of NumPy arrays, a dictionary of NumPy
33 | arrays, or a nested combination thereof. The gradient returned by *f_df*
34 | should have the same form as *theta_init*.
35 | - *subfunction_references* - A list containing an identifying element for
36 | each subfunction. The elements in this list could be, eg, numpy
37 | matrices containing minibatches, or indices identifying the
38 | subfunction, or filenames from which target data should be read.
39 | **If each subfunction corresponds to a minibatch, then the number of
40 | subfunctions should be approximately \[number subfunctions\] = sqrt(\[dataset size\])/10**.
41 |
42 | More detailed documentation, and additional options, can be found in **sfo.py**.
43 |
44 | ### MATLAB package
45 |
46 | To use SFO you must first initialize the optimizer,
47 | `optimizer = sfo(@f_df, theta_init, subfunction_references, [varargin]);`
48 | then call the optimizer, specifying the number of optimization passes to perform,
49 | `theta = optimizer.optimize(20);`.
50 |
51 | The initialization parameters are:
52 | - *f_df* - Returns the function value and gradient for a single subfunction
53 | call. Should have the form
54 | `[f, dfdtheta] = f_df(theta, subfunction_references{idx}, varargin{:})`,
55 | where *idx* is the index of a single subfunction.
56 | - *theta_init* - The initial parameters to be used for optimization. *theta_init* can
57 | be either a vector, a matrix, or a cell array with a vector or
58 | matrix in every cell. The gradient returned by *f_df* should have the
59 | same form as *theta_init*.
60 | - *subfunction_references* - A cell array containing an identifying element
61 | for each subfunction. The elements in this list could be, eg,
62 | matrices containing minibatches, or indices identifying the
63 | subfunction, or filenames from which target data should be read.
64 | **If each subfunction corresponds to a minibatch, then the number of
65 | subfunctions should be approximately \[number subfunctions\] = sqrt(\[dataset size\])/10**.
66 | - *[varargin]* - Any additional parameters, which will be passed through to *f_df* each time
67 | it is called.
68 |
69 | Slightly more documentation can be found in **sfo.m**.
70 |
71 | ### Special situations
72 |
73 | Email jaschasd@google.com with questions if you don't find your answer here.
74 |
75 | #### Reducing overhead
76 |
77 | If too much time is spent inside SFO relative to inside the objective function, then reduce the number of subfunctions by increasing the minibatch size or merging subfunctions.
78 |
79 | #### Replacing minibatches / using SFO in the infinite data limit
80 |
81 | In the Python version of the code, a subfunction or minibatch can be replaced by calling the function `replace_subfunction`. See the documentation in ***sfo.py*** for more details. Note that replacing a subfunction without calling `replace_subfunction` will cause the optimizer to fail, since SFO relies on subfunctions returning consistent gradients.
82 |
83 | #### Using with dropout
84 |
85 | Stochastic gradients will break SFO, because it uses the change in the minibatch/subfunction gradient to estimate the Hessian matrix. The benefits of noise regularization can be achieved without making the gradients stochastic by using frozen noise. That is, in the case of dropout, assign a random dropout mask to each datapoint. Every time that datapoint is evaluated however, use the same dropout mask. This makes the gradients consistent across multiple evaluations of the minibatch.
86 |
87 | ## Reproduce figures from paper
88 | To reproduce the figures from the paper, run **generate\_figures/figure\_\*.py**. Several of the figures rely on a subdirectory **figure_data/** with training data. This can be downloaded from https://www.dropbox.com/sh/h9z4djlgl2tagmu/GlVAJyErf8 .
89 |
90 | ## Example code
91 |
92 | The following code blocks train an autoencoder using SFO in Python and MATLAB respectively. Identical code is in **sfo_demo.py** and **sfo_demo.m**.
93 |
94 | ### Python example code
95 |
96 | ```python
97 | import matplotlib.pyplot as plt
98 | import numpy as np
99 | from numpy.random import randn
100 | from sfo import SFO
101 |
102 | # define an objective function and gradient
103 | def f_df(theta, v):
104 | """
105 | Calculate reconstruction error and gradient for an autoencoder with sigmoid
106 | nonlinearity.
107 | v contains the training data, and will be different for each subfunction.
108 | """
109 | h = 1./(1. + np.exp(-(np.dot(theta['W'], v) + theta['b_h'])))
110 | v_hat = np.dot(theta['W'].T, h) + theta['b_v']
111 | f = np.sum((v_hat - v)**2) / v.shape[1]
112 | dv_hat = 2.*(v_hat - v) / v.shape[1]
113 | db_v = np.sum(dv_hat, axis=1).reshape((-1,1))
114 | dW = np.dot(h, dv_hat.T)
115 | dh = np.dot(theta['W'], dv_hat)
116 | db_h = np.sum(dh*h*(1.-h), axis=1).reshape((-1,1))
117 | dW += np.dot(dh*h*(1.-h), v.T)
118 | dfdtheta = {'W':dW, 'b_h':db_h, 'b_v':db_v}
119 | return f, dfdtheta
120 |
121 | # set model and training data parameters
122 | M = 20 # number visible units
123 | J = 10 # number hidden units
124 | D = 100000 # full data batch size
125 | N = int(np.sqrt(D)/10.) # number minibatches
126 | # generate random training data
127 | v = randn(M,D)
128 |
129 | # create the array of subfunction specific arguments
130 | sub_refs = []
131 | for i in range(N):
132 | # extract a single minibatch of training data.
133 | sub_refs.append(v[:,i::N])
134 |
135 | # initialize parameters
136 | theta_init = {'W':randn(J,M), 'b_h':randn(J,1), 'b_v':randn(M,1)}
137 | # initialize the optimizer
138 | optimizer = SFO(f_df, theta_init, sub_refs)
139 | # run the optimizer for 1 pass through the data
140 | theta = optimizer.optimize(num_passes=1)
141 | # continue running the optimizer for another 20 passes through the data
142 | theta = optimizer.optimize(num_passes=20)
143 |
144 | # plot the convergence trace
145 | plt.plot(np.array(optimizer.hist_f_flat))
146 | plt.xlabel('Iteration')
147 | plt.ylabel('Minibatch Function Value')
148 | plt.title('Convergence Trace')
149 |
150 | # test the gradient of f_df
151 | optimizer.check_grad()
152 | ```
153 |
154 | ### MATLAB example code
155 |
156 | ```MATLAB
157 | % set model and training data parameters
158 | M = 20; % number visible units
159 | J = 10; % number hidden units
160 | D = 100000; % full data batch size
161 | N = floor(sqrt(D)/10.); % number minibatches
162 | % generate random training data
163 | v = randn(M,D);
164 |
165 | % create the cell array of subfunction specific arguments
166 | sub_refs = cell(N,1);
167 | for i = 1:N
168 | % extract a single minibatch of training data.
169 | sub_refs{i} = v(:,i:N:end);
170 | end
171 |
172 | % initialize parameters
173 | % Parameters can be stored as a vector, a matrix, or a cell array with a
174 | % vector or matrix in each cell. Here the parameters are
175 | % {[weight matrix], [hidden bias], [visible bias]}.
176 | theta_init = {randn(J,M), randn(J,1), randn(M,1)};
177 | % initialize the optimizer
178 | optimizer = sfo(@f_df_autoencoder, theta_init, sub_refs);
179 | % run the optimizer for half a pass through the data
180 | theta = optimizer.optimize(0.5);
181 | % run the optimizer for another 20 passes through the data, continuing from
182 | % the theta value where the prior call to optimize() ended
183 | theta = optimizer.optimize(20);
184 |
185 | % plot the convergence trace
186 | plot(optimizer.hist_f_flat);
187 | xlabel('Iteration');
188 | ylabel('Minibatch Function Value');
189 | title('Convergence Trace');
190 |
191 | % test the gradient of f_df
192 | optimizer.check_grad();
193 | ```
194 |
195 | The subfunction/minibatch objective function and gradient for the MATLAB code is defined as follows,
196 | ```MATLAB
197 | function [f, dfdtheta] = f_df_autoencoder(theta, v)
198 | % [f, dfdtheta] = f_df_autoencoder(theta, v)
199 | % Calculate L2 reconstruction error and gradient for an autoencoder
200 | % with sigmoid nonlinearity.
201 | % Parameters:
202 | % theta - A cell array containing
203 | % {[weight matrix], [hidden bias], [visible bias]}.
204 | % v - A [# visible, # datapoints] matrix containing training data.
205 | % v will be different for each subfunction.
206 | % Returns:
207 | % f - The L2 reconstruction error for data v and parameters theta.
208 | % df - A cell array containing the gradient of f with each of the
209 | % parameters in theta.
210 |
211 | W = theta{1};
212 | b_h = theta{2};
213 | b_v = theta{3};
214 |
215 | h = 1./(1 + exp(-bsxfun(@plus, W * v, b_h)));
216 | v_hat = bsxfun(@plus, W' * h, b_v);
217 | f = sum(sum((v_hat - v).^2)) / size(v, 2);
218 | dv_hat = 2*(v_hat - v) / size(v, 2);
219 | db_v = sum(dv_hat, 2);
220 | dW = h * dv_hat';
221 | dh = W * dv_hat;
222 | db_h = sum(dh.*h.*(1-h), 2);
223 | dW = dW + dh.*h.*(1-h) * v';
224 | % give the gradients the same order as the parameters
225 | dfdtheta = {dW, db_h, db_v};
226 | end
227 | ```
228 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sohl-Dickstein/Sum-of-Functions-Optimizer/36c66dfef68b344eca0ca9bb56d0a7bed2075036/__init__.py
--------------------------------------------------------------------------------
/generate_figures/README.md:
--------------------------------------------------------------------------------
1 | Reproducible Science
2 | ================================
3 |
4 | The code in this directory reproduces all of the figures from the paper
5 | > Jascha Sohl-Dickstein, Ben Poole, and Surya Ganguli
6 | > An adaptive low dimensional quasi-Newton sum of functions optimizer
7 | > International Conference on Machine Learning (2014)
8 | > arXiv preprint arXiv:1311.2115 (2013)
9 | > http://arxiv.org/abs/1311.2115
10 |
11 | - **figure\_compare_optimizers.py** produces the convergence comparison of different optimizers in Figure 3.
12 | - **figure\_compare_sfo_N.py** produces the convergence comparison for SFO with different numbers of subfunctions or minibatches in Figure 2(c).
13 | - **figure\_compare_sfo_variations.py** produces the convergence comparison for SFO with different design choices in supplemental Figure C.1.
14 | - **figure\_overhead.py** produces the computational overhead analysis in Figure 2(a,b).
15 | - **figure\_cartoon.py** produces the cartoon illustration of the SFO algorithm in Figure 1.
16 |
17 | To include a new objective function in the convergence comparison figure:
18 |
19 | 1. Add a class to **models.py** which provides the objective function and gradient, and initialization, for the new objective. The *toy* class is a good template to modify.
20 | 2. Add the new objective class to *models_to_train* in **figure_compare_optimizers.py**.
21 | 3. (optional) Modify plot characteristics in *make_plot_single_model* in **convergence_utils.py**.
22 | 4. Run **figure_compare_optimizers.py**.
23 |
24 | To include a new optimizer in the convergence comparison figure:
25 |
26 | 1. Add a function implementing the optimizer to **optimization_wrapper.py**. The *SGD* function is a good template to modify.
27 | 2. Add the model to *optimizers_to_use* in **figure_compare_optimizers.py**.
28 | 3. (optional) Modify plot characteristics in *make_plot_single_model* in **convergence_utils.py**. (for instance, to select only the best performing hyperparameters, add the new optimizer to the list of *best_traces* function calls)
29 | 4. Run **figure_compare_optimizers.py**.
30 |
31 | For more documentation on SFO in general, see the README.md in the parent directory.
32 |
33 | Several of the figures rely on a subdirectory **figure_data/** with training data. This can be downloaded from https://www.dropbox.com/sh/h9z4djlgl2tagmu/GlVAJyErf8 .
34 |
--------------------------------------------------------------------------------
/generate_figures/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sohl-Dickstein/Sum-of-Functions-Optimizer/36c66dfef68b344eca0ca9bb56d0a7bed2075036/generate_figures/__init__.py
--------------------------------------------------------------------------------
/generate_figures/adagrad.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2013 Jascha Sohl-Dickstein, Ben Poole
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 |
13 | This is an implementation of the ADAGrad algorithm:
14 | Duchi, J., Hazan, E., & Singer, Y. (2010).
15 | Adaptive subgradient methods for online learning and stochastic optimization.
16 | Journal of Machine Learning Research
17 | http://www.eecs.berkeley.edu/Pubs/TechRpts/2010/EECS-2010-24.pdf
18 | """
19 |
20 | from numpy import *
21 | import numpy as np
22 |
23 | class ADAGrad(object):
24 |
25 | def __init__(self, f_df, theta, subfunction_references, reps=1, learning_rate=0.1, args=(), kwargs={}):
26 |
27 | self.reps = reps
28 | self.learning_rate = learning_rate
29 |
30 | self.N = len(subfunction_references)
31 | self.sub_ref = subfunction_references
32 | self.f_df = f_df
33 | self.args = args
34 | self.kwargs = kwargs
35 |
36 | self.num_steps = 0
37 | self.theta = theta.copy().reshape((-1,1))
38 | self.grad_history = np.zeros_like(self.theta)
39 | self.M = self.theta.shape[0]
40 |
41 | self.f = ones((self.N))*np.nan
42 |
43 |
44 | def optimize(self, num_passes = 10, num_steps = None):
45 | if num_steps==None:
46 | num_steps = num_passes*self.N
47 | for i in range(num_steps):
48 | if not self.optimization_step():
49 | break
50 | #print 'L ', self.L
51 | return self.theta
52 |
53 | def optimization_step(self):
54 | idx = np.random.randint(self.N)
55 | gradii = np.zeros_like(self.theta)
56 | lossii = 0.
57 | for i in range(self.reps):
58 | lossi, gradi = self.f_df(self.theta, (self.sub_ref[idx], ), *self.args, **self.kwargs)
59 | lossii += lossi / self.reps
60 | gradii += gradi.reshape(gradii.shape) / self.reps
61 |
62 | self.num_steps += 1
63 | learning_rates = self.learning_rate / (np.sqrt(1./self.num_steps + self.grad_history))
64 | learning_rates[np.isinf(learning_rates)] = self.learning_rate
65 | self.theta -= learning_rates * gradii
66 | self.grad_history += gradii**2
67 | self.f[idx] = lossii
68 |
69 | if not np.isfinite(lossii):
70 | print("Non-finite subfunction. Ending run.")
71 | return False
72 | return True
73 |
74 |
--------------------------------------------------------------------------------
/generate_figures/convergence_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains various shared functions used to generate the convergence
3 | figures in the SFO paper.
4 |
5 | Copyright 2014 Jascha Sohl-Dickstein
6 | Licensed under the Apache License, Version 2.0 (the "License");
7 | you may not use this file except in compliance with the License.
8 | You may obtain a copy of the License at
9 | http://www.apache.org/licenses/LICENSE-2.0
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 |
18 | import matplotlib
19 | matplotlib.use('Agg') # no displayed figures -- need to call before loading pylab
20 | import matplotlib.pyplot as plt
21 |
22 | import datetime
23 | import glob
24 | import numpy as np
25 | import re
26 | import warnings
27 |
28 | from collections import defaultdict
29 | from itertools import cycle
30 |
31 |
32 | def sorted_nicely(strings):
33 | "Sort strings the way humans are said to expect."
34 | return sorted(strings, key=natural_sort_key)
35 |
36 | def natural_sort_key(key):
37 | import re
38 | return [int(t) if t.isdigit() else t for t in re.split(r'(\d+)', key)]
39 |
40 | def best_traces(prop, hist_f, styledict, color, neighbor_dist = 1):
41 | """
42 | Prunes the history to only the best trace in prop, and its
43 | neighbors within neighbor_dist in the sorted order. It sets
44 | only the best trace to a dark line, and makes them all color.
45 | """
46 |
47 | if len(prop)>0:
48 | minnm = prop[0]
49 | prop = sorted_nicely(prop)
50 | # get the best one
51 | minf = np.inf
52 | for nm in prop:
53 | ff = np.asarray(hist_f[nm])
54 | minf2 = ff[-1] #np.min(ff)
55 | print(nm, minf2)
56 | if minf2 < minf:
57 | minnm = nm
58 | minf = minf2
59 | ii = prop.index(minnm)
60 | for i in range(len(prop)):
61 | if np.abs(i-ii) > neighbor_dist:
62 | del hist_f[prop[i]]
63 | for jj in range(1,neighbor_dist+1):
64 | try:
65 | styledict[prop[ii-jj]] = {'color':color, 'ls':':'}
66 | except:
67 | print("failure around", prop[0])
68 | try:
69 | styledict[prop[ii+jj]] = {'color':color, 'ls':'-.'}
70 | except:
71 | print("failure around", prop[-1])
72 | styledict[prop[ii]] = {'color':color, 'ls':'-', 'linewidth':4}
73 |
74 | def make_plot_single_model(hist_f, hist_x_projection, hist_events, model_name,
75 | num_subfunctions, full_objective_period, display_events=False,
76 | display_trajectory=False, figsize=(4.5,3.5), external_legend=True, name_prefix=""):
77 | """
78 | Plot the different optimizers against each other for a single
79 | model.
80 | """
81 |
82 | # xlabel was getting cut off, this seems to fix?
83 | from matplotlib import rcParams
84 | rcParams.update({'figure.autolayout': True})
85 |
86 | # set the title
87 | title = model_name
88 | if model_name == 'protein' or model_name == 'protein logistic regression':
89 | title = 'Logistic Regression, Protein Dataset'
90 | elif model_name == 'Hopfield':
91 | title = 'Ising / Hopfield with MPF Objective'
92 | elif 'hard' in model_name:
93 | title = 'Multi-Layer Perceptron, Rectified Linear'
94 | elif 'soft' in model_name:
95 | title = 'Multi-Layer Perceptron, Sigmoid'
96 | elif model_name == 'sumf':
97 | title = 'Sum of Norms'
98 | elif model_name == 'Pylearn_conv':
99 | title = 'Convolutional Network, CIFAR-10'
100 | elif model_name == 'GLM':
101 | title = 'GLM, soft rectifying nonlinearity'
102 | model_name = 'GLM_soft'
103 |
104 | # set up linestyle cycler
105 | styles = [
106 | {'color':'r', 'ls':'--'},
107 | {'color':'g', 'ls':'-'},
108 | {'color':'b', 'ls':'-.'},
109 | {'color':'k', 'ls':':'},
110 | {'color':'y', 'ls':'--'},
111 | {'color':'r', 'ls':'-'},
112 | {'color':'g', 'ls':'-.'},
113 | {'color':'b', 'ls':':'},
114 | {'color':'k', 'ls':'--'},
115 | {'color':'y', 'ls':'-'},
116 | {'color':'r', 'ls':'-.'},
117 | {'color':'g', 'ls':':'},
118 | {'color':'b', 'ls':'--'},
119 | {'color':'k', 'ls':'-'},
120 | {'color':'y', 'ls':'-.'},
121 | ]
122 | stylecycler = cycle(styles)
123 | styledict = defaultdict(lambda: next(stylecycler))
124 |
125 | zorder = dict()
126 | sorted_nm = sorted(hist_f.keys(), key=lambda nm: np.asarray(hist_f[nm])[-1], reverse=True)
127 | for ii, nm in enumerate(sorted_nm):
128 | zorder[nm] = ii
129 |
130 | ## override the default cycled styles for specific optimizers
131 | # LBFGS
132 | prop = [nm for nm in hist_f.keys() if 'LBFGS' in nm]
133 | for nm in prop:
134 | styledict[nm] = {'color':'r', 'ls':'-', 'linewidth':3}
135 | if 'batch' in nm:
136 | styledict[nm]['ls'] = ':'
137 | # SAG
138 | prop = [nm for nm in hist_f.keys() if 'SAG' in nm]
139 | best_traces(prop, hist_f, styledict, 'g')
140 | # SGD
141 | prop = [nm for nm in hist_f.keys() if 'SGD' in nm and 'momentum' not in nm]
142 | best_traces(prop, hist_f, styledict, 'c')
143 | # SGD_momentum
144 | prop = [nm for nm in hist_f.keys() if 'SGD' in nm and 'momentum' in nm]
145 | best_traces(prop, hist_f, styledict, 'm', neighbor_dist=0)
146 | # ADA
147 | prop = [nm for nm in hist_f.keys() if 'ADA' in nm]
148 | best_traces(prop, hist_f, styledict, 'b')
149 | # SFO
150 | prop = [nm for nm in hist_f.keys() if nm == 'SFO' or nm == 'SFO standard']
151 | for nm in prop:
152 | styledict[nm] = {'color':'k', 'ls':'-', 'linewidth':4}
153 | # SFO number minibatches
154 | prop = [nm for nm in hist_f.keys() if 'SFO' in nm and 'N=' in nm]
155 | nprop = len(prop)
156 | for ii, nm in enumerate(sorted_nicely(prop)):
157 | #styledict[nm] = {'color':'k', 'dashes':(7,nprop/(ii+1.),), 'linewidth':5.*(ii+1.)/nprop}
158 | styledict[nm] = {'color':(1. - ii/(nprop-1.), 0., 0.), 'ls':'-', 'linewidth':4.*(ii+1.)/nprop}
159 | zorder[nm] = ii
160 | # SFO number history terms
161 | prop = [nm for nm in hist_f.keys() if 'SFO' in nm and 'L=' in nm]
162 | nprop = len(prop)
163 | for ii, nm in enumerate(sorted_nicely(prop)):
164 | #styledict[nm] = {'color':'k', 'dashes':(7,nprop/(ii+1.),), 'linewidth':5.*(ii+1.)/nprop}
165 | styledict[nm] = {'color':(1. - ii/(nprop-1.), 0., 0.), 'ls':'-', 'linewidth':4.*(ii+1.)/nprop}
166 | zorder[nm] = ii
167 |
168 | # plot the learning trace, and save pdf
169 | fig = plt.figure(figsize=figsize)
170 |
171 | lines = []
172 | labels = []
173 |
174 | fewest_passes = 0.
175 |
176 | for nm in sorted_nicely(hist_f.keys()):
177 | ff = np.asarray(hist_f[nm])
178 | ff[ff>1.5*ff[0]] = np.nan
179 | xx = np.arange(1, len(ff)+1).astype(float)*full_objective_period/num_subfunctions
180 | if max(np.max(xx), fewest_passes) > 2*min(np.max(xx), fewest_passes) or nm == 'LBFGS':
181 | # ignore cases that were terminated early eg because of bad learning rate
182 | # also, sometimes LBFGS terminates early, so don't judge based on that
183 | fewest_passes = max(np.max(xx), fewest_passes)
184 | else:
185 | fewest_passes = min(np.max(xx), fewest_passes)
186 | #assert(fewest_passes > 4.)
187 | line = plt.semilogy( xx, ff, label=nm, zorder=zorder[nm], **styledict[nm] )
188 | lines.append(line[0])
189 | labels.append(nm)
190 | if display_events:
191 | # add special events
192 | for jj, events in enumerate(hist_events[nm]):
193 | st = {'s':100}
194 | if events.get('natural gradient subspace update', False):
195 | st['marker'] = '>'
196 | st['c'] = 'r'
197 | elif events.get('collapse subspace', False):
198 | st['marker'] = '*'
199 | st['c'] = 'y'
200 | elif events.get('step_failure', False):
201 | st['marker'] = '<'
202 | st['c'] = 'c'
203 | else:
204 | continue
205 | plt.scatter(xx[jj], ff[jj], **st)
206 |
207 |
208 | plt.ylabel( 'Full Batch Objective' )
209 | plt.xlabel( 'Effective Passes Through Data' )
210 | plt.title(title)
211 | plt.grid()
212 | plt.axes().set_axisbelow(True)
213 |
214 | ax = plt.axis()
215 | plt.axis([0, fewest_passes, ax[2], ax[3]])
216 | ax = plt.axis()
217 | if "Autoencoder" in title:
218 | plt.yticks(np.arange(10, 46, 5.0), ["%d"%tt for tt in np.arange(10, 46, 5.0)])
219 | plt.axis([ax[0], ax[1], 13, 30])
220 | # elif "ICA" in title:
221 | # plt.axis([ax[0], ax[1], 0.23, 140])
222 | elif "hard" in title:
223 | plt.axis([ax[0], ax[1], 1e-13, 1e1])
224 | elif "Perceptron" in title:
225 | plt.axis([ax[0], ax[1], 1e-7, 1e1])
226 | elif "ICA" in title:
227 | plt.axis([ax[0], ax[1], 1e0, 1e3])
228 | try:
229 | plt.tight_layout()
230 | except:
231 | warnings.warn('tight_layout failed. try running with an Agg backend.')
232 |
233 | # update the labels to prettier text
234 | labels = map(lambda x: str.replace(x, "SGD_momentum ", r"SGD+mom "), labels)
235 | labels = map(lambda x: str.replace(x, "SGD ", r"SGD eta="), labels)
236 | labels = map(lambda x: str.replace(x, "ADAGrad ", r"ADAGrad eta="), labels)
237 | labels = map(lambda x: str.replace(x, "SAG ", r"SAG L="), labels)
238 | labels = map(lambda x: str.replace(x, "L=", r"$L=$"), labels)
239 | labels = map(lambda x: str.replace(x, "eta=", r"$\eta=$"), labels)
240 | labels = map(lambda x: str.replace(x, "mu=", r"$\mu=$"), labels)
241 | def number_shaver(ch, regx = re.compile('(?1.5*ff[0]] = np.nan
279 | xx = np.arange(1, len(ff)+1).astype(float)*full_objective_period/num_subfunctions
280 | plt.semilogy( xx, ff-minf, label=nm, zorder=zorder[nm], **styledict[nm] )
281 | plt.ylabel( 'Full Batch Objective - Minimum' )
282 | plt.xlabel( 'Effective Passes Through Data' )
283 | if not external_legend:
284 | plt.legend( loc='best' )
285 | plt.title(title)
286 | plt.grid()
287 | plt.axes().set_axisbelow(True)
288 | ax = plt.axis()
289 | plt.axis([0, fewest_passes, ax[2], ax[3]])
290 | try:
291 | plt.tight_layout()
292 | except:
293 | warnings.warn('tight_layout failed. try running with an Agg backend.')
294 | fig.savefig(('figure_' + name_prefix + model_name + '_diff.pdf').replace(' ', '-'))
295 |
296 | if display_trajectory:
297 | # SAG
298 | prop = [nm for nm in hist_f.keys() if 'SAG' in nm]
299 | best_traces(prop, hist_f, styledict, 'g', neighbor_dist=0)
300 | # SGD
301 | prop = [nm for nm in hist_f.keys() if 'SGD' in nm]
302 | best_traces(prop, hist_f, styledict, 'c', neighbor_dist=0)
303 | # ADA
304 | prop = [nm for nm in hist_f.keys() if 'ADA' in nm]
305 | best_traces(prop, hist_f, styledict, 'b', neighbor_dist=0)
306 |
307 | # plot the learning trajectory in low-d projections, and save pdf
308 | # make the line styles appropriate
309 | for nm in hist_f.keys():
310 | styledict[nm].pop('ls', None)
311 | styledict[nm].pop('linewidth', None)
312 | #styledict[nm]['linestyle'] = 'None'
313 | styledict[nm]['marker'] = '.' # ','
314 | styledict[nm]['alpha'] = 0.5 # ','
315 | fig = plt.figure(figsize=figsize)
316 | nproj = 3 # could make larger
317 | for i1 in range(nproj):
318 | for i2 in range(nproj):
319 | plt.subplot(nproj, nproj, i1 + nproj*i2 + 1)
320 | for nm in sorted_nicely(hist_f.keys()):
321 | xp = np.asarray(hist_x_projection[nm])
322 | plt.plot( xp[:,i1], xp[:,i2], label=nm, zorder=zorder[nm], **styledict[nm] )
323 | if display_events:
324 | # add special events
325 | for jj, events in enumerate(hist_events[nm]):
326 | st = {'s':100}
327 | if events.get('natural gradient subspace update', False):
328 | st['marker'] = '>'
329 | st['c'] = 'r'
330 | elif events.get('collapse subspace', False):
331 | st['marker'] = '*'
332 | st['c'] = 'y'
333 | elif events.get('step_failure', False):
334 | st['marker'] = '<'
335 | st['c'] = 'c'
336 | else:
337 | continue
338 | plt.scatter(xp[jj,i1], xp[jj,i2], **st)
339 | if not external_legend:
340 | plt.legend( loc='best' )
341 | plt.suptitle(title)
342 | try:
343 | plt.tight_layout()
344 | except:
345 | warnings.warn('tight_layout failed. try running with an Agg backend.')
346 | fig.savefig(('figure_' + name_prefix + model_name + '_trajectory.pdf').replace(' ', '-'))
347 |
348 |
349 | def make_plots(history_nested, *args, **kwargs):
350 | for model_name in history_nested:
351 | history = history_nested[model_name]
352 | fig = make_plot_single_model(history['f'], history['x_projection'], history['events'], model_name, *args, **kwargs)
353 |
354 | def load_results(fnames=None, base_fname='figure_data_'):
355 | """
356 | Load the function value traces during optimization for the
357 | set of models and optimizers provided by fnames. Find all
358 | files with matching filenames in the current directory if
359 | fnames not passed in.
360 |
361 | Output is a dictionary of dictionaries, where the inner dictionary
362 | contains different optimizers for each model, and the outer dictionary
363 | contains different models.
364 |
365 | Note that the files loaded can be as granular as a single optimizer
366 | for a single model for each file.
367 | """
368 |
369 | if fnames==None:
370 | fnames = glob.glob(base_fname + '*.npz')
371 |
372 | num_subfunctions = None
373 | full_objective_period = None
374 |
375 | history_nested = {}
376 | for fn in fnames:
377 | data = np.load(fn)
378 | if num_subfunctions is None:
379 | num_subfunctions = data['num_subfunctions']
380 | full_objective_period = data['full_objective_period']
381 | if not (num_subfunctions == data['num_subfunctions'] and full_objective_period == data['full_objective_period']):
382 | print "****************"
383 | print "WARNING: mixing data with different numbers of subfunctions or delays between evaluating the full objective"
384 | print "make sure you are doing this intentionally (eg, for the convergence vs., number subfunctions plot)"
385 | print "****************"
386 | model_name = data['model_name'].tostring()
387 | print("loading", model_name)
388 | if not model_name in history_nested:
389 | history_nested[model_name] = data['history'][()].copy()
390 | else:
391 | print("updating")
392 | for subkey in history_nested[model_name].keys():
393 | print subkey
394 | history_nested[model_name][subkey].update(data['history'][()].copy()[subkey])
395 | data.close()
396 |
397 | return history_nested, num_subfunctions, full_objective_period
398 |
399 | def save_results(trainer, base_fname='figure_data_', store_x=True):
400 | """
401 | Save the function trace for different optimizers for a
402 | given model to a .npz file.
403 | """
404 | if not store_x:
405 | # delete the saved final x value so we don't run out of memory
406 | trainer.history['x'] = defaultdict(list)
407 |
408 | fname = base_fname + trainer.model.name + ".npz"
409 | np.savez(fname, history=trainer.history, model_name=trainer.model.name,
410 | num_subfunctions = trainer.num_subfunctions,
411 | full_objective_period=trainer.full_objective_period)
412 |
--------------------------------------------------------------------------------
/generate_figures/dropout/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sohl-Dickstein/Sum-of-Functions-Optimizer/36c66dfef68b344eca0ca9bb56d0a7bed2075036/generate_figures/dropout/.DS_Store
--------------------------------------------------------------------------------
/generate_figures/dropout/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | *.txt
3 | *.png
4 | *.swp
5 | *.pyc
6 |
--------------------------------------------------------------------------------
/generate_figures/dropout/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2012 Misha Denil
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | this software and associated documentation files (the "Software"), to deal in
5 | the Software without restriction, including without limitation the rights to
6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
7 | of the Software, and to permit persons to whom the Software is furnished to do
8 | so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
21 |
--------------------------------------------------------------------------------
/generate_figures/dropout/README:
--------------------------------------------------------------------------------
1 | Theano implementation of dropout. See: http://arxiv.org/abs/1207.0580
2 |
3 | Run with:
4 |
5 | ./mlp.py dropout
6 |
7 | for dropout, or
8 |
9 | ./mlp.py backprop
10 |
11 | for regular backprop with no dropout.
12 |
13 | Use:
14 |
15 | ./plot_results.sh results.png
16 |
17 | to visualize the results.
18 |
19 | Based on code from:
20 | - http://deeplearning.net/tutorial/mlp.html
21 | - http://deeplearning.net/tutorial/logreg.html
22 |
23 | Use the data here to make the units of the results comparable to Hinton's paper:
24 | - http://www.cs.ubc.ca/~mdenil/hidden/mnist_batches.npz
25 |
26 |
--------------------------------------------------------------------------------
/generate_figures/dropout/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sohl-Dickstein/Sum-of-Functions-Optimizer/36c66dfef68b344eca0ca9bb56d0a7bed2075036/generate_figures/dropout/__init__.py
--------------------------------------------------------------------------------
/generate_figures/dropout/deepae.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import sys
4 | import time
5 |
6 | import theano
7 | import theano.tensor as T
8 | from theano.ifelse import ifelse
9 | import theano.printing
10 | import theano.tensor.shared_randomstreams
11 |
12 | from logistic_sgd import LogisticRegression
13 | from load_data import load_mnist
14 |
15 | class HiddenLayer(object):
16 | def __init__(self, rng, input, n_in, n_out,
17 | activation, W=None, b=None,
18 | use_bias=False, sparse_init=15):
19 |
20 | self.input = input
21 | self.activation = activation
22 |
23 | if W is None:
24 | if sparse_init is not None:
25 | W_values = np.zeros((n_in, n_out), dtype=theano.config.floatX)
26 | for i in xrange(n_out):
27 | for j in xrange(sparse_init):
28 | idx = rng.randint(0, n_in)
29 | # don't worry about getting exactly sparse_init nonzeroese
30 | #while W_values[idx, i] != 0.:
31 | # idx = rng.randint(0, n_in)
32 | W_values[idx, i] = rng.randn()/np.sqrt(sparse_init) * 1.2
33 | else:
34 | W_values = 1. * np.asarray(rng.standard_normal(
35 | size=(n_in, n_out)), dtype=theano.config.floatX) / np.float(n_in)**0.5
36 | W = theano.shared(value=W_values, name='W')
37 | if b is None:
38 | b_values = np.zeros((n_out,), dtype=theano.config.floatX)
39 | b = theano.shared(value=b_values, name='b')
40 | self.W = W
41 | self.b = b
42 |
43 | if use_bias:
44 | lin_output = T.dot(input, self.W) + self.b
45 | else:
46 | lin_output = T.dot(input, self.W)
47 |
48 | self.output = (lin_output if activation is None else activation(lin_output))
49 | # parameters of the model
50 | if use_bias:
51 | self.params = [self.W, self.b]
52 | else:
53 | self.params = [self.W]
54 |
55 |
56 | class DeepAE(object):
57 | """
58 | Deep Autoencoder.
59 | """
60 | def __init__(self,
61 | rng,
62 | input,
63 | layer_sizes,
64 | use_bias=True, objective='crossent'):
65 | # Activation functions are all sigmoid except
66 | # "coding" layer that is linear.
67 | # TODO(ben): Make sure layer after coding layer is sigmoid.
68 | layer_acts = [T.nnet.sigmoid for l in layer_sizes[1:-1]]
69 | layer_acts = layer_acts + [None, T.nnet.sigmoid] + layer_acts
70 | # use for untied weights
71 | #layer_acts = layer_acts + [None, None] + layer_acts
72 | layer_sizes = layer_sizes + layer_sizes[:-1][::-1]
73 | print 'Layer sizes:',layer_sizes
74 | print 'Layer acts:',layer_acts
75 |
76 | # Set up all the hidden layers
77 | weight_matrix_sizes = zip(layer_sizes, layer_sizes[1:])
78 | self.layers = []
79 | next_layer_input = input
80 | idx = 0
81 | for n_in, n_out in weight_matrix_sizes:
82 | print (n_in,n_out), layer_acts[idx]
83 | #if idx == 1:
84 | # W_in = next_layer.W.T
85 | #else:
86 | # W_in = None
87 | next_layer = HiddenLayer(rng=rng,
88 | input=next_layer_input,
89 | activation=layer_acts[idx],
90 | n_in=n_in, n_out=n_out,
91 | use_bias=use_bias)
92 | self.layers.append(next_layer)
93 | next_layer_input = next_layer.output
94 | idx += 1
95 | xpred = next_layer_input
96 | # Compute cost function, making sure to sum across data examples
97 | # so that we can properly average across minibatches.
98 | if objective == 'crossent':
99 | self.cost = (-input * T.log(xpred) - (1 - input) * T.log(1 - xpred)).sum(axis=1).sum()
100 | else:
101 | self.cost = ((input-xpred)**2).sum(axis=1).sum()
102 | # Grab all the parameters together.
103 | #XXX
104 | self.params = [ param for layer in self.layers for param in layer.params ]
105 | #self.params = [self.layers[0].W]
106 |
107 | def convert_variable(x):
108 | if x.ndim == 1:
109 | return T.vector(x.name, dtype=x.dtype)
110 | else:
111 | return T.matrix(x.name, dtype=x.dtype)
112 |
113 | def build_f_df(layer_sizes, dropout=False, **kwargs):
114 | print '... building the model'
115 | x = T.matrix('x') # the data is presented as rasterized images
116 | y = T.ivector('y') # the labels are presented as 1D vector of
117 | # [int] labels
118 | # Make sure initialization is repeatable
119 | rng = np.random.RandomState(1234)
120 | # construct the MLP class
121 | dae = DeepAE(rng=rng, input=x,
122 | layer_sizes=layer_sizes, **kwargs)
123 | # Build the expresson for the cost function.
124 | cost = dae.cost
125 | gparams = []
126 | for param in dae.params:
127 | gparam = T.grad(cost, param)
128 | gparams.append(gparam)
129 | symbolic_params = [convert_variable(param) for param in dae.params]
130 | givens = dict(zip(dae.params, symbolic_params))
131 | f_df = theano.function(inputs=symbolic_params + [x], outputs=[cost] + gparams, givens=givens)
132 | return f_df, dae
133 |
--------------------------------------------------------------------------------
/generate_figures/dropout/load_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cPickle
3 | import gzip
4 | import os
5 | import sys
6 |
7 | import theano
8 | import theano.tensor as T
9 |
10 | def _shared_dataset(data_xy):
11 | """ Function that loads the dataset into shared variables
12 |
13 | The reason we store our dataset in shared variables is to allow
14 | Theano to copy it into the GPU memory (when code is run on GPU).
15 | Since copying data into the GPU is slow, copying a minibatch everytime
16 | is needed (the default behaviour if the data is not in a shared
17 | variable) would lead to a large decrease in performance.
18 | """
19 | data_x, data_y = data_xy
20 | shared_x = theano.shared(np.asarray(data_x,
21 | dtype=theano.config.floatX))
22 | shared_y = theano.shared(np.asarray(data_y,
23 | dtype=theano.config.floatX))
24 | # When storing data on the GPU it has to be stored as floats
25 | # therefore we will store the labels as ``floatX`` as well
26 | # (``shared_y`` does exactly that). But during our computations
27 | # we need them as ints (we use labels as index, and if they are
28 | # floats it doesn't make sense) therefore instead of returning
29 | # ``shared_y`` we will have to cast it to int. This little hack
30 | # lets ous get around this issue
31 | return shared_x, T.cast(shared_y, 'int32')
32 |
33 | def load_mnist(path):
34 | mnist = np.load(path)
35 | train_set_x = mnist['train_data']
36 | train_set_y = mnist['train_labels']
37 | test_set_x = mnist['test_data']
38 | test_set_y = mnist['test_labels']
39 |
40 | train_set_x, train_set_y = _shared_dataset((train_set_x, train_set_y))
41 | test_set_x, test_set_y = _shared_dataset((test_set_x, test_set_y))
42 | valid_set_x, valid_set_y = test_set_x, test_set_y
43 |
44 | rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y),
45 | (test_set_x, test_set_y)]
46 | return rval
47 |
48 |
49 | def load_umontreal_data(dataset):
50 | ''' Loads the dataset
51 |
52 | :type dataset: string
53 | :param dataset: the path to the dataset (here MNIST)
54 | '''
55 |
56 | #############
57 | # LOAD DATA #
58 | #############
59 |
60 | # Download the MNIST dataset if it is not present
61 | data_dir, data_file = os.path.split(dataset)
62 | if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz':
63 | import urllib
64 | origin = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
65 | print 'Downloading data from %s' % origin
66 | urllib.urlretrieve(origin, dataset)
67 |
68 | print '... loading data'
69 |
70 | # Load the dataset
71 | f = gzip.open(dataset, 'rb')
72 | train_set, valid_set, test_set = cPickle.load(f)
73 | f.close()
74 | #train_set, valid_set, test_set format: tuple(input, target)
75 | #input is an np.ndarray of 2 dimensions (a matrix)
76 | #witch row's correspond to an example. target is a
77 | #np.ndarray of 1 dimensions (vector)) that have the same length as
78 | #the number of rows in the input. It should give the target
79 | #target to the example with the same index in the input.
80 |
81 | def _shared_dataset(data_xy):
82 | """ Function that loads the dataset into shared variables
83 |
84 | The reason we store our dataset in shared variables is to allow
85 | Theano to copy it into the GPU memory (when code is run on GPU).
86 | Since copying data into the GPU is slow, copying a minibatch everytime
87 | is needed (the default behaviour if the data is not in a shared
88 | variable) would lead to a large decrease in performance.
89 | """
90 | data_x, data_y = data_xy
91 | shared_x = theano.shared(np.asarray(data_x,
92 | dtype=theano.config.floatX))
93 | shared_y = theano.shared(np.asarray(data_y,
94 | dtype=theano.config.floatX))
95 | # When storing data on the GPU it has to be stored as floats
96 | # therefore we will store the labels as ``floatX`` as well
97 | # (``shared_y`` does exactly that). But during our computations
98 | # we need them as ints (we use labels as index, and if they are
99 | # floats it doesn't make sense) therefore instead of returning
100 | # ``shared_y`` we will have to cast it to int. This little hack
101 | # lets ous get around this issue
102 | return shared_x, T.cast(shared_y, 'int32')
103 |
104 | test_set_x, test_set_y = _shared_dataset(test_set)
105 | valid_set_x, valid_set_y = _shared_dataset(valid_set)
106 | train_set_x, train_set_y = _shared_dataset(train_set)
107 |
108 | rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y),
109 | (test_set_x, test_set_y)]
110 | return rval
111 |
112 |
--------------------------------------------------------------------------------
/generate_figures/dropout/logistic_sgd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import theano
4 | import theano.tensor as T
5 |
6 | class LogisticRegression(object):
7 | """Multi-class Logistic Regression Class
8 |
9 | The logistic regression is fully described by a weight matrix :math:`W`
10 | and bias vector :math:`b`. Classification is done by projecting data
11 | points onto a set of hyperplanes, the distance to which is used to
12 | determine a class membership probability.
13 | """
14 |
15 | def __init__(self, input, n_in, n_out, W=None, b=None):
16 | """ Initialize the parameters of the logistic regression
17 |
18 | :type input: theano.tensor.TensorType
19 | :param input: symbolic variable that describes the input of the
20 | architecture (one minibatch)
21 |
22 | :type n_in: int
23 | :param n_in: number of input units, the dimension of the space in
24 | which the datapoints lie
25 |
26 | :type n_out: int
27 | :param n_out: number of output units, the dimension of the space in
28 | which the labels lie
29 |
30 | """
31 |
32 | # initialize with 0 the weights W as a matrix of shape (n_in, n_out)
33 | if W is None:
34 | self.W = theano.shared(
35 | value=np.zeros((n_in, n_out), dtype=theano.config.floatX),
36 | name='W')
37 | else:
38 | self.W = W
39 |
40 | # initialize the baises b as a vector of n_out 0s
41 | if b is None:
42 | self.b = theano.shared(
43 | value=np.zeros((n_out,), dtype=theano.config.floatX),
44 | name='b')
45 | else:
46 | self.b = b
47 |
48 | # compute vector of class-membership probabilities in symbolic form
49 | self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b)
50 |
51 | # compute prediction as class whose probability is maximal in
52 | # symbolic form
53 | self.y_pred = T.argmax(self.p_y_given_x, axis=1)
54 |
55 | # parameters of the model
56 | self.params = [self.W, self.b]
57 |
58 | def negative_log_likelihood(self, y):
59 | """Return the mean of the negative log-likelihood of the prediction
60 | of this model under a given target distribution.
61 |
62 | .. math::
63 |
64 | \frac{1}{|\mathcal{D}|} \mathcal{L} (\theta=\{W,b\}, \mathcal{D}) =
65 | \frac{1}{|\mathcal{D}|} \sum_{i=0}^{|\mathcal{D}|} \log(P(Y=y^{(i)}|x^{(i)}, W,b)) \\
66 | \ell (\theta=\{W,b\}, \mathcal{D})
67 |
68 | :type y: theano.tensor.TensorType
69 | :param y: corresponds to a vector that gives for each example the
70 | correct label
71 |
72 | Note: we use the mean instead of the sum so that
73 | the learning rate is less dependent on the batch size
74 | """
75 | # y.shape[0] is (symbolically) the number of rows in y, i.e.,
76 | # number of examples (call it n) in the minibatch
77 | # T.arange(y.shape[0]) is a symbolic vector which will contain
78 | # [0,1,2,... n-1] T.log(self.p_y_given_x) is a matrix of
79 | # Log-Probabilities (call it LP) with one row per example and
80 | # one column per class LP[T.arange(y.shape[0]),y] is a vector
81 | # v containing [LP[0,y[0]], LP[1,y[1]], LP[2,y[2]], ...,
82 | # LP[n-1,y[n-1]]] and T.mean(LP[T.arange(y.shape[0]),y]) is
83 | # the mean (across minibatch examples) of the elements in v,
84 | # i.e., the mean log-likelihood across the minibatch.
85 | return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]), y])
86 |
87 | def errors(self, y):
88 | """Return a float representing the number of errors in the minibatch ;
89 | zero one loss over the size of the minibatch
90 |
91 | :type y: theano.tensor.TensorType
92 | :param y: corresponds to a vector that gives for each example the
93 | correct label
94 | """
95 |
96 | # check if y has same dimension of y_pred
97 | if y.ndim != self.y_pred.ndim:
98 | raise TypeError('y should have the same shape as self.y_pred',
99 | ('y', target.type, 'y_pred', self.y_pred.type))
100 | # check if y is of the correct datatype
101 | if y.dtype.startswith('int'):
102 | # the T.neq operator returns a vector of 0s and 1s, where 1
103 | # represents a mistake in prediction
104 | return T.sum(T.neq(self.y_pred, y))
105 | else:
106 | raise NotImplementedError()
107 |
108 |
109 |
--------------------------------------------------------------------------------
/generate_figures/dropout/mlp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import sys
4 | import time
5 |
6 | import theano
7 | import theano.tensor as T
8 | from theano.ifelse import ifelse
9 | import theano.printing
10 | import theano.tensor.shared_randomstreams
11 |
12 | from logistic_sgd import LogisticRegression
13 | from load_data import load_mnist
14 |
15 | class HiddenLayer(object):
16 | def __init__(self, rng, input, n_in, n_out,
17 | activation, W=None, b=None,
18 | use_bias=False):
19 |
20 | self.input = input
21 | self.activation = activation
22 |
23 | if W is None:
24 | #W_values = np.asarray(0.01 * rng.standard_normal(
25 | # size=(n_in, n_out)), dtype=theano.config.floatX)
26 | W_values = 2. * np.asarray(rng.standard_normal(
27 | size=(n_in, n_out)), dtype=theano.config.floatX) / np.float(n_in)**0.5
28 | W = theano.shared(value=W_values, name='W')
29 |
30 | if b is None:
31 | b_values = np.zeros((n_out,), dtype=theano.config.floatX)
32 | b = theano.shared(value=b_values, name='b')
33 |
34 | self.W = W
35 | self.b = b
36 |
37 | if use_bias:
38 | lin_output = T.dot(input, self.W) + self.b
39 | else:
40 | lin_output = T.dot(input, self.W)
41 |
42 | self.output = (lin_output if activation is None else activation(lin_output))
43 |
44 | # parameters of the model
45 | if use_bias:
46 | self.params = [self.W, self.b]
47 | else:
48 | self.params = [self.W]
49 |
50 |
51 | def _dropout_from_layer(rng, layer, p):
52 | """p is the probablity of dropping a unit
53 | """
54 | srng = theano.tensor.shared_randomstreams.RandomStreams(
55 | rng.randint(999999))
56 | # p=1-p because 1's indicate keep and p is prob of dropping
57 | mask = srng.binomial(n=1, p=1-p, size=layer.shape)
58 | # The cast is important because
59 | # int * float32 = float64 which pulls things off the gpu
60 | output = layer * T.cast(mask, theano.config.floatX)
61 | return output
62 |
63 | class DropoutHiddenLayer(HiddenLayer):
64 | def __init__(self, rng, input, n_in, n_out,
65 | activation, use_bias, W=None, b=None):
66 | super(DropoutHiddenLayer, self).__init__(
67 | rng=rng, input=input, n_in=n_in, n_out=n_out, W=W, b=b,
68 | activation=activation, use_bias=use_bias)
69 |
70 | self.output = _dropout_from_layer(rng, self.output, p=0.5)
71 |
72 |
73 | class MLP(object):
74 | """A multilayer perceptron with all the trappings required to do dropout
75 | training.
76 |
77 | """
78 | def __init__(self,
79 | rng,
80 | input,
81 | layer_sizes,
82 | use_bias=True,
83 | rectifier=None):
84 |
85 | if rectifier == 'soft':
86 | rectified_linear_activation = lambda x: T.nnet.softplus(x)
87 | elif rectifier == 'hard':
88 | rectified_linear_activation = lambda x: T.maximum(0.0, x)
89 |
90 | # Set up all the hidden layers
91 | weight_matrix_sizes = zip(layer_sizes, layer_sizes[1:])
92 | self.layers = []
93 | self.dropout_layers = []
94 | next_layer_input = input
95 | # dropout the input with prob 0.2
96 | next_dropout_layer_input = _dropout_from_layer(rng, input, p=0.2)
97 | for n_in, n_out in weight_matrix_sizes[:-1]:
98 | next_dropout_layer = DropoutHiddenLayer(rng=rng,
99 | input=next_dropout_layer_input,
100 | activation=rectified_linear_activation,
101 | n_in=n_in, n_out=n_out, use_bias=use_bias)
102 | self.dropout_layers.append(next_dropout_layer)
103 | next_dropout_layer_input = next_dropout_layer.output
104 |
105 | # Reuse the paramters from the dropout layer here, in a different
106 | # path through the graph.
107 | next_layer = HiddenLayer(rng=rng,
108 | input=next_layer_input,
109 | activation=rectified_linear_activation,
110 | W=next_dropout_layer.W * 0.5,
111 | b=next_dropout_layer.b,
112 | n_in=n_in, n_out=n_out,
113 | use_bias=use_bias)
114 | self.layers.append(next_layer)
115 | next_layer_input = next_layer.output
116 |
117 | # Set up the output layer
118 | n_in, n_out = weight_matrix_sizes[-1]
119 | dropout_output_layer = LogisticRegression(
120 | input=next_dropout_layer_input,
121 | n_in=n_in, n_out=n_out)
122 | self.dropout_layers.append(dropout_output_layer)
123 |
124 | # Again, reuse paramters in the dropout output.
125 | output_layer = LogisticRegression(
126 | input=next_layer_input,
127 | W=dropout_output_layer.W * 0.5,
128 | b=dropout_output_layer.b,
129 | n_in=n_in, n_out=n_out)
130 | self.layers.append(output_layer)
131 |
132 | # Use the negative log likelihood of the logistic regression layer as
133 | # the objective.
134 | self.dropout_negative_log_likelihood = self.dropout_layers[-1].negative_log_likelihood
135 | self.dropout_errors = self.dropout_layers[-1].errors
136 |
137 | self.negative_log_likelihood = self.layers[-1].negative_log_likelihood
138 | self.errors = self.layers[-1].errors
139 |
140 | # Grab all the parameters together.
141 | self.params = [ param for layer in self.dropout_layers for param in layer.params ]
142 |
143 | def convert_variable(x):
144 | if x.ndim == 1:
145 | return T.vector(x.name, dtype=x.dtype)
146 | else:
147 | return T.matrix(x.name, dtype=x.dtype)
148 |
149 | def f_df_split(fnc, *args, **kwargs):
150 | result = fnc(*args, **kwargs)
151 | return result[0], results[1:]
152 |
153 | def build_f_df(layer_sizes, dropout=False, use_bias=False, rectifier=None):
154 | print '... building the model'
155 | x = T.matrix('x') # the data is presented as rasterized images
156 | y = T.ivector('y') # the labels are presented as 1D vector of
157 | # [int] labels
158 | # Make sure initialization is repeatable
159 | rng = np.random.RandomState(1234)
160 | # construct the MLP class
161 | classifier = MLP(rng=rng, input=x,
162 | layer_sizes=layer_sizes, use_bias=use_bias, rectifier=rectifier)
163 | # Build the expresson for the cost function.
164 | cost = classifier.negative_log_likelihood(y)
165 | dropout_cost = classifier.dropout_negative_log_likelihood(y)
166 | # Compile theano function for testing.
167 | # test_model = theano.function(inputs=[index],
168 | # outputs=classifier.errors(y),
169 | # givens={
170 | # x: test_set_x[index * batch_size:(index + 1) * batch_size],
171 | # y: test_set_y[index * batch_size:(index + 1) * batch_size]})
172 | # Compute gradients of the model wrt parameters
173 | gparams = []
174 | for param in classifier.params:
175 | # Use the right cost function here to train with or without dropout.
176 | gparam = T.grad(dropout_cost if dropout else cost, param)
177 | gparams.append(gparam)
178 | symbolic_params = [convert_variable(param) for param in classifier.params]
179 | givens = dict(zip(classifier.params, symbolic_params))
180 | f_df = theano.function(inputs=symbolic_params + [x, y], outputs=[cost] + gparams, givens=givens)
181 | return f_df, classifier
182 |
183 |
184 | def test_mlp(
185 | initial_learning_rate,
186 | learning_rate_decay,
187 | squared_filter_length_limit,
188 | n_epochs,
189 | batch_size,
190 | dropout,
191 | results_file_name,
192 | layer_sizes,
193 | dataset,
194 | use_bias):
195 | """
196 | The dataset is the one from the mlp demo on deeplearning.net. This training
197 | function is lifted from there almost exactly.
198 |
199 | :type dataset: string
200 | :param dataset: the path of the MNIST dataset file from
201 | http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz
202 |
203 |
204 | """
205 | datasets = load_mnist(dataset)
206 | train_set_x, train_set_y = datasets[0]
207 | valid_set_x, valid_set_y = datasets[1]
208 | test_set_x, test_set_y = datasets[2]
209 |
210 | # compute number of minibatches for training, validation and testing
211 | n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size
212 | n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size
213 | n_test_batches = test_set_x.get_value(borrow=True).shape[0] / batch_size
214 |
215 | ######################
216 | # BUILD ACTUAL MODEL #
217 | ######################
218 |
219 | print '... building the model'
220 |
221 | # allocate symbolic variables for the data
222 | index = T.lscalar() # index to a [mini]batch
223 | epoch = T.scalar()
224 | x = T.matrix('x') # the data is presented as rasterized images
225 | y = T.ivector('y') # the labels are presented as 1D vector of
226 | # [int] labels
227 | learning_rate = theano.shared(np.asarray(initial_learning_rate,
228 | dtype=theano.config.floatX))
229 |
230 | rng = np.random.RandomState(1234)
231 |
232 | # construct the MLP class
233 | classifier = MLP(rng=rng, input=x,
234 | layer_sizes=layer_sizes, use_bias=use_bias)
235 |
236 | # Build the expresson for the cost function.
237 | cost = classifier.negative_log_likelihood(y)
238 | dropout_cost = classifier.dropout_negative_log_likelihood(y)
239 |
240 | # Compile theano function for testing.
241 | test_model = theano.function(inputs=[index],
242 | outputs=classifier.errors(y),
243 | givens={
244 | x: test_set_x[index * batch_size:(index + 1) * batch_size],
245 | y: test_set_y[index * batch_size:(index + 1) * batch_size]})
246 | #theano.printing.pydotprint(test_model, outfile="test_file.png",
247 | # var_with_name_simple=True)
248 |
249 | # Compile theano function for validation.
250 | validate_model = theano.function(inputs=[index],
251 | outputs=classifier.errors(y),
252 | givens={
253 | x: valid_set_x[index * batch_size:(index + 1) * batch_size],
254 | y: valid_set_y[index * batch_size:(index + 1) * batch_size]})
255 | #theano.printing.pydotprint(validate_model, outfile="validate_file.png",
256 | # var_with_name_simple=True)
257 |
258 | # Compute gradients of the model wrt parameters
259 | gparams = []
260 | for param in classifier.params:
261 | # Use the right cost function here to train with or without dropout.
262 | gparam = T.grad(dropout_cost if dropout else cost, param)
263 | gparams.append(gparam)
264 |
265 | # ... and allocate mmeory for momentum'd versions of the gradient
266 | gparams_mom = []
267 | for param in classifier.params:
268 | gparam_mom = theano.shared(np.zeros(param.get_value(borrow=True).shape,
269 | dtype=theano.config.floatX))
270 | gparams_mom.append(gparam_mom)
271 |
272 | # Compute momentum for the current epoch
273 | mom = ifelse(epoch < 500,
274 | 0.5*(1. - epoch/500.) + 0.99*(epoch/500.),
275 | 0.99)
276 |
277 | # Update the step direction using momentum
278 | updates = {}
279 | for gparam_mom, gparam in zip(gparams_mom, gparams):
280 | updates[gparam_mom] = mom * gparam_mom + (1. - mom) * gparam
281 |
282 | # ... and take a step along that direction
283 | for param, gparam_mom in zip(classifier.params, gparams_mom):
284 | stepped_param = param - (1.-mom) * learning_rate * gparam_mom
285 |
286 | # This is a silly hack to constrain the norms of the rows of the weight
287 | # matrices. This just checks if there are two dimensions to the
288 | # parameter and constrains it if so... maybe this is a bit silly but it
289 | # should work for now.
290 | if param.get_value(borrow=True).ndim == 2:
291 | squared_norms = T.sum(stepped_param**2, axis=1).reshape((stepped_param.shape[0],1))
292 | scale = T.clip(T.sqrt(squared_filter_length_limit / squared_norms), 0., 1.)
293 | updates[param] = stepped_param * scale
294 | else:
295 | updates[param] = stepped_param
296 |
297 |
298 | # Compile theano function for training. This returns the training cost and
299 | # updates the model parameters.
300 | output = dropout_cost if dropout else cost
301 | train_model = theano.function(inputs=[epoch, index], outputs=output,
302 | updates=updates,
303 | givens={
304 | x: train_set_x[index * batch_size:(index + 1) * batch_size],
305 | y: train_set_y[index * batch_size:(index + 1) * batch_size]})
306 | #theano.printing.pydotprint(train_model, outfile="train_file.png",
307 | # var_with_name_simple=True)
308 |
309 | # Theano function to decay the learning rate, this is separate from the
310 | # training function because we only want to do this once each epoch instead
311 | # of after each minibatch.
312 | decay_learning_rate = theano.function(inputs=[], outputs=learning_rate,
313 | updates={learning_rate: learning_rate * learning_rate_decay})
314 |
315 | ###############
316 | # TRAIN MODEL #
317 | ###############
318 | print '... training'
319 |
320 | best_params = None
321 | best_validation_errors = np.inf
322 | best_iter = 0
323 | test_score = 0.
324 | epoch_counter = 0
325 | start_time = time.clock()
326 |
327 | results_file = open(results_file_name, 'wb')
328 |
329 | while epoch_counter < n_epochs:
330 | # Train this epoch
331 | epoch_counter = epoch_counter + 1
332 | for minibatch_index in xrange(n_train_batches):
333 | minibatch_avg_cost = train_model(epoch_counter, minibatch_index)
334 |
335 | # Compute loss on validation set
336 | validation_losses = [validate_model(i) for i in xrange(n_valid_batches)]
337 | this_validation_errors = np.sum(validation_losses)
338 |
339 | # Report and save progress.
340 | print "epoch {}, test error {}, learning_rate={}{}".format(
341 | epoch_counter, this_validation_errors,
342 | learning_rate.get_value(borrow=True),
343 | " **" if this_validation_errors < best_validation_errors else "")
344 |
345 | best_validation_errors = min(best_validation_errors,
346 | this_validation_errors)
347 | results_file.write("{0}\n".format(this_validation_errors))
348 | results_file.flush()
349 |
350 | new_learning_rate = decay_learning_rate()
351 |
352 | end_time = time.clock()
353 | print(('Optimization complete. Best validation score of %f %% '
354 | 'obtained at iteration %i, with test performance %f %%') %
355 | (best_validation_errors * 100., best_iter, test_score * 100.))
356 | print >> sys.stderr, ('The code for file ' +
357 | os.path.split(__file__)[1] +
358 | ' ran for %.2fm' % ((end_time - start_time) / 60.))
359 |
360 |
361 | if __name__ == '__main__':
362 | import sys
363 |
364 | initial_learning_rate = 1.0
365 | learning_rate_decay = 0.998
366 | squared_filter_length_limit = 15.0
367 | n_epochs = 3000
368 | batch_size = 100
369 | layer_sizes = [ 28*28, 1200, 1200, 10 ]
370 | dataset = 'data/mnist_batches.npz'
371 | #dataset = 'data/mnist.pkl.gz'
372 |
373 | build_f_df(layer_sizes, use_bias=False)
374 |
375 | if len(sys.argv) < 2:
376 | print "Usage: {0} [dropout|backprop]".format(sys.argv[0])
377 | exit(1)
378 |
379 | elif sys.argv[1] == "dropout":
380 | dropout = True
381 | results_file_name = "results_dropout.txt"
382 |
383 | elif sys.argv[1] == "backprop":
384 | dropout = False
385 | results_file_name = "results_backprop.txt"
386 |
387 | else:
388 | print "I don't know how to '{0}'".format(sys.argv[1])
389 | exit(1)
390 |
391 | test_mlp(initial_learning_rate=initial_learning_rate,
392 | learning_rate_decay=learning_rate_decay,
393 | squared_filter_length_limit=squared_filter_length_limit,
394 | n_epochs=n_epochs,
395 | batch_size=batch_size,
396 | layer_sizes=layer_sizes,
397 | dropout=dropout,
398 | dataset=dataset,
399 | results_file_name=results_file_name,
400 | use_bias=False)
401 |
402 |
--------------------------------------------------------------------------------
/generate_figures/dropout/plot_results.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ -z "$1" ]; then
4 | TERM="dumb"
5 | PLOT_FILE=/dev/stdout
6 | else
7 | TERM="png"
8 | PLOT_FILE="$1"
9 | fi
10 |
11 | SMOOTHING=0.0
12 |
13 | (
14 | gnuplot 2>/dev/null < 1e5:
78 | return "$10^{%d}$"%int(np.log10(tt))
79 | return "$%d$"%tt
80 |
81 | def plot_shared(M, M_arr, N_arr, T_arr, v_fixed, v_change):
82 | figsize=(2.5,2.,)
83 |
84 | idx = np.flatnonzero(M_arr == M)
85 | ord = np.argsort(N_arr[idx])
86 | idx = idx[ord]
87 | if idx.shape[0] > 1:
88 | plt.figure(figsize=figsize)
89 | plt.plot(N_arr[idx], T_arr[idx], 'x--')
90 | plt.grid()
91 | ax = plt.axis()
92 | plt.axis([0, ax[1], 0, ax[3]])
93 | nn = np.linspace(0, np.max(T_arr[idx]), 5)
94 | nnt = ["$%d$"%tt for tt in nn]
95 | nnt[1] = ''
96 | nnt[3] = ''
97 | plt.yticks(nn, nnt)
98 | nn = np.linspace(0, np.max(N_arr[idx]), 5)
99 | nnt = [convert_num(tt) for tt in nn]
100 | nnt[1] = ''
101 | nnt[2] = ''
102 | nnt[3] = ''
103 | plt.xticks(nn, nnt)
104 | plt.axes().set_axisbelow(True)
105 | plt.xlabel('$%s$'%v_change)
106 | plt.ylabel('Overhead (s)')
107 | if M > 10**3:
108 | plt.title('Fixed $%s=10^{%d}$'%(v_fixed, int(np.log10(M))))
109 | else:
110 | plt.title('Fixed $%s=%d$'%(v_fixed, M))
111 | try:
112 | plt.tight_layout()
113 | except:
114 | warnings.warn('tight_layout failed. try running with an Agg backend.')
115 | plt.savefig('figure_overhead_fixed%s.pdf'%(v_fixed))
116 |
117 |
118 | def make_plots(M_arr, N_arr, T_arr):
119 | for M in np.unique(M_arr):
120 | plot_shared(M, M_arr, N_arr, T_arr, 'M', 'N')
121 | for N in np.unique(N_arr):
122 | plot_shared(N, N_arr, M_arr, T_arr, 'N', 'M')
123 |
124 |
125 | if __name__ == '__main__':
126 | M_arr, N_arr, T_arr = explore_MN()
127 | make_plots(M_arr, N_arr, T_arr)
128 |
129 | """
130 | import figure_overhead
131 | reload(figure_overhead)
132 |
133 | M_arr, N_arr, T_arr = figure_overhead.explore_MN()
134 | figure_overhead.make_plots(M_arr, N_arr, T_arr)
135 |
136 |
137 |
138 | """
139 |
--------------------------------------------------------------------------------
/generate_figures/figures_cae.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | """
4 | cae.py
5 |
6 | u pythonic library for Contractive Auto-Encoders. This is
7 | for people who want to give CAEs a quick try and for people
8 | who want to understand how they are implemented. For this
9 | purpose we tried to make the code as simple and clean as possible.
10 | The only dependency is numpy, which is used to perform all
11 | expensive operations. The code is quite fast, however much better
12 | performance can be achieved using the Theano version of this code.
13 |
14 | Created by Yann N. Dauphin, Salah Rifai on 2012-01-17.
15 | Copyright (c) 2012 Yann N. Dauphin, Salah Rifai. All rights reserved.
16 | """
17 |
18 | import numpy as np
19 |
20 |
21 | class CAE(object):
22 | """
23 | A Contractive Auto-Encoder (CAE) with sigmoid input units and sigmoid
24 | hidden units.
25 | """
26 | def __init__(self,
27 | n_hiddens=1024,
28 | jacobi_penalty=0.1,
29 | W=None,
30 | hidbias=None,
31 | visbias=None):
32 | """
33 | Initialize a CAE.
34 |
35 | Parameters
36 | ----------
37 | n_hiddens : int, optional
38 | Number of binary hidden units
39 | jacobi_penalty : float, optional
40 | Scalar by which to multiply the gradients coming from the jacobian
41 | penalty.
42 | W : array-like, shape (n_inputs, n_hiddens), optional
43 | Weight matrix, where n_inputs in the number of input
44 | units and n_hiddens is the number of hidden units.
45 | hidbias : array-like, shape (n_hiddens,), optional
46 | Biases of the hidden units
47 | visbias : array-like, shape (n_inputs,), optional
48 | Biases of the input units
49 | """
50 | self.n_hiddens = n_hiddens
51 | self.jacobi_penalty = jacobi_penalty
52 | self.W = W
53 | self.hidbias = hidbias
54 | self.visbias = visbias
55 |
56 |
57 | def _sigmoid(self, x):
58 | """
59 | Implements the logistic function.
60 |
61 | Parameters
62 | ----------
63 | x: array-like, shape (M, N)
64 |
65 | Returns
66 | -------
67 | x_new: array-like, shape (M, N)
68 | """
69 | return 1. / (1. + np.exp(-np.maximum(np.minimum(x, 100), -100)))
70 |
71 | def encode(self, x):
72 | """
73 | Computes the hidden code for the input {\bf x}.
74 |
75 | Parameters
76 | ----------
77 | x: array-like, shape (n_examples, n_inputs)
78 |
79 | Returns
80 | -------
81 | h: array-like, shape (n_examples, n_hiddens)
82 | """
83 | return self._sigmoid(np.dot(x, self.W) + self.hidbias)
84 |
85 | def decode(self, h):
86 | """
87 | Compute the reconstruction from the hidden code {\bf h}.
88 |
89 | Parameters
90 | ----------
91 | h: array-like, shape (n_examples, n_hiddens)
92 |
93 | Returns
94 | -------
95 | x: array-like, shape (n_examples, n_inputs)
96 | """
97 | return self._sigmoid(np.dot(h, self.W.T) + self.visbias)
98 |
99 | def reconstruct(self, x):
100 | """
101 | Compute the reconstruction of the input {\bf x}.
102 |
103 | Parameters
104 | ----------
105 | x: array-like, shape (n_examples, n_inputs)
106 |
107 | Returns
108 | -------
109 | x_new: array-like, shape (n_examples, n_inputs)
110 | """
111 | return self.decode(self.encode(x))
112 |
113 | def loss(self, x, h=None, r=None):
114 | """
115 | Computes the error of the model with respect
116 | to the total cost.
117 |
118 | Parameters
119 | ----------
120 | x: array-like, shape (n_examples, n_inputs)
121 | h: array-like, shape (n_examples, n_hiddens), optional
122 | r: array-like, shape (n_examples, n_inputs), optional
123 |
124 | Returns
125 | -------
126 | loss: array-like, shape (n_examples,)
127 | """
128 | if h == None:
129 | h = self.encode(x)
130 | if r == None:
131 | r = self.decode(h)
132 |
133 | def _reconstruction_loss(h, r):
134 | """
135 | Computes the error of the model with respect
136 | to the reconstruction (L2) cost.
137 | """
138 | return 1/2. * ((r-x)**2).sum()/x.shape[0]
139 |
140 | def _jacobi_loss(h):
141 | """
142 | Computes the error of the model with respect
143 | the Frobenius norm of the jacobian.
144 | """
145 | return ((h *(1-h))**2 * (self.W**2).sum(0)).sum()/x.shape[0]
146 | recon_loss = _reconstruction_loss(h, r)
147 | jacobi_loss = _jacobi_loss(h)
148 | return (recon_loss + self.jacobi_penalty * jacobi_loss)
149 |
150 | def get_params(self):
151 | return dict(W=self.W, hidbias=self.hidbias, visbias=self.visbias)
152 |
153 | def set_params(self, theta):
154 | self.W = theta['W']
155 | self.hidbias = theta['hidbias']
156 | self.visbias = theta['visbias']
157 |
158 | def f_df(self, theta, x):
159 | """
160 | Compute objective and gradient of the CAE objective using the
161 | examples {\bf x}.
162 |
163 | Parameters
164 | ----------
165 | x: array-like, shape (n_examples, n_inputs)
166 |
167 | Parameters
168 | ----------
169 | loss: array-like, shape (n_examples,)
170 | Value of the loss function for each example before the step.
171 | """
172 | self.set_params(theta)
173 | h = self.encode(x)
174 | r = self.decode(h)
175 | def _contraction_jacobian():
176 | """
177 | Compute the gradient of the contraction cost w.r.t parameters.
178 | """
179 | a = 2*(h * (1 - h))**2
180 | d = ((1 - 2 * h) * a * (self.W**2).sum(0)[None, :])
181 | b = np.dot(x.T / x.shape[0], d)
182 | c = a.mean(0) * self.W
183 | return (b + c), d.mean(0)
184 |
185 | def _reconstruction_jacobian():
186 | """
187 | Compute the gradient of the reconstruction cost w.r.t parameters.
188 | """
189 | dr = (r - x) / x.shape[0]
190 | dr *= r * (1-r)
191 | dd = np.dot(dr.T, h)
192 | dh = np.dot(dr, self.W) * h * (1. - h)
193 | de = np.dot(x.T, dh)
194 | return (dd + de), dr.sum(0), dh.sum(0)
195 |
196 | W_rec, c_rec, b_rec = _reconstruction_jacobian()
197 | W_con, b_con = _contraction_jacobian()
198 | dW = W_rec + self.jacobi_penalty * W_con
199 | dhidbias = b_rec + self.jacobi_penalty * b_con
200 | dvisbias = c_rec
201 | return self.loss(x, h, r), dict(W=dW, hidbias=dhidbias, visbias=dvisbias);
202 |
203 | def init_weights(self, n_input, dtype=np.float32):
204 | self.W = np.asarray(np.random.uniform(
205 | #low=-4*np.sqrt(6./(n_input+self.n_hiddens)),
206 | #high=4*np.sqrt(6./(n_input+self.n_hiddens)),
207 | low=-1./np.sqrt(self.n_hiddens),
208 | high=1./np.sqrt(self.n_hiddens),
209 | size=(n_input, self.n_hiddens)), dtype=dtype)
210 | self.hidbias = np.zeros(self.n_hiddens, dtype=dtype)
211 | self.visbias = np.zeros(n_input, dtype=dtype)
212 | return self.get_params()
213 |
--------------------------------------------------------------------------------
/generate_figures/models.py:
--------------------------------------------------------------------------------
1 | """
2 | Model classes for each of the demo cases. Each class contains
3 | an objective function f_df, initial parameters theta_init, a
4 | reference for each subfunction subfunction_references, and a
5 | set of full_objective_references that are evaluated every update step
6 | to make the plots of objective function vs. learning iteration.
7 |
8 | This is designed to be called by figure_convergence.py.
9 |
10 |
11 | Copyright 2014 Jascha Sohl-Dickstein
12 | Licensed under the Apache License, Version 2.0 (the "License");
13 | you may not use this file except in compliance with the License.
14 | You may obtain a copy of the License at
15 | http://www.apache.org/licenses/LICENSE-2.0
16 | Unless required by applicable law or agreed to in writing, software
17 | distributed under the License is distributed on an "AS IS" BASIS,
18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 | See the License for the specific language governing permissions and
20 | limitations under the License.
21 | """
22 |
23 | import numpy as np
24 | import scipy.special
25 | import warnings
26 | import random
27 | from figures_cae import CAE
28 | from os.path import join
29 |
30 | try:
31 | import pyGLM.simGLM as glm
32 | import pyGLM.gabor as gabor
33 | except:
34 | pass
35 |
36 | # numpy < 1.7 does not have np.random.choice
37 | def my_random_choice(n, k, replace):
38 | perm = np.random.permutation(n)
39 | return perm[:k]
40 | if hasattr(np.random, 'choice'):
41 | random_choice = np.random.choice
42 | else:
43 | random_choice = my_random_choice
44 |
45 |
46 | class toy:
47 | """
48 | Toy problem. Sum of squared errors from random means, raised
49 | to random powers.
50 | """
51 |
52 | def __init__(self, num_subfunctions=100, num_dims=10):
53 | self.name = '||x-u||^a, a in U[1.5,4.5]'
54 |
55 | # create the array of subfunction identifiers
56 | self.subfunction_references = []
57 | N = num_subfunctions
58 | for i in range(N):
59 | npow = np.random.rand()*3. + 1.5
60 | mn = np.random.randn(num_dims,1)
61 | self.subfunction_references.append([npow,mn])
62 | self.full_objective_references = self.subfunction_references
63 |
64 | ## initialize parameters
65 | self.theta_init = np.random.randn(num_dims,1)
66 |
67 | def f_df(self, x, args):
68 | npow = args[0]/2.
69 | mn = args[1]
70 | f = np.sum(((x-mn)**2)**npow)
71 | df = npow*((x-mn)**2)**(npow-1.)*2*(x-mn)
72 | scl = 1. / np.prod(x.shape)
73 | return f*scl, df*scl
74 |
75 |
76 | class Hopfield:
77 | def __init__(self, num_subfunctions=100, reg=1., scale_by_N=True):
78 | """
79 | Train a Hopfield network/Ising model using MPF.
80 |
81 | Adapted from code by Chris Hillar, Kilian Koepsell, Jascha Sohl-Dickstein, 2011
82 |
83 | TODO insert Hopfield and MPF references.
84 | """
85 | self.name = 'Hopfield'
86 | self.reg = reg/num_subfunctions
87 |
88 | # Load data
89 | X, _ = load_mnist()
90 |
91 | # binarize data
92 | X = (np.sign(X-0.5)+1)/2
93 |
94 | # only keep units which actually change state
95 | gd = ((np.sum(X,axis=1) > 0) & (np.sum(1-X,axis=1) > 0))
96 | X = X[gd,:]
97 | # TODO -- discard units with correlation of exactly 1?
98 |
99 | # break the data up into minibatches
100 | self.subfunction_references = []
101 | for mb in range(num_subfunctions):
102 | self.subfunction_references.append(X[:, mb::num_subfunctions].T)
103 | #self.full_objective_references = (X[:, random_choice(X.shape[1], 10000, replace=False)].copy().T,)
104 | self.full_objective_references = self.subfunction_references
105 |
106 | if scale_by_N:
107 | self.scl = float(num_subfunctions) / float(X.shape[1])
108 | else:
109 | self.scl = 100. / float(X.shape[1])
110 |
111 | # parameter initialization
112 | self.theta_init = np.random.randn(X.shape[0], X.shape[0])/np.sqrt(X.shape[0])/10.
113 | self.theta_init = (self.theta_init + self.theta_init.T)/2.
114 |
115 |
116 | def f_df(self, J, X):
117 | J = (J + J.T)/2.
118 | X = np.atleast_2d(X)
119 | S = 2 * X - 1
120 | Kfull = np.exp(-S * np.dot(X, J.T) + .5 * np.diag(J)[None, :])
121 | dJ = -np.dot(X.T, Kfull * S) + .5 * np.diag(Kfull.sum(0))
122 | dJ = (dJ + dJ.T)/2.
123 | dJ = np.nan_to_num(dJ)
124 |
125 | K = Kfull.sum()
126 | if not np.isfinite(K):
127 | K = 1e50
128 |
129 | K *= self.scl
130 | dJ *= self.scl
131 |
132 | K += self.reg * np.sum(J**2)
133 | dJ += 2. * self.reg * J
134 | return K, dJ
135 |
136 |
137 | class logistic:
138 | """
139 | logistic regression on "protein" dataset
140 | """
141 |
142 | def __init__(self, num_subfunctions=100, scale_by_N=True):
143 | self.name = 'protein logistic regression'
144 |
145 | try:
146 | data = np.loadtxt('figure_data/bio_train.dat')
147 | except:
148 | raise Exception("Missing data. Download from and place in figure_data subdirectory.")
149 |
150 | target = data[:,[2]]
151 | feat = data[:,3:]
152 | feat = self.whiten(feat)
153 | feat = np.hstack((feat, np.ones((data.shape[0],1))))
154 |
155 | # create the array of subfunction identifiers
156 | self.subfunction_references = []
157 | N = num_subfunctions
158 | nper = float(feat.shape[0])/float(N)
159 | if scale_by_N:
160 | lam = 1./nper**2
161 | scl = 1./nper
162 | else:
163 | default_N = 100.
164 | nnp = float(feat.shape[0])/default_N
165 | lam = (1./nnp**2) * (default_N / float(N))
166 | scl = 1./nnp
167 | for i in range(N):
168 | i_s = int(np.floor(i*nper))
169 | i_f = int(np.floor((i+1)*nper))
170 | if i == N-1:
171 | # don't drop any at the end
172 | i_f = target.shape[0]
173 | l_targ = target[i_s:i_f,:]
174 | l_feat = feat[i_s:i_f,:]
175 | self.subfunction_references.append([l_targ, l_feat, lam, scl, i])
176 | self.full_objective_references = self.subfunction_references
177 |
178 | self.theta_init = np.random.randn(feat.shape[1],1)/np.sqrt(feat.shape[1])/10. # parameters initialization
179 |
180 | # remove first order dependencies from X, and scale to unit norm
181 | def whiten(self, X,fudge=1E-10):
182 | max_whiten_lines = 10000
183 |
184 | # zero mean
185 | X -= np.mean(X, axis=0).reshape((1,-1))
186 |
187 | # the matrix X should be observations-by-components
188 | # get the covariance matrix
189 | Xsu = X[:max_whiten_lines,:]
190 | Xcov = np.dot(Xsu.T,Xsu)/Xsu.shape[0]
191 | # eigenvalue decomposition of the covariance matrix
192 | try:
193 | d,V = np.linalg.eigh(Xcov)
194 | except:
195 | print("could not get eigenvectors and eigenvalues using numpy.linalg.eigh of ", Xcov.shape, Xcov)
196 | d,V = np.linalg.eig(Xcov)
197 | d = np.nan_to_num(d+fudge)
198 | d[d==0] = 1
199 | V = np.nan_to_num(V)
200 |
201 | # a fudge factor can be used so that eigenvectors associated with
202 | # small eigenvalues do not get overamplified.
203 | # TODO(jascha) D could be a vector not a matrix
204 | D = np.diag(1./np.sqrt(d))
205 |
206 | D = np.nan_to_num(D)
207 |
208 | # whitening matrix
209 | W = np.dot(np.dot(V,D),V.T)
210 | # multiply by the whitening matrix
211 | Y = np.dot(X,W)
212 | return Y
213 |
214 | def sigmoid(self, u):
215 | return 1./(1.+np.exp(-u))
216 | def f_df(self, x, args):
217 | target = args[0]
218 | feat = args[1]
219 | lam = args[2]
220 | scl = args[3]
221 |
222 | feat = feat*(2*target - 1)
223 | ip = -np.dot(feat, x.reshape((-1,1)))
224 | et = np.exp(ip)
225 | logL = np.log(1. + et)
226 | etrat = et/(1.+et)
227 | bd = np.nonzero(ip[:,0]>50)[0]
228 | logL[bd,:] = ip[bd,:]
229 | etrat[bd,:] = 1.
230 | logL = np.sum(logL)
231 | dlogL = -np.dot(feat.T, etrat)
232 |
233 | logL *= scl
234 | dlogL *= scl
235 |
236 | reg = lam*np.sum(x**2)
237 | dreg = 2.*lam*x
238 |
239 | return logL + reg, dlogL+dreg
240 |
241 |
242 | class ICA:
243 | """
244 | ICA with Student's t-experts on MNIST images.
245 | """
246 |
247 | def __init__(self, num_subfunctions=100):
248 | self.name = 'ICA'
249 |
250 | # Load data
251 | #X = load_cifar10_imagesonly()
252 | X, _ = load_mnist()
253 |
254 | # do PCA to eliminate degenerate dimensions and whiten
255 | C = np.dot(X, X.T) / X.shape[1]
256 | w, V = np.linalg.eigh(C)
257 | # take only the non-negligible eigenvalues
258 | mw = np.max(np.real(w))
259 | max_ratio = 1e4
260 | gd = np.nonzero(np.real(w) > mw/max_ratio)[0]
261 | # # whiten
262 | # P = V[gd,:]*(np.real(w[gd])**(-0.5)).reshape((-1,1))
263 | # don't whiten -- make the problem harder
264 | P = V[gd,:]
265 | X = np.dot(P, X)
266 | X /= np.std(X)
267 |
268 | # break the data up into minibatches
269 | self.subfunction_references = []
270 | for mb in range(num_subfunctions):
271 | self.subfunction_references.append(X[:, mb::num_subfunctions])
272 | # compute the full objective on all the data
273 | self.full_objective_references = self.subfunction_references
274 | # # the subset of the data used to compute the full objective function value
275 | # idx = random_choice(X.shape[1], 10000, replace=False)
276 | # self.full_objective_references = (X[:,idx].copy(),)
277 |
278 | ## initialize parameters
279 | num_dims = X.shape[0]
280 | self.theta_init = {'W':np.random.randn(num_dims, num_dims)/np.sqrt(num_dims),
281 | 'logalpha':np.random.randn(num_dims,1).ravel()}
282 |
283 | # rescale the objective and gradient so that the same hyperparameter ranges work for
284 | # ICA as for the other objectives
285 | self.scale = 1. / self.subfunction_references[0].shape[1] / 100.
286 |
287 |
288 | def f_df(self, params, X):
289 | """
290 | ICA objective function and gradient, using a Student's t-distribution prior.
291 | The energy function has form:
292 | E = \sum_i \alpha_i \log( 1 + (\sum_j W_{ij} x_j )^2 )
293 |
294 | params is a dictionary containing the filters W and the log of the
295 | X is the training data, with each column corresponding to a sample.
296 |
297 | L is the average negative log likelihood.
298 | dL is its gradient.
299 | """
300 | W = params['W']
301 | logalpha = params['logalpha'].reshape((-1,1))
302 | alpha = np.exp(logalpha)+0.5
303 |
304 | ## calculate the energy
305 | ff = np.dot(W, X)
306 | ff2 = ff**2
307 | off2 = 1 + ff2
308 | lff2 = np.log( off2 )
309 | alff2 = lff2 * alpha.reshape((-1,1))
310 | E = np.sum(alff2)
311 |
312 | ## calculate the energy gradient
313 | # could just sum instead of using the rscl 1s vector. this is functionality
314 | # left over from MPF MATLAB code. May want it again in a future project though.
315 | rscl = np.ones((X.shape[1],1))
316 | lt = (ff/off2) * alpha.reshape((-1,1))
317 | dEdW = 2 * np.dot(lt * rscl.T, X.T)
318 | dEdalpha = np.dot(lff2, rscl)
319 | dEdlogalpha = (alpha-0.5) * dEdalpha
320 |
321 | ## calculate log Z
322 | nu = alpha * 2. - 1.
323 | #logZ = -np.log(scipy.special.gamma((nu + 1.) / 2.)) + 0.5 * np.log(np.pi) + \
324 | # np.log(scipy.special.gamma(nu/2.))
325 | logZ = -scipy.special.gammaln((nu + 1.) / 2.) + 0.5 * np.log(np.pi) + \
326 | scipy.special.gammaln((nu/2.))
327 | logZ = np.sum(logZ)
328 | ## DEBUG slogdet has memory leak!
329 | ## eg, call "a = np.linalg.slogdet(random.randn(5000,5000))"
330 | ## repeatedly, and watch memory usage. So, we do this with an
331 | ## explicit eigendecomposition instead
332 | ## logZ += -np.linalg.slogdet(W)[1]
333 | W2 = np.dot(W.T, W)
334 | W2eig, _ = np.linalg.eig(W2)
335 | logZ += -np.sum(np.log(W2eig))/2.
336 |
337 | ## calculate gradient of log Z
338 | # log determinant contribution
339 | dlogZdW = -np.linalg.inv(W).T
340 | if np.min(nu) < 0:
341 | dlogZdnu = np.zeros(nu.shape)
342 | warnings.warn('not a normalizable distribution!')
343 | E = np.inf
344 | else:
345 | dlogZdnu = -scipy.special.psi((nu + 1) / 2 )/2 + \
346 | scipy.special.psi( nu/2 )/2
347 | dlogZdalpha = 2. * dlogZdnu
348 | dlogZdlogalpha = (alpha-0.5) * dlogZdalpha
349 |
350 | ## full objective and gradient
351 | L = (E + logZ) * self.scale
352 | dLdW = (dEdW + dlogZdW) * self.scale
353 | dLdlogalpha = (dEdlogalpha + dlogZdlogalpha) * self.scale
354 |
355 | ddL = {'W':dLdW, 'logalpha':dLdlogalpha.ravel()}
356 |
357 | if not np.isfinite(L):
358 | warnings.warn('numerical problems')
359 | L = np.inf
360 |
361 | return L, ddL
362 |
363 |
364 | class DeepAE:
365 | """
366 | Deep Autoencoder from Hinton, G. E. and Salakhutdinov, R. R. (2006)
367 | """
368 | def __init__(self, num_subfunctions=50, num_dims=10, objective='l2'):
369 | # don't introduce a Theano dependency until we have to
370 | from utils import _tonp
371 |
372 | self.name = 'DeepAE'
373 | layer_sizes = [ 28*28, 1000, 500, 250, 30]
374 | #layer_sizes = [ 28*28, 20]
375 | # Load data
376 | X, y = load_mnist()
377 | # break the data up into minibatches
378 | self.subfunction_references = []
379 | for mb in range(num_subfunctions):
380 | self.subfunction_references.append([X[:, mb::num_subfunctions], y[mb::num_subfunctions]])
381 | # evaluate on subset of training data
382 | self.n_full = 10000
383 | idx = random_choice(X.shape[1], self.n_full, replace=False)
384 | ##use all the training data for a smoother plot
385 | #idx = np.array(range(X.shape[1]))
386 | self.full_objective_references = [[X[:,idx].copy(), y[idx].copy()]]
387 | from dropout.deepae import build_f_df # here so theano not required for import
388 | self.theano_f_df, self.model = build_f_df(layer_sizes, use_bias=True, objective=objective)
389 | crossent_params = False
390 | if crossent_params:
391 | history = np.load('/home/poole/Sum-of-Functions-Optimizer/sfo_output.npz')
392 | out = dict(history=history['arr_0'])
393 | params = out['history'].item()['x']['SFO']
394 | self.theta_init = params
395 | else:
396 | self.theta_init = [param.get_value() for param in self.model.params]
397 |
398 | def f_df(self, theta, args, gpu_batch_size=128):
399 | X = args[0].T
400 | y = args[1]
401 | rem = np.mod(len(X), gpu_batch_size)
402 | n_batches = (len(X) - rem) / gpu_batch_size
403 | splits = np.split(np.arange(len(X) - rem), n_batches)
404 | if rem > 0:
405 | splits.append(np.arange(len(X)-rem, len(X)))
406 | sum_results = None
407 | for split in splits:
408 | theano_args = theta + [X[split]]
409 | # Convert to float32 so that this works on GPU
410 | theano_args = [arg.astype(np.float32) for arg in theano_args]
411 | results = self.theano_f_df(*theano_args)
412 | results = [_tonp(result) for result in results]
413 | if sum_results is None:
414 | sum_results = results
415 | else:
416 | sum_results = [cur_res + new_res for cur_res, new_res in zip(results, sum_results)]
417 | # Divide by number of datapoints.
418 | sum_results = [result/len(X) for result in sum_results]
419 | return sum_results[0], sum_results[1:]
420 |
421 | class MLP:
422 | """
423 | Multi-layer-perceptron
424 | """
425 | def __init__(self, num_subfunctions=100, num_dims=10, rectifier='soft'):
426 | self.name = 'MLP'
427 | #layer_sizes = [ 28*28, 1200, 10 ]
428 | #layer_sizes = [ 28*28, 1200, 10 ]
429 | #layer_sizes = [ 28*28, 500, 120, num_dims ]
430 | #layer_sizes = [ 28*28, 120, 12, num_dims ]
431 | layer_sizes = [ 28*28, 1200, 1200, num_dims ]
432 | # Load data
433 | X, y = load_mnist()
434 | # break the data up into minibatches
435 | self.subfunction_references = []
436 | for mb in range(num_subfunctions):
437 | self.subfunction_references.append([X[:, mb::num_subfunctions], y[mb::num_subfunctions]])
438 | # evaluate on subset of training data
439 | idx = random_choice(X.shape[1], 5000, replace=False)
440 | #use all the training data for a smoother plot
441 | #idx = np.array(range(X.shape[1]))
442 | self.full_objective_references = [[X[:,idx].copy(), y[idx].copy()]]
443 | from dropout.mlp import build_f_df # here so theano not required for import
444 | self.theano_f_df, self.model = build_f_df(layer_sizes, rectifier=rectifier,
445 | use_bias=True)
446 | self.theta_init = [param.get_value() for param in self.model.params]
447 | def f_df(self, theta, args):
448 | X = args[0].T
449 | y = args[1]
450 | theano_args = theta + [X, y]
451 | results = self.theano_f_df(*theano_args)
452 | return results[0], results[1:]
453 |
454 | class MLP_hard(MLP):
455 | """
456 | Multi-layer-perceptron with rectified-linear nonlinearity
457 | """
458 | def __init__(self, num_subfunctions=100, num_dims=10):
459 | MLP.__init__(self, num_subfunctions=num_subfunctions, num_dims=num_dims, rectifier='hard')
460 | self.name += ' hard'
461 | class MLP_soft(MLP):
462 | """
463 | Multi-layer-perceptron with sigmoid nonlinearity
464 | """
465 | def __init__(self, num_subfunctions=100, num_dims=10):
466 | MLP.__init__(self, num_subfunctions=num_subfunctions, num_dims=num_dims, rectifier='soft')
467 | self.name += ' soft'
468 |
469 | class ContractiveAutoencoder:
470 | """
471 | Contractive autoencoder on MNIST dataset.
472 | """
473 | def __init__(self, num_subfunctions=100, num_dims=10):
474 | self.name = 'Contractive Autoencoder'
475 |
476 | # Load data
477 | X, y = load_mnist()
478 | # break the data up into minibatches
479 | self.subfunction_references = []
480 | for mb in range(num_subfunctions):
481 | self.subfunction_references.append(X[:, mb::num_subfunctions])
482 | #self.full_objective_references = (X[:,np.random.choice(X.shape[1], 1000, replace=False)].copy(),)
483 | self.full_objective_references = (X[:, random_choice(X.shape[1], 1000, replace=False)].copy(),)
484 | #self.full_objective_references = (X.copy(),)
485 | # Initialize CAE model
486 | self.cae = CAE(n_hiddens=256, jacobi_penalty=1.0)
487 | # Initialize parameters
488 | self.theta_init = self.cae.init_weights(X.shape[0], dtype=np.float64)
489 |
490 | def f_df(self, theta, X):
491 | return self.cae.f_df(theta, X.T)
492 |
493 | class PylearnModel:
494 | def __init__(self, filename, load_fnc, num_subfunctions=100, num_dims=10):
495 | # Import here so we don't depend on pylearn elsewhere
496 | from nnet.model_gradient import load_model
497 | self.name = 'Pylearn'
498 | X, y = load_fnc()
499 | batch_size = X.shape[0] / num_subfunctions
500 | mg = load_model(filename, batch_size=batch_size)
501 | self.mg = mg
502 | self.model = mg.model
503 | self._f_df = mg.f_df
504 |
505 |
506 | # break the data up into minibatches
507 | self.subfunction_references = []
508 | for mb in range(num_subfunctions):
509 | self.subfunction_references.append([X[ mb::num_subfunctions,...], y[mb::num_subfunctions,...]])
510 | # evaluate on subset of training data
511 | idx = random_choice(X.shape[0], 1000, replace=False)
512 | #use all the training data for a smoother plot
513 | #idx = np.array(range(X.shape[1]))
514 | self.full_objective_references = [[X[idx,...].copy(), y[idx,...].copy()]]
515 | self.theta_init = [param.get_value() for param in self.mg.params]
516 |
517 | def f_df(self, thetas, args):
518 | thetas32 = [theta.astype(np.float32) for theta in thetas]
519 | return self._f_df(thetas32, args)
520 |
521 | class MNISTConvNet(PylearnModel):
522 | def __init__(self, num_subfunctions=100, num_dims=10):
523 | fn = 'nnet/mnist.yaml'
524 | #super(ConvNet, self).__init__(fn, num_subfunctions, num_dims)
525 | PylearnModel.__init__(self,fn, load_mnist, num_subfunctions, num_dims)
526 | self.name += '_conv'
527 |
528 | class CIFARConvNet(PylearnModel):
529 | def __init__(self, num_subfunctions=100, num_dims=10):
530 | fn = 'nnet/conv.yaml'
531 | #super(ConvNet, self).__init__(fn, num_subfunctions, num_dims)
532 | PylearnModel.__init__(self,fn, load_cifar, num_subfunctions, num_dims)
533 | self.name += '_conv'
534 |
535 |
536 | class GLM:
537 | """
538 | Train a GLM on simulated data.
539 | """
540 |
541 | def __init__(self, num_subfunctions_ratio=0.05, baseDir='/home/nirum/data/retina/glm-feb-19/'):
542 | #def __init__(self, num_subfunctions_ratio=0.05, baseDir='/home/nirum/data/retina/glm-feb-19/small'):
543 | self.name = 'GLM'
544 |
545 | print('Initializing parameters...')
546 |
547 | ## FOR CUSTOM SIZES
548 | #self.params = glm.setParameters(m=5000, dh=10, ds=50) # small
549 | #self.params = glm.setParameters(m=5000, dh=10, ds=49) # small
550 | #self.params = glm.setParameters(m=20000, dh=100, ds=500) # large
551 | #self.params = glm.setParameters(m=20000, dh=25, ds=256, n=5) # huge
552 | #self.params = glm.setParameters(m=1e5, dh=10, ds=256, n=50) # Jascha huge
553 | #self.params = glm.setParameters(m=1e5, n=100, ds=256, dh=10) # shared
554 |
555 | ## FROM EXTERNAL DATA FILES
556 | # load sizes of external data
557 | shapes = np.load(join(baseDir, 'metadata.npz'))
558 |
559 | # set up GLM parameters
560 | self.params = glm.setParameters(dh=40, ds=shapes['stimSlicedShape'][1], n=shapes['rateShape'][1])
561 | #self.params = glm.setParameters(dh=40, ds=shapes['stimSlicedShape'][1], n=2)
562 |
563 | ## initialize parameters
564 | print('Generating model...')
565 | #self.theta_true = glm.generateModel(self.params)
566 | self.theta_init = glm.generateModel(self.params)
567 | for key in self.theta_init.keys():
568 | self.theta_init[key] /= 1e3
569 |
570 | #print('Simulating model to generate data...')
571 | #self.data = glm.generateData(self.theta_true, self.params)
572 |
573 | # load external data files as memmapped arrays
574 | print('Loading external data...')
575 | self.data = glm.loadExternalData('stimulus_sliced.dat', 'rates.dat', shapes, baseDir=baseDir)
576 |
577 | # all the data
578 | batch_size = self.data['x'].shape[0]
579 | #batch_size = 10000
580 |
581 | print('ready to go!')
582 |
583 | #trueMinimum, trueGrad = self.f_df(self.theta_true, (0, batch_size))
584 | #print('Minimum for true parameters %g'%(trueMinimum))
585 |
586 | #print('Norm of the gradient at the minimum:')
587 | #for key in trueGrad.keys():
588 | #print('grad[' + key + ']: %g'%(np.linalg.norm(trueGrad[key])))
589 |
590 | # print out some network information
591 | #glm.visualizeNetwork(self.theta_true, self.params, self.data)
592 |
593 | # break the data up into minibatches
594 |
595 | self.N = int(np.ceil(np.sqrt(batch_size)*num_subfunctions_ratio))
596 | self.subfunction_references = []
597 | samples_per_subfunction = int(np.floor(batch_size/self.N))
598 | for mb in range(self.N):
599 | print(mb, self.N)
600 | start_idx = mb*samples_per_subfunction
601 | end_idx = (mb+1)*samples_per_subfunction
602 | self.subfunction_references.append((start_idx, end_idx,))
603 |
604 | self.full_objective_references = self.subfunction_references
605 | print('initialized ...')
606 | #self.full_objective_references = random.sample(self.subfunction_references, int(num_subfunctions/10))
607 |
608 |
609 | def f_df(self, theta, idx_range):
610 | """
611 | objective assuming Poisson noise
612 |
613 | function [fval grad] = objPoissonGLM(theta, datapath)
614 | Computes the Poisson log-likelihood objective and gradient
615 | for the generalized linear model (GLM)
616 | """
617 |
618 | data_subf = dict()
619 | for key in self.data.keys():
620 | data_subf[key] = np.array(self.data[key][idx_range[0]:idx_range[1]])
621 |
622 | return glm.f_df(theta, data_subf, self.params)
623 |
624 |
625 |
626 | # def load_fingerprints(nsamples=None):
627 | # """ load fingerprint image
628 | # 1 <= m <= 10 (subjects), 1 <= n <= 8 (fingers)
629 | # """
630 | # from PIL import Image
631 |
632 | # img_list = []
633 | # for m = range(1,11):
634 | # for n = range(1,9):
635 | # fname = os.path.join('figure_data/DB2_B', '1%02d_%d.tif'%(m,n))
636 | # arr = np.array(Image.open(fname)) # unit8
637 | # img = arr.astype(np.double)
638 | # img = img[100:-64,28:-28]
639 | # return img[::2,::2]
640 |
641 |
642 |
643 |
644 |
645 |
646 | def load_cifar10_imagesonly(nsamples=None):
647 | X = np.load('figure_data/cifar10_images.npy')
648 | if nsamples == None:
649 | nsamples = X.shape[0]
650 | perm = random_choice(X.shape[0], nsamples, replace=False)
651 | X = X[perm, :]
652 | return X.T
653 |
654 | def load_cifar(nsamples=None):
655 | try:
656 | from pylearn2.utils import serial
657 | except:
658 | raise Exception("pylearn2 must be installed.")
659 | try:
660 | dset = serial.load('figure_data/train.pkl')
661 | except:
662 | raise Exception("Missing data. Download CIFAR!")
663 | X = dset.X
664 | y = dset.y
665 | # Convert to one hot
666 | one_hot = np.zeros((y.shape[0], 10), dtype='float32')
667 | for i in range(y.shape[0]):
668 | one_hot[i, y[i]] = 1
669 |
670 | if nsamples == None:
671 | nsamples = X.shape[0]
672 | perm = random_choice(X.shape[0], nsamples, replace=False)
673 | X = X[perm, :]
674 | X = X.reshape(-1, 3,32,32)
675 | X = X.astype(np.float32)
676 | y = one_hot[perm, :]
677 | return X,y
678 |
679 |
680 | def load_mnist(nsamples=None):
681 | """
682 | Load the MNIST dataset.
683 | """
684 |
685 | try:
686 | data = np.load('figure_data/mnist_train.npz')
687 | except:
688 | raise Exception("Missing data. Download mnist_train.npz from and place in figure_data subdirectory.")
689 | X = data['X']
690 | y = data['y']
691 | if nsamples == None:
692 | nsamples = X.shape[0]
693 | perm = random_choice(X.shape[0], nsamples, replace=False)
694 | X = X[perm, :]
695 | X = X.T
696 | y = y[perm]
697 | return X, y
698 |
--------------------------------------------------------------------------------
/generate_figures/nnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sohl-Dickstein/Sum-of-Functions-Optimizer/36c66dfef68b344eca0ca9bb56d0a7bed2075036/generate_figures/nnet/__init__.py
--------------------------------------------------------------------------------
/generate_figures/nnet/conv.yaml:
--------------------------------------------------------------------------------
1 | {
2 | model: !obj:pylearn2.models.mlp.MLP {
3 | batch_size: $batch_size,
4 | layers: [
5 | !obj:pylearn2.models.mlp.ConvRectifiedLinear {
6 | layer_name: 'h0',
7 | #pad: 4,
8 | #tied_b: 1,
9 | W_lr_scale: .05,
10 | b_lr_scale: .05,
11 | output_channels: 48,
12 | #num_pieces: 2,
13 | kernel_shape: [8, 8],
14 | pool_shape: [4, 4],
15 | pool_stride: [2, 2],
16 | irange: .05,
17 | #max_kernel_norm: .9,
18 | },
19 | !obj:pylearn2.models.mlp.ConvRectifiedLinear {
20 | layer_name: 'h1',
21 | #pad: 3,
22 | #tied_b: 1,
23 | W_lr_scale: .05,
24 | b_lr_scale: .05,
25 | output_channels: 128,
26 | #num_pieces: 2,
27 | kernel_shape: [8, 8],
28 | pool_shape: [4, 4],
29 | pool_stride: [2, 2],
30 | irange: .05,
31 | # max_kernel_norm: 1.9365,
32 | },
33 | #!obj:pylearn2.models.maxout.MaxoutConvC01B {
34 | # pad: 3,
35 | # layer_name: 'h2',
36 | # tied_b: 1,
37 | # W_lr_scale: .05,
38 | # b_lr_scale: .05,
39 | # num_channels: 128,
40 | # num_pieces: 2,
41 | # kernel_shape: [5, 5],
42 | # pool_shape: [2, 2],
43 | # pool_stride: [2, 2],
44 | # irange: .005,
45 | # max_kernel_norm: 1.9365,
46 | #},
47 | !obj:pylearn2.models.mlp.RectifiedLinear {
48 | layer_name: 'h3',
49 | irange: .005,
50 | #num_units: 240,
51 | #num_pieces: 5,
52 | dim: 240,
53 | #max_col_norm: 1.9
54 | },
55 | !obj:pylearn2.models.mlp.Softmax {
56 | #max_col_norm: 1.9365,
57 | layer_name: 'y',
58 | n_classes: 10,
59 | irange: .05
60 | }
61 | ],
62 | input_space: !obj:pylearn2.space.Conv2DSpace {
63 | shape: [32, 32],
64 | num_channels: 3,
65 | axes: ['c', 0, 1, 'b'],
66 | },
67 | },
68 | algorithm: !obj:pylearn2.training_algorithms.sgd.SGD {
69 | learning_rate: .1,
70 | learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum {
71 | init_momentum: 0.5,
72 | },
73 | #cost: !obj:pylearn2.costs.mlp.dropout.Dropout {
74 | # input_include_probs: { 'h0' : .8 },
75 | # input_scales: { 'h0': 1. }
76 | #},
77 | termination_criterion: !obj:pylearn2.termination_criteria.MonitorBased {
78 | channel_name: "valid_y_misclass",
79 | prop_decrease: 0.,
80 | N: 100
81 | },
82 | },
83 | save_freq: 1
84 | }
85 |
--------------------------------------------------------------------------------
/generate_figures/nnet/mnist.yaml:
--------------------------------------------------------------------------------
1 | !obj:pylearn2.train.Train {
2 | dataset: &train !obj:pylearn2.datasets.mnist.MNIST {
3 | which_set: 'train',
4 | one_hot: 1,
5 | axes: ['c', 0, 1, 'b'],
6 | start: 0,
7 | stop: 50000
8 | },
9 | model: !obj:pylearn2.models.mlp.MLP {
10 | batch_size: 128,
11 | layers: [
12 | !obj:pylearn2.models.maxout.MaxoutConvC01B {
13 | layer_name: 'h0',
14 | pad: 0,
15 | num_channels: 48,
16 | num_pieces: 2,
17 | kernel_shape: [8, 8],
18 | pool_shape: [4, 4],
19 | pool_stride: [2, 2],
20 | irange: .005,
21 | max_kernel_norm: .9,
22 | },
23 | !obj:pylearn2.models.maxout.MaxoutConvC01B {
24 | layer_name: 'h1',
25 | pad: 3,
26 | num_channels: 48,
27 | num_pieces: 2,
28 | kernel_shape: [8, 8],
29 | pool_shape: [4, 4],
30 | pool_stride: [2, 2],
31 | irange: .005,
32 | max_kernel_norm: 1.9365,
33 | },
34 | !obj:pylearn2.models.maxout.MaxoutConvC01B {
35 | pad: 3,
36 | layer_name: 'h2',
37 | num_channels: 24,
38 | num_pieces: 4,
39 | kernel_shape: [5, 5],
40 | pool_shape: [2, 2],
41 | pool_stride: [2, 2],
42 | irange: .005,
43 | max_kernel_norm: 1.9365,
44 | },
45 | !obj:pylearn2.models.mlp.Softmax {
46 | max_col_norm: 1.9365,
47 | layer_name: 'y',
48 | n_classes: 10,
49 | irange: .005
50 | }
51 | ],
52 | input_space: !obj:pylearn2.space.Conv2DSpace {
53 | shape: [28, 28],
54 | num_channels: 1,
55 | axes: ['c', 0, 1, 'b'],
56 | },
57 | },
58 | algorithm: !obj:pylearn2.training_algorithms.sgd.SGD {
59 | learning_rate: .05,
60 | learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum {
61 | init_momentum: .5,
62 | },
63 | monitoring_dataset:
64 | {
65 | 'valid' : !obj:pylearn2.datasets.mnist.MNIST {
66 | axes: ['c', 0, 1, 'b'],
67 | which_set: 'train',
68 | one_hot: 1,
69 | start: 50000,
70 | stop: 60000
71 | },
72 | },
73 | cost: !obj:pylearn2.costs.mlp.dropout.Dropout {
74 | input_include_probs: { 'h0' : .8 },
75 | input_scales: { 'h0': 1. }
76 | },
77 | termination_criterion: !obj:pylearn2.termination_criteria.MonitorBased {
78 | channel_name: "valid_y_misclass",
79 | prop_decrease: 0.,
80 | N: 100
81 | },
82 | update_callbacks: !obj:pylearn2.training_algorithms.sgd.ExponentialDecay {
83 | decay_factor: 1.00004,
84 | min_lr: .000001
85 | }
86 | },
87 | extensions: [
88 | !obj:pylearn2.train_extensions.best_params.MonitorBasedSaveBest {
89 | channel_name: 'valid_y_misclass',
90 | save_path: "${PYLEARN2_TRAIN_DIR}mnist_best.pkl"
91 | },
92 | !obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor {
93 | start: 1,
94 | saturate: 250,
95 | final_momentum: .7
96 | }
97 | ]
98 | }
99 |
--------------------------------------------------------------------------------
/generate_figures/nnet/model_gradient.py:
--------------------------------------------------------------------------------
1 | import theano
2 | import theano.tensor as T
3 | import numpy as np
4 | from collections import OrderedDict
5 | from itertools import izip
6 | from string import Template
7 |
8 | from pylearn2.utils import serial, safe_zip
9 | from pylearn2.utils.data_specs import DataSpecsMapping
10 | from pylearn2.config.yaml_parse import load_path, load
11 | from theano.sandbox.cuda import CudaNdarray
12 |
13 | def _tonp(x):
14 | if type(x) not in [CudaNdarray, np.array, np.ndarray]:
15 | x = x.eval()
16 | return np.array(x)
17 |
18 | def load_model(filename, batch_size=100):
19 | out=Template(open(filename, 'r').read()).substitute({'batch_size':batch_size})
20 | stuff = load(out)
21 | model = stuff['model']
22 | #model.batch_size = batch_size
23 | #model.set_batch_size(batch_size)
24 | cost = stuff['algorithm'].cost
25 | if cost is None:
26 | cost = model.get_default_cost()
27 | mg = ModelGradient(model, cost,batch_size=batch_size)
28 | return mg
29 |
30 |
31 | class ModelGradient:
32 | def __init__(self, model, cost=None, batch_size=100):
33 | self.model = model
34 | self.model.set_batch_size(batch_size)
35 | self.model._test_batch_size = batch_size
36 | print 'it should really be ', batch_size
37 | self.cost = cost
38 | self.batch_size = batch_size
39 | self.setup()
40 |
41 | def setup(self):
42 | self.X = T.matrix('X')
43 | self.Y = T.matrix('Y')
44 |
45 | # Taken from pylearn2/training_algorithms/sgd.py
46 |
47 |
48 | data_specs = self.cost.get_data_specs(self.model)
49 | mapping = DataSpecsMapping(data_specs)
50 | space_tuple = mapping.flatten(data_specs[0], return_tuple=True)
51 | source_tuple = mapping.flatten(data_specs[1], return_tuple=True)
52 |
53 | # Build a flat tuple of Theano Variables, one for each space.
54 | # We want that so that if the same space/source is specified
55 | # more than once in data_specs, only one Theano Variable
56 | # is generated for it, and the corresponding value is passed
57 | # only once to the compiled Theano function.
58 | theano_args = []
59 | for space, source in safe_zip(space_tuple, source_tuple):
60 | name = '%s[%s]' % (self.__class__.__name__, source)
61 | arg = space.make_theano_batch(name=name, batch_size = self.batch_size)
62 | theano_args.append(arg)
63 | print 'BATCH SIZE=',self.batch_size
64 | theano_args = tuple(theano_args)
65 |
66 | # Methods of `self.cost` need args to be passed in a format compatible
67 | # with data_specs
68 | nested_args = mapping.nest(theano_args)
69 | print self.cost
70 | fixed_var_descr = self.cost.get_fixed_var_descr(self.model, nested_args)
71 | print self.cost
72 | self.on_load_batch = fixed_var_descr.on_load_batch
73 | params = list(self.model.get_params())
74 | self.X = nested_args[0]
75 | self.Y = nested_args[1]
76 | init_grads, updates = self.cost.get_gradients(self.model, nested_args)
77 |
78 | params = self.model.get_params()
79 | # We need to replace parameters with purely symbolic variables in case some are shared
80 | # Create gradient and cost functions
81 | self.params = params
82 | symbolic_params = [self._convert_variable(param) for param in params]
83 | givens = dict(zip(params, symbolic_params))
84 | costfn = self.model.cost_from_X((self.X, self.Y))
85 | gradfns = [init_grads[param] for param in params]
86 | #self.symbolic_params = symbolic_params
87 | #self._loss = theano.function(symbolic_para[self.X, self.Y], self.model.cost_from_X((self.X, self.Y)))#, givens=givens)
88 | #1/0
89 | print 'Compiling function...'
90 | self.theano_f_df = theano.function(inputs=symbolic_params + [self.X, self.Y], outputs=[costfn] + gradfns, givens=givens)
91 | print 'done'
92 | # self.grads = theano.function(symbolic_params + [self.X, self.Y], [init_grads[param] for param in params], givens=givens)
93 | # self._loss = theano.function(symbolic_params + [self.X, self.Y], self.model.cost(self.X, self.Y), givens=givens)
94 | # Maps params -> their derivative
95 |
96 | def f_df(self, theta, args):
97 | X = args[0]
98 | y = args[1]
99 | X = np.transpose(X,(1,2,3,0))
100 | y = y
101 | nsamples = X.shape[-1]
102 | nbatches = nsamples / self.batch_size
103 | # lets hope it's actually divisible
104 | #XXX(ben): fix this
105 | cost = 0.0
106 | thetas = None
107 | idxs = np.array_split(np.arange(nsamples), nbatches)
108 | for idx in idxs:
109 | theano_args = theta + [X[...,idx], y[idx,...]]
110 | results = self.theano_f_df(*theano_args)
111 | results = [_tonp(result) for result in results]
112 | if thetas is None:
113 | thetas = results[1:]
114 | else:
115 | thetas = [np.array(t) + np.array(result) for t,result in zip(thetas,results[1:])]
116 | cost += results[0]/nbatches
117 |
118 | thetas = [np.array(theta)/nbatches for theta in thetas]
119 |
120 | # if X.shape[1] != self.batch_size:
121 |
122 |
123 | # print X.shape, y.shape
124 | # theano_args = theta + [X,y]
125 | # results = self.theano_f_df(*theano_args)
126 | # return results[0], results[1:]
127 | return cost, thetas
128 |
129 | def _convert_variable(self, x):
130 | return T.TensorType(x.dtype, x.broadcastable)(x.name) #'int32', broadcastable=())('myvar')
131 |
132 |
133 |
134 | if __name__ == '__main__':
135 | # Load train obj
136 | f = np.load('test.npz')
137 | model = f['model'].item()
138 | cost = f['cost'].item()
139 | m = ModelGradient(model, cost)
140 | p = model.weights.shape.eval()[0]
141 | X = np.random.randn(100,p).astype(np.float32)
142 | theta = m.get_params()
143 | m.f_df(theta, X)
144 | 1/0
145 |
146 | # Load cost function
147 | W = np.random.randn(10,10)
148 | W = theano.shared(W)
149 |
150 | cost = (W**2).sum()
151 | grad = T.grad(cost, W)
152 |
153 | params = [W]
154 | 1/0
155 |
--------------------------------------------------------------------------------
/generate_figures/optimization_wrapper.py:
--------------------------------------------------------------------------------
1 | """
2 | Trains the model using a variety of optimization algorithms.
3 | This class also wraps the objective and gradient of the model,
4 | so that it can store a history of the objective during
5 | optimization.
6 |
7 | This is slower than directly calling the optimizers, because
8 | it periodically evaluates (and stores) the FULL objective rather
9 | than always evaluating only a single subfunction per update step.
10 |
11 | Designed to be used by figure_comparison*.py.
12 |
13 | Copyright 2014 Jascha Sohl-Dickstein
14 | Licensed under the Apache License, Version 2.0 (the "License");
15 | you may not use this file except in compliance with the License.
16 | You may obtain a copy of the License at
17 | http://www.apache.org/licenses/LICENSE-2.0
18 | Unless required by applicable law or agreed to in writing, software
19 | distributed under the License is distributed on an "AS IS" BASIS,
20 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21 | See the License for the specific language governing permissions and
22 | limitations under the License.
23 | """
24 |
25 | # allow SFO to be imported despite being in the parent directoy
26 | import sys
27 | sys.path.append("..")
28 | sys.path.append(".")
29 |
30 | from sfo import SFO
31 | from sag import SAG
32 | from adagrad import ADAGrad
33 | from collections import defaultdict
34 | from itertools import product
35 | import numpy as np
36 | from scipy.optimize import fmin_l_bfgs_b
37 |
38 | # numpy < 1.7 does not have np.random.choice
39 | def my_random_choice(n, k, replace):
40 | perm = np.random.permutation(n)
41 | return perm[:k]
42 | if hasattr(np.random, 'choice'):
43 | random_choice = np.random.choice
44 | else:
45 | random_choice = my_random_choice
46 |
47 |
48 | class train:
49 | """
50 | Trains the model using a variety of optimization algorithms.
51 | This class also wraps the objective and gradient of the model,
52 | so that it can evaluate and store the full objective for each
53 | step in the optimization.
54 |
55 | This is WAY SLOWER than just calling the optimizers, because
56 | it evaluates the FULL objective and gradient instead of a single
57 | subfunction several times per pass.
58 |
59 | Designed to be used by figure_convergence.py.
60 | """
61 |
62 | def __init__(self, model, calculate_full_objective=True, num_projection_dims=5, full_objective_per_pass=4):
63 | """
64 | Trains the model using a variety of optimization algorithms.
65 | This class also wraps the objective and gradient of the model,
66 | so that it can evaluate and store the full objective for each
67 | step in the optimization.
68 |
69 | This is WAY SLOWER than just calling the optimizers, because
70 | it evaluates the FULL objective and gradient instead of a single
71 | subfunction several times per pass.
72 |
73 | Designed to be used by figure_convergence.py.
74 | """
75 |
76 | self.model = model
77 | self.history = {'f':defaultdict(list), 'x_projection':defaultdict(list), 'events':defaultdict(list), 'x':defaultdict(list)}
78 |
79 | # we use SFO to flatten/unflatten parameters for the other optimizers
80 | self.x_map = SFO(self.model.f_df, self.model.theta_init, self.model.subfunction_references)
81 | self.xinit_flat = self.x_map.theta_original_to_flat(self.model.theta_init)
82 | self.calculate_full_objective = calculate_full_objective
83 |
84 | M = self.xinit_flat.shape[0]
85 | self.x_projection_matrix = np.random.randn(num_projection_dims, M)/np.sqrt(M)
86 |
87 | self.num_subfunctions = len(self.model.subfunction_references)
88 | self.full_objective_period = int(self.num_subfunctions/full_objective_per_pass)
89 |
90 |
91 | def f_df_wrapper(self, *args, **kwargs):
92 | """
93 | This (slightly hacky) function stands between the optimizer and the objective function.
94 | It evaluates the objective on the full function every full_objective_function times a
95 | subfunction is evaluated, and stores the history of the full objective function value.
96 | """
97 |
98 | ## call the true subfunction objective function, passing through all parameters
99 | f, df = self.model.f_df(*args, **kwargs)
100 |
101 | if len(self.history['f'][self.learner_name]) == 0:
102 | # this is the first time step for this learner
103 | self.last_f = np.inf
104 | self.last_idx = -1
105 | self.nsteps_this_learner = 0
106 |
107 | self.nsteps_this_learner += 1
108 | # only record the step every once every self.full_objective_period steps
109 | if np.mod(self.nsteps_this_learner, self.full_objective_period) != 1 and self.full_objective_period > 1:
110 | return f, df
111 |
112 | # the full objective function on all subfunctions
113 | if self.calculate_full_objective:
114 | new_f = 0.
115 | for ref in self.model.full_objective_references:
116 | new_f += self.model.f_df(args[0], ref)[0]
117 | else:
118 | new_f = f
119 |
120 | events = dict() # holds anything special about this step
121 | # a unique identifier for the current subfunction
122 | new_idx = id(args[1])
123 | if 'SFO' in self.learner_name:
124 | events = dict(self.optimizer.events)
125 | # append the full objective value, projections, etc to the history
126 | self.history['f'][self.learner_name].append(new_f)
127 | x_proj = np.dot(self.x_projection_matrix, self.x_map.theta_original_to_flat(args[0])).ravel()
128 | self.history['x_projection'][self.learner_name].append(x_proj)
129 | self.history['events'][self.learner_name].append(events)
130 | self.history['x'][self.learner_name] = args[0]
131 | print("full f %g"%(new_f))
132 | # store the prior values
133 | self.last_f = new_f
134 | self.last_idx = new_idx
135 |
136 | return f, df
137 |
138 |
139 | def f_df_wrapper_flattened(self, x_flat, subfunction_references, *args, **kwargs):
140 | """
141 | Calculate the subfunction objective and gradient.
142 | Takes a 1d parameter vector, and returns a 1d gradient, even
143 | if the parameters for f_df are a list or a dictionary.
144 | x_flat should be the flattened version of the parameters.
145 | """
146 |
147 | x = self.x_map.theta_flat_to_original(x_flat)
148 | f = 0.
149 | df = 0.
150 | for sr in subfunction_references:
151 | fl, dfl = self.f_df_wrapper(x, sr, *args, **kwargs)
152 | dfl_flat = self.x_map.theta_original_to_flat(dfl)
153 | f += fl
154 | df += dfl_flat
155 | return f, df.ravel()
156 |
157 |
158 | def SGD(self, num_passes=20):
159 | """ Train model using SGD with various learning rates """
160 |
161 | # get the number of minibatches
162 | N = len(self.model.subfunction_references)
163 | # step through all the hyperparameters. eta is step length.
164 | for eta in 10**np.linspace(-5,2,8):
165 | # label this convergence trace using the optimizer name and hyperparameter
166 | self.learner_name = "SGD %.4f"%eta
167 | print("\n\n" + self.learner_name)
168 |
169 | # initialize the parameters
170 | x = self.xinit_flat.copy()
171 | ## perform stochastic gradient descent
172 | for _ in range(num_passes*N): # number of minibatch evaluations
173 | # choose a minibatch at random
174 | idx = np.random.randint(N)
175 | sr = self.model.subfunction_references[idx]
176 | # evaluate the objective and gradient for that minibatch
177 | fl, dfl = self.f_df_wrapper_flattened(x.reshape((-1,1)), (sr,))
178 | # update the parameters
179 | x -= dfl.reshape(x.shape) * eta
180 | # if the objective has diverged, skip the rest of the run for this hyperparameter
181 | if not np.isfinite(fl):
182 | print("Non-finite subfunction.")
183 | break
184 |
185 |
186 | def LBFGS(self, num_passes=20):
187 | """ Train model using LBFGS """
188 |
189 | self.learner_name = "LBFGS"
190 | print("\n\n" + self.learner_name)
191 | _, _, _ = fmin_l_bfgs_b(
192 | self.f_df_wrapper_flattened,
193 | self.xinit_flat.copy(),
194 | disp=1,
195 | args=(self.model.subfunction_references, ),
196 | maxfun=num_passes)
197 |
198 |
199 | def SFO(self, num_passes=20, learner_name='SFO', **kwargs):
200 | """ Train model using SFO."""
201 | self.learner_name = learner_name
202 | print("\n\n" + self.learner_name)
203 |
204 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references, **kwargs)
205 | # # check the gradients
206 | # self.optimizer.check_grad()
207 | x = self.optimizer.optimize(num_passes=num_passes)
208 |
209 |
210 | def SFO_variations(self, num_passes=20):
211 | """
212 | Train model using several variations on the standard SFO algorithm.
213 | """
214 |
215 | np.random.seed(0) # make experiments repeatable
216 | self.learner_name = 'SFO standard'
217 | print("\n\n" + self.learner_name)
218 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references)
219 | x = self.optimizer.optimize(num_passes=num_passes)
220 |
221 | np.random.seed(0) # make experiments repeatable
222 | self.learner_name = 'SFO all active'
223 | print("\n\n" + self.learner_name)
224 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references,
225 | init_subf=len(self.model.subfunction_references))
226 | x = self.optimizer.optimize(num_passes=num_passes)
227 |
228 | np.random.seed(0) # make experiments repeatable
229 | self.learner_name = 'SFO rank 1'
230 | print("\n\n" + self.learner_name)
231 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references,
232 | hessian_algorithm='rank1')
233 | x = self.optimizer.optimize(num_passes=num_passes)
234 |
235 | self.learner_name = 'SFO random'
236 | print("\n\n" + self.learner_name)
237 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references,
238 | subfunction_selection='random'
239 | )
240 | x = self.optimizer.optimize(num_passes=num_passes)
241 |
242 | self.learner_name = 'SFO cyclic'
243 | print("\n\n" + self.learner_name)
244 | self.optimizer = SFO(self.f_df_wrapper, self.model.theta_init, self.model.subfunction_references,
245 | subfunction_selection='cyclic'
246 | )
247 | x = self.optimizer.optimize(num_passes=num_passes)
248 |
249 |
250 | def SAG(self, num_passes=20):
251 | """ Train model using SAG with line search, for various initial Lipschitz """
252 |
253 | # larger L is easier, so start large
254 | for L in 10**(-np.linspace(-3, 3, 7)):
255 | self.learner_name = "SAG %.4f"%L
256 | #learner_name = "SAG (diverges)"
257 | print("\n\n" + self.learner_name)
258 | self.optimizer = SAG(self.f_df_wrapper_flattened, self.xinit_flat.copy(), self.model.subfunction_references, L=L)
259 | x = self.optimizer.optimize(num_passes=num_passes)
260 | print(np.mean(self.optimizer.f), "average value at last evaluation")
261 |
262 |
263 | def LBFGS_minibatch(self, num_passes=20, data_fraction=0.1, num_steps=10):
264 | """ Perform LBFGS on minibatches of size data_fraction of the full datastep, and with num_steps LBFGS steps per minibatch."""
265 |
266 | self.learner_name = "LBFGS minibatch"
267 |
268 |
269 | x = self.xinit_flat.copy()
270 | for epoch in range(num_passes):
271 | idx = random_choice(len(self.model.subfunction_references),
272 | int(data_fraction*len(self.model.subfunction_references)),
273 | replace=False)
274 | sr = []
275 | for ii in idx:
276 | sr.append(self.model.subfunction_references[ii])
277 | x, _, _ = fmin_l_bfgs_b(
278 | self.f_df_wrapper_flattened,
279 | x,
280 | args=(sr, ),
281 | disp=1,
282 | maxfun=num_steps)
283 |
284 |
285 | def SGD_momentum(self, num_passes=20):
286 | """ Train model using SGD with various learning rates and momentums"""
287 |
288 | learning_rates = 10**np.linspace(-5,2,8)
289 | momentums = np.array([0.5, 0.9, 0.95, 0.99])
290 | params = product(learning_rates, momentums)
291 | N = len(self.model.subfunction_references)
292 | for eta, momentum in params:
293 | self.learner_name = "SGD_momentum eta=%.5f, mu=%.2f" % (eta, momentum)
294 | print("\n\n" + self.learner_name)
295 | f = np.ones((N))*np.nan
296 | x = self.xinit_flat.copy()
297 | # Prevous step
298 | inc = 0.0
299 | for epoch in range(num_passes):
300 | for minibatch in range(N):
301 | idx = np.random.randint(N)
302 | sr = self.model.subfunction_references[idx]
303 | fl, dfl = self.f_df_wrapper_flattened(x.reshape((-1,1)), (sr,))
304 | inc = momentum * inc - eta * dfl.reshape(x.shape)
305 | x += inc
306 | f[idx] = fl
307 | if not np.isfinite(fl):
308 | print("Non-finite subfunction. Ending run.")
309 | break
310 | if not np.isfinite(fl):
311 | print("Non-finite subfunction. Ending run.")
312 | break
313 | print(np.mean(f[np.isfinite(f)]), "average finite value at last evaluation")
314 |
315 |
316 | def ADA(self, num_passes=20):
317 | """ Train model using ADAgrad with various learning rates """
318 |
319 | for eta in 10**np.linspace(-3,1,5):
320 | self.learner_name = "ADAGrad %.4f"%eta
321 | print("\n\n" + self.learner_name)
322 | self.optimizer = ADAGrad(self.f_df_wrapper_flattened, self.xinit_flat.copy(), self.model.subfunction_references, learning_rate=eta)
323 | x = self.optimizer.optimize(num_passes=num_passes)
324 | print(np.mean(self.optimizer.f), "average value at last evaluation")
325 |
326 |
--------------------------------------------------------------------------------
/generate_figures/sag.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2014 Jascha Sohl-Dickstein
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 | http://www.apache.org/licenses/LICENSE-2.0
7 | Unless required by applicable law or agreed to in writing, software
8 | distributed under the License is distributed on an "AS IS" BASIS,
9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | See the License for the specific language governing permissions and
11 | limitations under the License.
12 |
13 | This is an implementation of the Stochastic Average Gradient (SAG) algorithm:
14 | Le Roux, Nicolas, Mark Schmidt, and Francis Bach.
15 | "A Stochastic Gradient Method with an Exponential Convergence _Rate for Finite Training Sets."
16 | Advances in Neural Information Processing Systems 25. 2012.
17 | http://books.nips.cc/papers/files/nips25/NIPS2012_1246.pdf
18 | """
19 |
20 | from numpy import *
21 | import scipy.linalg
22 |
23 | class SAG(object):
24 |
25 | def __init__(self, f_df, theta, subfunction_references, args=(), kwargs={}, L=1., L_freq=0):
26 | """
27 | L is the Lipschitz constant. Smaller corresponds to faster (but
28 | noisier) learning. If L_freq is greater than 0, then every L_freq
29 | steps L will be adjusted as described in the SAG paper.
30 | """
31 | self.L = L
32 | self.L_freq = L_freq # how often to perform the extra function evaluations in order to test L
33 |
34 | self.N = len(subfunction_references)
35 | self.sub_ref = subfunction_references
36 | self.M = theta.shape[0]
37 | self.f_df = f_df
38 | self.args = args
39 | self.kwargs = kwargs
40 |
41 | self.num_steps = 0
42 | self.theta = theta.copy().reshape((-1,1))
43 |
44 | self.f = ones((self.N))*nan
45 | self.df = zeros((self.M,self.N))
46 |
47 | self.grad_sum = sum( self.df, axis=1 ).reshape((-1,1))
48 |
49 | def optimize(self, num_passes = 10, num_steps = None):
50 | if num_steps==None:
51 | num_steps = num_passes*self.N
52 | for i in range(num_steps):
53 | if not self.optimization_step():
54 | break
55 | return self.theta
56 |
57 | def optimization_step(self):
58 | # choose a subfunction at random
59 | ind = int(floor(self.N*random.rand()))
60 |
61 | # calculate the objective function and gradient for the subfunction
62 | fl, dfl = self.f_df(self.theta, (self.sub_ref[ind], ), *self.args, **self.kwargs)
63 | dfl = dfl.reshape(-1,1)
64 | # store them
65 | self.f[ind] = fl
66 | self.grad_sum += (dfl - self.df[:,[ind]]) # TODO this may slowly accumulate errors. occasionally do the full sum?
67 | self.df[:,ind] = dfl.flat
68 |
69 | # only adjust the learning rate with frequency L_freq, to reduce computational load
70 | if self.L_freq > 0 and mod(self.num_steps, self.L_freq) == 0:
71 | # adapt the learning rate
72 | self.L *= 2**(-1. / self.N)
73 | theta_shift = self.theta - dfl / self.L
74 | # evaluate the objective function at theta_shift
75 | fl_shift, _ = self.f_df(theta_shift, (self.sub_ref[ind], ), *self.args, **self.kwargs)
76 | # test whether the change in objective satisfies Lip. inequality, otherwise increase constant
77 | if fl_shift - fl > -sum((dfl)**2) / (2.*self.L):
78 | self.L *= 2.
79 |
80 | self.num_steps += 1
81 |
82 | # take a gradient descent step
83 | div = min([self.num_steps, self.N])
84 | delta_theta = -self.grad_sum / self.L / div
85 | self.theta += delta_theta
86 |
87 | if not isfinite(fl):
88 | print("Non-finite subfunction. Ending run.")
89 | return False
90 | return True
91 |
--------------------------------------------------------------------------------
/generate_figures/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from theano.sandbox.cuda import CudaNdarraySharedVariable as CudaNdarray
3 |
4 | #XXX: Shamefully stolen from pyautodiff
5 | # https://github.com/jaberg/pyautodiff/blob/master/autodiff/fmin_scipy.py
6 |
7 | def _tonp(x):
8 | try:
9 | x = x.eval()
10 | except:
11 | pass
12 | #if not np.isscalar(x) and type(x) not in [CudaNdarray, np.array, np.ndarray]:
13 | #x = x.eval()
14 | return np.array(x)
15 | #return np.array(x)
16 | #elif type(x) in [np.array, np.ndarray]:
17 | # return x
18 | #else:
19 | # return np.array(x.eval())
20 |
21 | def vector_from_args(raw_args):
22 | args = [_tonp(a) for a in raw_args]
23 | args_sizes = [w.size for w in args]
24 | x_size = sum(args_sizes)
25 | x = np.empty(x_size, dtype=args[0].dtype) # has to be float64 for fmin_l_bfgs_b
26 | i = 0
27 | for w in args:
28 | x[i: i + w.size] = w.flatten()
29 | i += w.size
30 | return x
31 |
32 | def args_from_vector(x, orig_args):
33 | #if type(orig_args[0]) != np.ndarray:
34 | # orig_args = [a.eval() for a in orig_args]
35 | # unpack x_opt -> args-like structure `args_opt`
36 | rval = []
37 | i = 0
38 | for w in orig_args:
39 | size = _tonp(w.size)
40 | rval.append(x[i: i + size].reshape(_tonp(w.shape)).astype(w.dtype))
41 | i += size
42 | return rval
43 |
--------------------------------------------------------------------------------
/sfo.m:
--------------------------------------------------------------------------------
1 | % Implements the Sum of Functions Optimizer (SFO), as described in the paper:
2 | % Jascha Sohl-Dickstein, Ben Poole, and Surya Ganguli
3 | % An adaptive low dimensional quasi-Newton sum of functions optimizer
4 | % arXiv preprint arXiv:1311.2115 (2013)
5 | % http://arxiv.org/abs/1311.2115
6 | %
7 | % Sample code is provided in sfo_demo.m
8 | %
9 | % Useful functions in this class are:
10 | % obj = sfo(f_df, theta, subfunction_references, varargin)
11 | % Initializes the optimizer class.
12 | % Parameters:
13 | % f_df - Returns the function value and gradient for a single subfunction
14 | % call. Should have the form
15 | % [f, dfdtheta] = f_df(theta, subfunction_references{idx},
16 | % varargin{:})
17 | % where idx is the index of a single subfunction.
18 | % theta - The initial parameters to be used for optimization. theta can
19 | % be either a vector, a matrix, or a cell array with a vector or
20 | % matrix in every cell. The gradient returned by f_df should have the
21 | % same form as theta.
22 | % subfunction_references - A cell array containing an identifying element
23 | % for each subfunction. The elements in this list could be, eg,
24 | % matrices containing minibatches, or indices identifying the
25 | % subfunction, or filenames from which target data should be read.
26 | % If each subfunction corresponds to a minibatch, then the number of
27 | % subfunctions should be approximately
28 | % [number subfunctions] = sqrt([dataset size])/10.
29 | % varargin - Any additional parameters will be passed through to f_df
30 | % each time it is called.
31 | % Returns:
32 | % obj - sfo class instance.
33 | % theta = optimize(num_passes)
34 | % Optimize the objective function.
35 | % Parameters:
36 | % num_passes - The number of effective passes through
37 | % subfunction_references to perform.
38 | % Returns:
39 | % theta - The estimated parameter vector after num_passes of optimization.
40 | % check_grad()
41 | % Numerically checks the gradient of f_df for each element in
42 | % subfunction_references.
43 | %
44 | % Copyright 2014 Jascha Sohl-Dickstein
45 | %
46 | % Licensed under the Apache License, Version 2.0 (the "License");
47 | % you may not use this file except in compliance with the License.
48 | % You may obtain a copy of the License at
49 |
50 | % http://www.apache.org/licenses/LICENSE-2.0
51 |
52 | % Unless required by applicable law or agreed to in writing, software
53 | % distributed under the License is distributed on an "AS IS" BASIS,
54 | % WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55 | % See the License for the specific language governing permissions and
56 | % limitations under the License.
57 |
58 | classdef sfo < handle
59 |
60 | properties
61 | display = 2;
62 | f_df;
63 | args;
64 | max_history = 10;
65 | max_gradient_noise = 1;
66 | hess_max_dev = 1e8;
67 | hessian_init = 1e5;
68 | N;
69 | subfunction_references;
70 | % theta, in its original format;
71 | theta_original;
72 | % theta, flattented into a 1d array;
73 | theta;
74 | % theta from the previous learning step -- initialize to theta;
75 | theta_prior_step;
76 | % number of data dimensions;
77 | M;
78 | % the update steps will be rescaled by this;
79 | step_scale = 1;
80 | % 'very small' for various tasks, most importantly identifying when;
81 | % update steps or gradient changes are too small to be used for Hessian
82 | % updates without incurring large numerical errors.;
83 | eps = 1e-12;
84 |
85 | % The shortest step length allowed. Any update
86 | % steps shorter than this will be made this length. Set this so as to
87 | % prevent numerical errors when computing the difference in gradients
88 | % before and after a step.
89 | minimum_step_length = 1e-8;
90 | % The length of the longest allowed update step,
91 | % relative to the average length of prior update steps. Takes effect
92 | % after the first full pass through the data.
93 | max_step_length_ratio = 10;
94 |
95 | % the min & max dimenstionality for the subspace;
96 | K_min;
97 | K_max;
98 | % the current dimensionality of the subspace;
99 | K_current = 1;
100 | % obj.P holds the subspace;
101 | P;
102 |
103 | % store the minimum & maximum eigenvalue from each approximate;
104 | % Hessian;
105 | min_eig_sub;
106 | max_eig_sub;
107 |
108 | % store the total time spent in optimization, & the amount of time;
109 | % spent in the objective function;
110 | time_pass = 0.;
111 | time_func = 0.;
112 |
113 | % how many steps since the active set size was increased;
114 | iter_since_active_growth = 0;
115 |
116 | % which subfunctions are active;
117 | init_subf = 2;
118 | active;
119 |
120 | % the total path length traveled during optimization;
121 | total_distance = 0.;
122 | % number of function evaluations for each subfunction;
123 | eval_count;
124 |
125 | % theta projected into current working subspace;
126 | theta_proj;
127 | % holds the last position & the last gradient for all the objective functions;
128 | last_theta;
129 | last_df;
130 | % the history of theta changes for each subfunction;
131 | hist_deltatheta;
132 | % the history of gradient changes for each subfunction;
133 | hist_deltadf;
134 | % the history of function values for each subfunction;
135 | hist_f;
136 | % a flat history of all returned subfunction values for debugging/diagnostics;
137 | hist_f_flat = [];
138 |
139 | % the approximate Hessian for each subfunction is stored;
140 | % as dot(b(:.:.index), b(:.:.inedx).')
141 | b;
142 |
143 | % the full Hessian (sum over all the subfunctions);
144 | full_H = 0;
145 |
146 | % parameters that are passed through to f_df
147 | varargin_stored = {};
148 |
149 | % the predicted improvement in the total objective from the current update step
150 | f_predicted_total_improvement = 0;
151 | end
152 |
153 | methods
154 |
155 | function obj = sfo(f_df, theta, subfunction_references, varargin)
156 | % obj = sfo(f_df, theta, subfunction_references, varargin)
157 | % Initializes the optimizer class.
158 | % Parameters:
159 | % f_df - Returns the function value and gradient for a single subfunction
160 | % call. Should have the form
161 | % [f, dfdtheta] = f_df(theta, subfunction_references{idx},
162 | % varargin{:})
163 | % where idx is the index of a single subfunction.
164 | % theta - The initial parameters to be used for optimization. theta can
165 | % be either a vector, a matrix, or a cell array with a vector or
166 | % matrix in every cell. The gradient returned by f_df should have the
167 | % same form as theta.
168 | % subfunction_references - A cell array containing an identifying element
169 | % for each subfunction. The elements in this list could be, eg,
170 | % matrices containing minibatches, or indices identifying the
171 | % subfunction, or filenames from which target data should be read.
172 | % varargin - Any additional parameters will be passed through to f_df
173 | % each time it is called.
174 | % Returns:
175 | % obj - sfo class instance.
176 |
177 | obj.N = length(subfunction_references);
178 | obj.theta_original = theta;
179 | obj.theta = obj.theta_original_to_flat(obj.theta_original);
180 | obj.theta_prior_step = obj.theta;
181 | obj.M = length(obj.theta);
182 | obj.f_df = f_df;
183 | obj.varargin_stored = varargin;
184 | obj.subfunction_references = subfunction_references;
185 |
186 | subspace_dimensionality = 2.*obj.N+2; % 2 to include current location;
187 | % subspace can't be larger than the full space;
188 | subspace_dimensionality = min([subspace_dimensionality, obj.M]);
189 |
190 |
191 | % the min & max dimenstionality for the subspace;
192 | obj.K_min = subspace_dimensionality;
193 | obj.K_max = ceil(obj.K_min.*1.5);
194 | obj.K_max = min([obj.K_max, obj.M]);
195 | % obj.P holds the subspace;
196 | obj.P = zeros(obj.M,obj.K_max);
197 |
198 | % store the minimum & maximum eigenvalue from each approximate;
199 | % Hessian;
200 | obj.min_eig_sub = zeros(obj.N,1);
201 | obj.max_eig_sub = zeros(obj.N,1);
202 |
203 | % which subfunctions are active;
204 | obj.active = false(obj.N,1);
205 | obj.init_subf = min(obj.N, obj.init_subf);
206 | inds = randperm(obj.N, obj.init_subf);
207 | obj.active(inds) = true;
208 | obj.min_eig_sub(inds) = obj.hessian_init;
209 | obj.max_eig_sub(inds) = obj.hessian_init;
210 |
211 | % number of function evaluations for each subfunction;
212 | obj.eval_count = zeros(obj.N,1);
213 |
214 | % set the first column of the subspace to be the initial;
215 | % theta;
216 | rr = sqrt(sum(obj.theta.^2));
217 | if rr > 0;
218 | obj.P(:,1) = obj.theta/rr;
219 | else
220 | % initial theta is 0 -- initialize randomly;
221 | obj.P(:,1) = randn(obj.M,1);
222 | obj.P(:,1) = obj.P(:,1) / sqrt(sum(obj.P(:,1).^2));
223 | end
224 |
225 | if obj.M == obj.K_max
226 | % if the subspace spans the full space, then (j)ust make;
227 | % P the identity matrix;
228 | if obj.display > 1;
229 | fprintf('subspace spans full space');
230 | end
231 | obj.P = eye(obj.M);
232 | obj.K_current = obj.M+1;
233 | end
234 |
235 |
236 | % theta projected into current working subspace;
237 | obj.theta_proj = obj.P' * obj.theta;
238 | % holds the last position & the last gradient for all the objective functions;
239 | obj.last_theta = obj.theta_proj * ones(1,obj.N);
240 | obj.last_df = zeros(obj.K_max,obj.N);
241 | % the history of theta changes for each subfunction;
242 | obj.hist_deltatheta = zeros(obj.K_max,obj.max_history,obj.N);
243 | % the history of gradient changes for each subfunction;
244 | obj.hist_deltadf = zeros(obj.K_max,obj.max_history,obj.N);
245 | % the history of function values for each subfunction;
246 | obj.hist_f = ones(obj.N, obj.max_history).*nan;
247 |
248 | % the approximate Hessian for each subfunction is stored;
249 | % as dot(obj.b(:.:.index), obj.b(:.:.inedx).')
250 | obj.b = zeros(obj.K_max,2.*obj.max_history,obj.N);
251 |
252 | if obj.N < 25 && obj.display > 0
253 | fprintf( '\n\nIn experiments, performance suffered when the data was broken up into fewer\nthan 25 minibatches (and performance saturated after about 50 minibatches).\nSee Figure 2c. You may want to use more than the current %d minibatches.\n\n', obj.N);
254 | end
255 | end
256 |
257 |
258 | function theta = optimize(obj, num_passes)
259 | % theta = optimize(num_passes)
260 | % Optimize the objective function.
261 | % Parameters:
262 | % num_passes - The number of effective passes through
263 | % subfunction_references to perform.
264 | % Returns:
265 | % theta - The estimated parameter vector after num_passes of optimization.
266 |
267 | num_steps = ceil(num_passes.*obj.N);
268 | for i = 1:num_steps
269 | if obj.display > 1
270 | fprintf('pass %g, step %d,', sum(obj.eval_count)/obj.N, i);
271 | end
272 | obj.optimization_step();
273 | if obj.display > 1
274 | fprintf('active %d/%d, sfo time %g s, func time %g s, f %f, %f\n', sum(obj.active), size(obj.active, 1), obj.time_pass - obj.time_func, obj.time_func, obj.hist_f_flat(end), mean(obj.hist_f(obj.eval_count>0,1)));
275 | end
276 | end
277 | if obj.display > 0
278 | fprintf('active %d/%d, pass %g, sfo time %g s, func time %g s, %f\n', sum(obj.active), size(obj.active, 1), sum(obj.eval_count)/obj.N, obj.time_pass - obj.time_func, obj.time_func, mean(obj.hist_f(obj.eval_count>0,1)));
279 | end
280 |
281 | theta = obj.theta_flat_to_original(obj.theta);
282 | end
283 |
284 | function check_grad(obj)
285 | % A diagnostic function to check the gradients for the subfunctions. It
286 | % checks the subfunctions in random order, & the dimensions of each
287 | % subfunction in random order. This way, a representitive set of
288 | % gradients can be checked quickly, even for high dimensional objectives.
289 |
290 | % step size to use for gradient check;
291 | small_diff = obj.eps.*1e6;
292 | fprintf('Testing step size %g\n', small_diff);
293 |
294 | for i = randperm(obj.N)
295 | [fl, ~, dfl] = obj.f_df_wrapper(obj.theta, i);
296 | ep = zeros(obj.M,1);
297 | dfl_obs = zeros(obj.M,1);
298 | dfl_err = zeros(obj.M,1);
299 | for j = randperm(obj.M)
300 | ep(j) = small_diff;
301 | fl2 = obj.f_df_wrapper(obj.theta + ep, i);
302 | dfl_obs(j) = (fl2 - fl)/small_diff;
303 | dfl_err(j) = dfl_obs(j) - dfl(j);
304 | if abs(dfl_err(j)) > small_diff * 1e4
305 | fprintf('large diff ');
306 | else
307 | fprintf(' ');
308 | end
309 | fprintf(' gradient subfunction %d, dimension %d, analytic %g, finite diff %g, error %g\n', i, j, dfl(j), dfl_obs(j), dfl_err(j));
310 | ep(j) = 0.;
311 | end
312 | gerr = sqrt((sum(dfl - dfl_obs).^2));
313 | fprintf('subfunction %g, total L2 gradient error %g\n', i, gerr);
314 | end
315 | end
316 |
317 |
318 | function apply_subspace_transformation(obj,T_left,T_right)
319 | % Apply change-of-subspace transformation. This function is called when;
320 | % the subspace is collapsed to project into the new lower dimensional;
321 | % subspace.;
322 | % T_left - The covariant subspace to subspace projection matrix.;
323 | % T_right - The contravariant subspace projection matrix.;
324 |
325 | % (note that currently T_left = T_right always since the subspace is;
326 | % orthogonal. This will change if eg the code is adapted to also;
327 | % incorporate a 'natural gradient' based parameter space transformation.);
328 |
329 | [tt, ss] = size(T_left);
330 |
331 | % project history terms into new subspace;
332 | obj.last_df = (T_right.') * (obj.last_df);
333 | obj.last_theta = (T_left) * (obj.last_theta);
334 | obj.hist_deltadf = obj.reshape_wrapper(T_right.' * obj.reshape_wrapper(obj.hist_deltadf, [ss,-1]), [tt,-1,obj.N]);
335 | obj.hist_deltatheta = obj.reshape_wrapper(T_left * obj.reshape_wrapper(obj.hist_deltatheta, [ss, -1]), [tt,-1,obj.N]);
336 | % project stored hessian for each subfunction in to new subspace;
337 | obj.b = obj.reshape_wrapper(T_right.' * obj.reshape_wrapper(obj.b, [ss,-1]), [tt, 2.*obj.max_history,obj.N]);
338 |
339 | %% To avoid slow accumulation of numerical errors, recompute full_H;
340 | %% & theta_proj when the subspace is collapsed. Should not be a;
341 | %% leading time cost.;
342 | % theta projected into current working subspace;
343 | obj.theta_proj = (obj.P.') * (obj.theta);
344 | % full approximate hessian;
345 | obj.full_H = real(obj.reshape_wrapper(obj.b,[ss,-1]) * obj.reshape_wrapper(obj.b,[ss,-1]).');
346 | end
347 |
348 | function reorthogonalize_subspace(obj)
349 | % check if the subspace has become non-orthogonal
350 | subspace_eigs = eig(obj.P.' * obj.P);
351 | % TODO(jascha) this may be a stricter cutoff than we need
352 | if max(subspace_eigs) <= 1 + obj.eps
353 | return
354 | end
355 |
356 | if obj.display > 2
357 | fprintf('subspace has become non-orthogonal. Performing QR.\n');
358 | end
359 | [Porth, ~] = qr(obj.P(:,1:obj.K_current), 0);
360 | Pl = zeros(obj.K_max, obj.K_max);
361 | Pl(:,1:obj.K_current) = obj.P.' * Porth;
362 | % update the subspace;
363 | obj.P(:,1:obj.K_current) = Porth;
364 | % Pl is the projection matrix from old to new basis. apply it to all the history;
365 | % terms;
366 | obj.apply_subspace_transformation(Pl.', Pl);
367 | end
368 |
369 | function collapse_subspace(obj, xl)
370 | % Collapse the subspace to its smallest dimensionality.;
371 |
372 | % xl is a new direction that may not be in the history yet, so we pass;
373 | % it in explicitly to make sure it's included.;
374 |
375 | if obj.display > 2
376 | fprintf('collapsing subspace\n');
377 | end
378 |
379 | % the projection matrix from old to new subspace;
380 | Pl = zeros(obj.K_max,obj.K_max);
381 |
382 | % yy will hold all the directions to pack into the subspace.;
383 | % initialize it with random noise, so that it still spans K_min;
384 | % dimensions even if not all the subfunctions are active yet;
385 | yy = randn(obj.K_max,obj.K_min);
386 | % the most recent position & gradient for all active subfunctions;
387 | % as well as the current position & gradient (which will not be saved in the history yet);
388 | yz = [obj.last_df(:,obj.active), obj.last_theta(:,obj.active), xl, (obj.P.') * (obj.theta)];
389 | yy(:,1:size(yz, 2)) = yz;
390 | [Pl(:,1:obj.K_min), ~] = qr(yy, 0);
391 |
392 | % update the subspace;
393 | obj.P = (obj.P) * (Pl);
394 |
395 | % Pl is the projection matrix from old to new basis. apply it to all the history;
396 | % terms;
397 | obj.apply_subspace_transformation(Pl.', Pl);
398 |
399 | % update the stored subspace size;
400 | obj.K_current = obj.K_min;
401 |
402 | % re-orthogonalize the subspace if it's accumulated small errors
403 | obj.reorthogonalize_subspace();
404 | end
405 |
406 |
407 | function update_subspace(obj, x_in)
408 | % Update the low dimensional subspace by adding a new direction.;
409 | % x_in - The new vector to incorporate into the subspace.;
410 |
411 | if obj.K_current >= obj.M
412 | % no need to update the subspace if it spans the full space;
413 | return;
414 | end
415 | if sum(~isfinite(x_in)) > 0
416 | % bad vector! bail.;
417 | return;
418 | end
419 | x_in_length = sqrt(sum(x_in.^2));
420 | if x_in_length < obj.eps
421 | % if the new vector is too short, nothing to do;
422 | return;
423 | end
424 | % make x unit length;
425 | xnew = x_in/x_in_length;
426 |
427 | % Find the component of x pointing out of the existing subspace.;
428 | % We need to do this multiple times for numerical stability.;
429 | for i = 1:3
430 | xnew = xnew - obj.P * (obj.P.' * xnew);
431 | ss = sqrt(sum(xnew.^2));
432 | if ss < obj.eps
433 | % it barely points out of the existing subspace;
434 | % no need to add a new direction to the subspace;
435 | return;
436 | end
437 | % make it unit length;
438 | xnew = xnew / ss;
439 | % if it was already largely orthogonal then numerical;
440 | % stability will be good enough;
441 | % TODO replace this with a more principled test;
442 | if ss > 0.1
443 | break;
444 | end
445 | end
446 |
447 | % add a new column to the subspace containing the new direction;
448 | obj.P(:,obj.K_current+1) = xnew;
449 | obj.K_current = obj.K_current + 1;
450 |
451 | if obj.K_current >= obj.K_max
452 | % the subspace has exceeded its maximum allowed size -- collapse it;
453 | % xl may not be in the history yet, so we pass it in explicitly to make;
454 | % sure it's used;
455 | xl = (obj.P.') * (x_in);
456 | obj.collapse_subspace(xl);
457 | end
458 | end
459 |
460 | function full_H_combined = get_full_H_with_diagonal(obj)
461 | % Get the full approximate Hessian, including the diagonal terms.;
462 | % (note that obj.full_H is stored without including the diagonal terms);
463 |
464 | full_H_combined = obj.full_H + eye(obj.K_max).*sum(obj.min_eig_sub(obj.active));
465 | end
466 |
467 | function f_pred = get_predicted_subf(obj, indx, theta_proj)
468 | % Get the predicted value of subfunction idx at theta_proj;
469 | % (where theat_proj is in the subspace);
470 |
471 | dtheta = theta_proj - obj.last_theta(:,indx);
472 | bdtheta = obj.b(:,:,indx).' * dtheta;
473 | Hdtheta = real(obj.b(:,:,indx) * bdtheta);
474 | Hdtheta = Hdtheta + dtheta.*obj.min_eig_sub(indx); % the diagonal contribution
475 | %df_pred = obj.last_df(:,indx) + Hdtheta;
476 | f_pred = obj.hist_f(indx,1) + obj.last_df(:,indx).' * dtheta + 0.5.*(dtheta.') * (Hdtheta);
477 | end
478 |
479 |
480 | function update_history(obj, indx, theta_proj, f, df_proj)
481 | % Update history of position differences & gradient differences;
482 | % for subfunction indx.;
483 |
484 | % there needs to be at least one earlier measurement from this;
485 | % subfunction to compute position & gradient differences.;
486 | if obj.eval_count(indx) > 1
487 | % differences in gradient & position;
488 | ddf = df_proj - obj.last_df(:,indx);
489 | ddt = theta_proj - obj.last_theta(:,indx);
490 | % length of gradient & position change vectors;
491 | lddt = sqrt(sum(ddt.^2));
492 | lddf = sqrt(sum(ddf.^2));
493 |
494 | corr_ddf_ddt = ddf.' * ddt / (lddt*lddf);
495 |
496 | if obj.display > 3 && corr_ddf_ddt < 0
497 | fprintf('Warning! Negative dgradient dtheta inner product. Adding it anyway.');
498 | end
499 | if lddt < obj.eps
500 | if obj.display > 2
501 | fprintf('Largest change in theta too small (%g). Not adding to history.', lddt);
502 | end
503 | elseif lddf < obj.eps
504 | if obj.display > 2
505 | fprintf('Largest change in gradient too small (%g). Not adding to history.', lddf);
506 | end
507 | elseif abs(corr_ddf_ddt) < obj.eps
508 | if obj.display > 2
509 | fprintf('Inner product between dgradient and dtheta too small (%g). Not adding to history.', corr_ddf_ddt);
510 | end
511 | else
512 | if obj.display > 3
513 | fprintf('subf ||dtheta|| %g, subf ||ddf|| %g, corr(ddf,dtheta) %g,', lddt, lddf, sum(ddt.*ddf)/(lddt.*lddf));
514 | end
515 |
516 | % shift the history by one timestep;
517 | obj.hist_deltatheta(:,2:end,indx) = obj.hist_deltatheta(:,1:end-1,indx);
518 | % store the difference in theta since the subfunction was last evaluated;
519 | obj.hist_deltatheta(:,1,indx) = ddt;
520 | % do the same thing for the change in gradient;
521 | obj.hist_deltadf(:,2:end,indx) = obj.hist_deltadf(:,1:end-1,indx);
522 | obj.hist_deltadf(:,1,indx) = ddf;
523 | end
524 | end
525 |
526 | obj.last_theta(:,indx) = theta_proj;
527 | obj.last_df(:,indx) = df_proj;
528 | obj.hist_f(indx,2:end) = obj.hist_f(indx,1:end-1);
529 | obj.hist_f(indx,1) = f;
530 | end
531 |
532 |
533 | function update_hessian(obj,indx)
534 | % Update the Hessian approximation for a single subfunction.;
535 | % indx - The index of the target subfunction for Hessian update.;
536 |
537 | gd = find(sum(obj.hist_deltatheta(:,:,indx).^2, 1)>0);
538 | num_gd = length(gd);
539 | if num_gd == 0
540 | % if no history, initialize with the median eigenvalue from full Hessian;
541 | if obj.display > 2
542 | fprintf(' no history ');
543 | end
544 | obj.b(:,:,indx) = 0.;
545 | H = obj.get_full_H_with_diagonal();
546 | [U, ~] = obj.eigh_wrapper(H);
547 | obj.min_eig_sub(indx) = median(U)/sum(obj.active);
548 | obj.max_eig_sub(indx) = obj.min_eig_sub(indx);
549 | if obj.eval_count(indx) > 2
550 | if obj.display > 2 || sum(obj.eval_count) < 5
551 | fprintf('Subfunction evaluated %d times, but has no stored history.', obj.eval_count(indx));
552 | end
553 | if sum(obj.eval_count) < 5
554 | fprintf('You probably need to initialize SFO with a smaller hessian_init value. Scaling down the Hessian to try to recover. You are better off correcting the hessian_init value though!');
555 | obj.min_eig_sub(indx) = obj.min_eig_sub(indx) / 10.;
556 | end
557 | end
558 | return;
559 | end
560 |
561 | % work in the subspace defined by this subfunction's history for this;
562 | [P_hist, ~] = qr([obj.hist_deltatheta(:,gd,indx),obj.hist_deltadf(:,gd,indx)], 0);
563 | deltatheta_P = P_hist.' * obj.hist_deltatheta(:,gd,indx);
564 | deltadf_P = P_hist.' * obj.hist_deltadf(:,gd,indx);
565 |
566 | %% get an approximation to the smallest eigenvalue.;
567 | %% This will be used as the diagonal initialization for BFGS.;
568 | % calculate Hessian using pinv & squared equation. (j)ust to get;
569 | % smallest eigenvalue.;
570 | % df = H dx;
571 | % df^T df = dx^T H^T H dx = dx^T H^2 dx
572 | pdelthet = pinv(deltatheta_P);
573 | dd = (deltadf_P) * (pdelthet);
574 | H2 = (dd.') * (dd);
575 | [H2w, ~] = obj.eigh_wrapper(H2);
576 | H2w = sqrt(abs(H2w));
577 |
578 | % only the top ~ num_gd eigenvalues are expected to be well defined;
579 | H2w = sort(H2w, 'descend');
580 | H2w = H2w(1:num_gd);
581 |
582 | if min(H2w) == 0 || sum(~isfinite(H2w)) > 0
583 | % there was a failure using this history. either deltadf was;
584 | % degenerate (0 case), | deltatheta was (non-finite case).;
585 | % Initialize using other subfunctions;
586 | H2w(:) = max(obj.min_eig_sub(obj.active));
587 | if obj.display > 3
588 | fprintf('ill-conditioned history');
589 | end
590 | end
591 |
592 | obj.min_eig_sub(indx) = min(H2w);
593 | obj.max_eig_sub(indx) = max(H2w);
594 |
595 | if obj.min_eig_sub(indx) < obj.max_eig_sub(indx)/obj.hess_max_dev
596 | % constrain using allowed ratio;
597 | obj.min_eig_sub(indx) = obj.max_eig_sub(indx)/obj.hess_max_dev;
598 | if obj.display > 3
599 | fprintf('constraining Hessian initialization');
600 | end
601 | end
602 |
603 | %% recalculate Hessian;
604 | % number of history terms;
605 | num_hist = size(deltatheta_P, 2);
606 | % the new hessian will be (b_p) * (b_p.') + eye().*obj.min_eig_sub(indx);
607 | b_p = zeros(size(P_hist, 2), num_hist*2);
608 | % step through the history;
609 | for hist_i = num_hist:-1:1
610 | s = deltatheta_P(:,hist_i);
611 | y = deltadf_P(:,hist_i);
612 |
613 | % for numerical stability;
614 | rscl = sqrt(sum(s.^2));
615 | s = s/rscl;
616 | y = y/rscl;
617 |
618 | % the BFGS step proper
619 | Hs = s.*obj.min_eig_sub(indx) + b_p * ((b_p.') * (s));
620 | term1 = y / sqrt(sum(y.*s));
621 | sHs = sum(s.*Hs);
622 | term2 = sqrt(complex(-1.)) .* Hs / sqrt(sHs);
623 | if sum(~isfinite(term1)) > 0 || sum(~isfinite(term2)) > 0
624 | obj.min_eig_sub(indx) = max(H2w);
625 | if obj.display > 1
626 | fprintf('invalid bfgs history term. should never get here!');
627 | end
628 | continue;
629 | end
630 | b_p(:,2*(hist_i-1)+2) = term1;
631 | b_p(:,2*(hist_i-1)+1) = term2;
632 | end
633 |
634 | H = real((b_p) * (b_p.')) + eye(size(b_p, 1))*obj.min_eig_sub(indx);
635 | % constrain it to be positive definite;
636 | [U, V] = obj.eigh_wrapper(H);
637 | if max(U) <= 0.
638 | % if there aren't any positive eigenvalues, then;
639 | % set them all to be the same conservative diagonal value;
640 | U(:) = obj.max_eig_sub(indx);
641 | if obj.display > 3
642 | fprintf('no positive eigenvalues after BFGS');
643 | end
644 | end
645 | % set any too-small eigenvalues to the median positive;
646 | % eigenvalue;
647 | U_median = median(U(U>0));
648 | U(U<(max(abs(U))/obj.hess_max_dev)) = U_median;
649 |
650 | % the Hessian after it's been forced to be positive definite;
651 | H_posdef = bsxfun(@times, V, U') * V.';
652 |
653 | % now break it apart into matrices b & a diagonal term again;
654 | B_pos = H_posdef - eye(size(b_p, 1))*obj.min_eig_sub(indx);
655 | [U, V] = obj.eigh_wrapper(B_pos);
656 | b_p = bsxfun(@times, V, sqrt(obj.reshape_wrapper(U, [1,-1])));
657 |
658 | obj.b(:,:,indx) = 0.;
659 | obj.b(:,1:size(b_p,2),indx) = (P_hist) * (b_p);
660 | end
661 |
662 |
663 | function theta_flat = theta_original_to_flat(obj, theta_original)
664 | % Convert from the original parameter format into a 1d array
665 | % The original format can be an array, | a cell array full of
666 | % arrays
667 |
668 | if iscell(theta_original)
669 | theta_length = 0;
670 | for theta_array = theta_original(:)'
671 | % iterate over cells
672 | theta_length = theta_length + numel(theta_array{1});
673 | end
674 | theta_flat = zeros(theta_length,1);
675 | i_start = 1;
676 | for theta_array = theta_original(:)'
677 | % iterate over cells
678 | i_end = i_start + numel(theta_array{1})-1;
679 | theta_flat(i_start:i_end) = theta_array{1}(:);
680 | i_start = i_end+1;
681 | end
682 | else
683 | theta_flat = theta_original(:);
684 | end
685 | end
686 | function theta_new = theta_flat_to_original(obj, theta_flat)
687 | % Convert from a 1d array into the original parameter format.;
688 |
689 | if iscell(obj.theta_original)
690 | theta_new = cell(size(obj.theta_original));
691 | i_start = 1;
692 | jj = 1;
693 | for theta_array = obj.theta_original(:)'
694 | % iterate over cells
695 | i_end = i_start + numel(theta_array{1})-1;
696 | theta_new{jj} = obj.reshape_wrapper(theta_flat(i_start:i_end), size(theta_array{1}));
697 | i_start = i_end + 1;
698 | jj = jj + 1;
699 | end
700 | else
701 | theta_new = obj.reshape_wrapper(theta_flat, size(obj.theta_original));
702 | end
703 | end
704 |
705 |
706 | function [f, df_proj, df_full] = f_df_wrapper(obj, theta_in, idx)
707 | % A wrapper around the subfunction objective f_df, that handles the transformation;
708 | % into & out of the flattened parameterization used internally by SFO.;
709 |
710 | theta_local = obj.theta_flat_to_original(theta_in);
711 | % evaluate;
712 | t = tic();
713 | [f, df_full] = obj.f_df(theta_local, obj.subfunction_references{idx}, obj.varargin_stored{:});
714 | time_diff = toc(t);
715 | obj.time_func = obj.time_func + time_diff; % time spent in function evaluation;
716 | df_full = obj.theta_original_to_flat(df_full);
717 | % update the subspace with the new gradient direction;
718 | obj.update_subspace(df_full);
719 | % gradient projected into the current subspace;
720 | df_proj = ( obj.P.') * (df_full );
721 | % keep a record of function evaluations;
722 | obj.hist_f_flat = [obj.hist_f_flat, f];
723 | obj.eval_count(idx) = obj.eval_count(idx) + 1;
724 | end
725 |
726 |
727 | function indx = get_target_index(obj)
728 | % Choose which subfunction to update this iteration.
729 |
730 | % if an active subfunction has one evaluation, get a second
731 | % so we can have a Hessian estimate
732 | gd = find((obj.eval_count == 1) & obj.active);
733 | if ~isempty(gd)
734 | indx = gd(randperm(length(gd), 1));
735 | return
736 | end
737 | % If an active subfunction has less than two observations, then;
738 | % evaluate it. We want to get to two evaluations per subfunction;
739 | % as quickly as possibly so that it's possible to estimate a Hessian;
740 | % for it
741 | gd = find((obj.eval_count < 2) & obj.active);
742 | if ~isempty(gd)
743 | indx = gd(randperm(length(gd), 1));
744 | return
745 | end
746 |
747 | % use the subfunction evaluated farthest;
748 | % either weighted by the total Hessian, or by the Hessian
749 | % just for that subfunction
750 | if randn() < 0
751 | max_dist = -1;
752 | indx = -1;
753 | for i = 1:obj.N
754 | dtheta = obj.theta_proj - obj.last_theta(:,i);
755 | bdtheta = obj.b(:,:,i).' * dtheta;
756 | dist = sum(bdtheta.^2) + sum(dtheta.^2)*obj.min_eig_sub(i);
757 | if (dist > max_dist) && obj.active(i)
758 | max_dist = dist;
759 | indx = i;
760 | end
761 | end
762 | else
763 | % from the current location, weighted by the Hessian;
764 | % difference between current theta & most recent evaluation;
765 | % for all subfunctions;
766 | dtheta = bsxfun(@plus, obj.theta_proj, -obj.last_theta);
767 | % full Hessian;
768 | full_H_combined = obj.get_full_H_with_diagonal();
769 | % squared distance;
770 | distance = sum(dtheta.*(full_H_combined * dtheta), 1);
771 | % sort the distances from largest to smallest;
772 | [~, dist_ord] = sort(-distance);
773 | % & keep only the indices that belong to active subfunctions;
774 | dist_ord = dist_ord(obj.active(dist_ord));
775 | % & choose the active subfunction from farthest away;
776 | indx = dist_ord(1);
777 | if max(distance(obj.active)) < obj.eps && sum(~obj.active)>0 && obj.eval_count(indx)>0
778 | if obj.display > 2
779 | fprintf('all active subfunctions evaluated here. expanding active set.');
780 | end
781 | inactive = find(~obj.active);
782 | indx = inactive(randperm(length(inactive), 1));
783 | obj.active(indx) = true;
784 | end
785 | end
786 |
787 | end
788 |
789 |
790 | function [step_failure, f, df_proj] = handle_step_failure(obj, f, df_proj, indx)
791 | % Check whether an update step failed. Update current position if it did.;
792 |
793 | % check to see whether the step should be a failure;
794 | step_failure = false;
795 | if ~isfinite(f) || sum(~isfinite(df_proj))>0
796 | % step is a failure if function | gradient is non-finite;
797 | step_failure = true;
798 | elseif obj.eval_count(indx) == 1
799 | % the step is a candidate for failure if it's a new subfunction, & it's;
800 | % much larger than expected;
801 | if max(obj.eval_count) > 1
802 | if f > mean(obj.hist_f(obj.eval_count>1,1)) + 3*std(obj.hist_f(obj.eval_count>1,1))
803 | step_failure = true;
804 | end
805 | end
806 | elseif f > obj.hist_f(indx,1)
807 | % if this subfunction has increased in value, then look whether it's larger;
808 | % than its predicted value by enough to trigger a failure;
809 | % calculate the predicted value of this subfunction;
810 | f_pred = obj.get_predicted_subf(indx, obj.theta_proj);
811 | % if the subfunction exceeds its predicted value by more than the predicted average gain;
812 | % in the other subfunctions, then mark the step as a failure
813 | % (note that it's been about N steps since this has been evaluated, & that this subfunction can lay;
814 | % claim to about 1/N fraction of the objective change);
815 | predicted_improvement_others = obj.f_predicted_total_improvement - (obj.hist_f(indx,1) - f_pred);
816 | if f - f_pred > predicted_improvement_others
817 | step_failure = true;
818 | end
819 | end
820 |
821 | if ~step_failure
822 | % decay the step_scale back towards 1;
823 | obj.step_scale = 1./obj.N + obj.step_scale .* (1. - 1./obj.N);
824 | else
825 | % shorten the step length;
826 | obj.step_scale = obj.step_scale / 2.;
827 |
828 | % the subspace may be updated during the function calls;
829 | % so store this in the full space;
830 | df = (obj.P) * (df_proj);
831 |
832 | [f_lastpos, df_lastpos_proj] = obj.f_df_wrapper(obj.theta_prior_step, indx);
833 | df_lastpos = (obj.P) * (df_lastpos_proj);
834 |
835 | %% if the function value exploded, then back it off until it's a;
836 | %% reasonable order of magnitude before adding anything to the history;
837 | f_pred = obj.get_predicted_subf(indx, obj.theta_proj);
838 | if isfinite(obj.hist_f(indx,1))
839 | predicted_f_diff = abs(f_pred - obj.hist_f(indx,1));
840 | else
841 | predicted_f_diff = abs(f - f_lastpos);
842 | end
843 | if ~isfinite(predicted_f_diff) || predicted_f_diff < obj.eps
844 | predicted_f_diff = obj.eps;
845 | end
846 |
847 | for i_ls = 1:10
848 | if f - f_lastpos < 10*predicted_f_diff
849 | % the failed update is already with an order of magnitude;
850 | % of the target update value -- no backoff required;
851 | break;
852 | end
853 | if obj.display > 4
854 | fprintf('ls %d f_diff %g predicted_f_diff %g ', i_ls, f - f_lastpos, predicted_f_diff);
855 | end
856 | % make the step length a factor of 100 shorter;
857 | obj.theta = 0.99.*obj.theta_prior_step + 0.01.*obj.theta;
858 | obj.theta_proj = (obj.P.') * (obj.theta);
859 | % & recompute f & df at this new location;
860 | [f, df_proj] = obj.f_df_wrapper(obj.theta, indx);
861 | df = (obj.P) * (df_proj);
862 | end
863 |
864 | % we're done with function calls. can move these back into the subspace.;
865 | df_proj = (obj.P.') * (df);
866 | df_lastpos_proj = (obj.P.') * (df_lastpos);
867 |
868 | if f < f_lastpos
869 | % the original objective was better -- but add the newly evaluated point to the history;
870 | % (j)ust so it's not a wasted function call;
871 | theta_lastpos_proj = (obj.P.') * (obj.theta_prior_step);
872 | obj.update_history(indx, theta_lastpos_proj, f_lastpos, df_lastpos_proj);
873 | if obj.display > 2
874 | fprintf('step failed, but last position was even worse ( f %g, std f %g), ', f_lastpos, std(obj.hist_f(obj.eval_count>0,1)));
875 | end
876 | else
877 | % add the change in theta & the change in gradient to the history for this subfunction;
878 | % before failing over to the last position;
879 | if isfinite(f) && sum(~isfinite(df_proj))==0
880 | obj.update_history(indx, obj.theta_proj, f, df_proj);
881 | end
882 | if obj.display > 2
883 | fprintf('step failed, proposed f %g, std f %g, ', f, std(obj.hist_f(obj.eval_count>0,1)));
884 | end
885 | if (obj.display > -1) && (sum(obj.eval_count>1) < 2)
886 | fprintf([ '\nStep failed on the very first subfunction. This is\n' ...
887 | 'either due to an incorrect gradient, or a very large\n' ...
888 | 'Hessian. Try:\n' ...
889 | ' - Calling check_grad() (see README.md for details)\n' ...
890 | ' - Setting sfo.hessian_init to a larger value.\n']);
891 | end
892 | f = f_lastpos;
893 | df_proj = df_lastpos_proj;
894 | obj.theta = obj.theta_prior_step;
895 | obj.theta_proj = (obj.P.') * (obj.theta);
896 | end
897 | end
898 |
899 | % don't let steps get so short that they don't provide any usable Hessian information;
900 | % TODO use a more principled cutoff here;
901 | obj.step_scale = max([obj.step_scale, 1e-5]);
902 | end
903 |
904 |
905 | function expand_active_subfunctions(obj, full_H_inv, step_failure)
906 | % expand the set of active subfunctions as appropriate;
907 |
908 | % power in the average gradient direction;
909 | df_avg = mean(obj.last_df(:,obj.active), 2);
910 | p_df_avg = sum(df_avg .* (full_H_inv * df_avg));
911 | % power of the standard error;
912 | ldfs = obj.last_df(:,obj.active) - df_avg*ones(1,sum(obj.active));
913 | num_active = sum(obj.active);
914 | p_df_sum = sum(sum(ldfs .* (full_H_inv * ldfs))) / num_active / (num_active - 1);
915 | % if the standard errror in the estimated gradient is the same order of magnitude as the gradient;
916 | % we want to increase the size of the active set;
917 | increase_desirable = (p_df_sum >= p_df_avg.*obj.max_gradient_noise);
918 | % increase the active set on step failure;
919 | increase_desirable = increase_desirable || step_failure;
920 | % increase the active set if we've done a full pass without updating it;
921 | increase_desirable = increase_desirable || (obj.iter_since_active_growth > num_active);
922 | % make sure all the subfunctions have enough evaluations for a Hessian approximation;
923 | % before bringing in new subfunctions;
924 | eligibile_for_increase = (min(obj.eval_count(obj.active)) >= 2);
925 | % one more iteration has passed since the active set was last expanded;
926 | obj.iter_since_active_growth = obj.iter_since_active_growth + 1;
927 | if increase_desirable && eligibile_for_increase && sum(~obj.active) > 0
928 | % the index of the new subfunction to activate;
929 | new_gd = find(~obj.active);
930 | new_gd = new_gd(randperm(length(new_gd), 1));
931 | if ~isempty(new_gd)
932 | obj.iter_since_active_growth = 0;
933 | obj.active(new_gd) = true;
934 | end
935 | end
936 | end
937 |
938 |
939 | function optimization_step(obj)
940 | % Perform a single optimization step. This function is typically called by SFO.optimize().;
941 |
942 | time_pass_start = tic();
943 |
944 | %% choose an index to update;
945 | indx = obj.get_target_index();
946 |
947 | if obj.display > 2
948 | fprintf('||dtheta|| %g, ', sqrt(sum((obj.theta - obj.theta_prior_step).^2)));
949 | fprintf('index %d, last f %g, ', indx, obj.hist_f(indx,1));
950 | fprintf('step scale %g, ', obj.step_scale);
951 | end
952 | if obj.display > 8
953 | C = obj.P.' * obj.P;
954 | eC = eig(C);
955 | fprintf('mne %g, mxe %g, ', min(eC(eC>0)), max(eC));
956 | end
957 |
958 | % evaluate subfunction value & gradient at new position;
959 | [f, df_proj] = obj.f_df_wrapper(obj.theta, indx);
960 |
961 | % check for a failed update step, & adjust f, df, & obj.theta;
962 | % as appropriate if one occurs.;
963 | [step_failure, f, df_proj] = obj.handle_step_failure(f, df_proj, indx);
964 |
965 | % add the change in theta & the change in gradient to the history for this subfunction;
966 | obj.update_history(indx, obj.theta_proj, f, df_proj);
967 |
968 | % increment the total distance traveled using the last update;
969 | obj.total_distance = obj.total_distance + sqrt(sum((obj.theta - obj.theta_prior_step).^2));
970 |
971 | % the current contribution from this subfunction to the total Hessian approximation;
972 | H_pre_update = real(obj.b(:,:,indx) * obj.b(:,:,indx).');
973 | %% update this subfunction's Hessian estimate;
974 | obj.update_hessian(indx);
975 | % the new contribution from this subfunction to the total approximate hessian;
976 | H_new = real(obj.b(:,:,indx) * obj.b(:,:,indx).');
977 | % update total Hessian using this subfunction's updated contribution;
978 | obj.full_H = obj.full_H + H_new - H_pre_update;
979 |
980 | % calculate the total gradient, total Hessian, & total function value at the current location;
981 | full_df = 0.;
982 | for i = 1:obj.N
983 | dtheta = obj.theta_proj - obj.last_theta(:,i);
984 | bdtheta = obj.b(:,:,i).' * dtheta;
985 | Hdtheta = real(obj.b(:,:,i) * bdtheta);
986 | Hdtheta = Hdtheta + dtheta.*obj.min_eig_sub(i); % the diagonal contribution;
987 | full_df = full_df + Hdtheta + obj.last_df(:,i);
988 | end
989 | full_H_combined = obj.get_full_H_with_diagonal();
990 | % TODO - Use Woodbury identity instead of recalculating full inverse;
991 | full_H_inv = inv(full_H_combined);
992 |
993 | % calculate an update step;
994 | dtheta_proj = -(full_H_inv) * (full_df) .* obj.step_scale;
995 |
996 | dtheta_proj_length = sqrt(sum(dtheta_proj(:).^2));
997 | if dtheta_proj_length < obj.minimum_step_length
998 | dtheta_proj = dtheta_proj*obj.minimum_step_length/dtheta_proj_length;
999 | dtheta_proj_length = obj.minimum_step_length;
1000 | if obj.display > 3
1001 | fprintf('forcing minimum step length');
1002 | end
1003 | end
1004 | if sum(obj.eval_count) > obj.N && dtheta_proj_length > obj.eps
1005 | % only allow a step to be up to a factor of max_step_length_ratio longer than the
1006 | % average step length
1007 | avg_length = obj.total_distance / sum(obj.eval_count);
1008 | length_ratio = dtheta_proj_length / avg_length;
1009 | ratio_scale = obj.max_step_length_ratio;
1010 | if length_ratio > ratio_scale
1011 | if obj.display > 3
1012 | fprintf('truncating step length from %g to %g', dtheta_proj_length, ratio_scale*avg_length);
1013 | end
1014 | dtheta_proj_length = dtheta_proj_length/(length_ratio/ratio_scale);
1015 | dtheta_proj = dtheta_proj/(length_ratio/ratio_scale);
1016 | end
1017 | end
1018 |
1019 | % the update to theta, in the full dimensional space;
1020 | dtheta = (obj.P) * (dtheta_proj);
1021 |
1022 | % backup the prior position, in case this is a failed step;
1023 | obj.theta_prior_step = obj.theta;
1024 | % update theta to the new location;
1025 | obj.theta = obj.theta + dtheta;
1026 | obj.theta_proj = obj.theta_proj + dtheta_proj;
1027 | % the predicted improvement from this update step;
1028 | obj.f_predicted_total_improvement = 0.5 .* dtheta_proj.' * (full_H_combined * dtheta_proj);
1029 |
1030 | %% expand the set of active subfunctions as appropriate;
1031 | obj.expand_active_subfunctions(full_H_inv, step_failure);
1032 |
1033 | % record how much time was taken by this learning step;
1034 | time_diff = toc(time_pass_start);
1035 | obj.time_pass = obj.time_pass + time_diff;
1036 | end
1037 | end
1038 |
1039 |
1040 | methods(Static)
1041 |
1042 | function A = reshape_wrapper(A, shape)
1043 | % a wrapper for reshape which duplicates the numpy behavior, and sets
1044 | % any -1 dimensions to the appropriate length
1045 |
1046 | total_dims = numel(A);
1047 | total_assigned = prod(shape(shape>0));
1048 | shape(shape==-1) = total_dims/total_assigned;
1049 | A = reshape(A, shape);
1050 | end
1051 |
1052 |
1053 | function [U, V] = eigh_wrapper(A)
1054 | % A wrapper which duplicates the order and format of the numpy
1055 | % eigh routine. (note, eigh further assumes symmetric matrix. don't
1056 | % think there's an equivalent MATLAB function?)
1057 |
1058 | % Note: this function enforces A to be symmetric
1059 |
1060 | [V,U] = eig(0.5 * (A + A'));
1061 | U = diag(U);
1062 | end
1063 |
1064 | end
1065 | end
1066 |
--------------------------------------------------------------------------------
/sfo_demo.m:
--------------------------------------------------------------------------------
1 | % Demonstrates usage of the Sum of Functions Optimizer (SFO) MATLAB
2 | % package. See sfo.m and
3 | % https://github.com/Sohl-Dickstein/Sum-of-Functions-Optimizer
4 | % for additional documentation.
5 | %
6 | % Copyright 2014 Jascha Sohl-Dickstein
7 | %
8 | % Licensed under the Apache License, Version 2.0 (the "License");
9 | % you may not use this file except in compliance with the License.
10 | % You may obtain a copy of the License at
11 | %
12 | % http://www.apache.org/licenses/LICENSE-2.0
13 | %
14 | % Unless required by applicable law or agreed to in writing, software
15 | % distributed under the License is distributed on an "AS IS" BASIS,
16 | % WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | % See the License for the specific language governing permissions and
18 | % limitations under the License.
19 |
20 | function sfo_demo()
21 | % set model and training data parameters
22 | M = 20; % number visible units
23 | J = 10; % number hidden units
24 | D = 100000; % full data batch size
25 | N = floor(sqrt(D)/10.); % number minibatches
26 | % generate random training data
27 | v = randn(M,D);
28 |
29 | % create the cell array of subfunction specific arguments
30 | sub_refs = cell(N,1);
31 | for i = 1:N
32 | % extract a single minibatch of training data.
33 | sub_refs{i} = v(:,i:N:end);
34 | end
35 |
36 | % initialize parameters
37 | % Parameters can be stored as a vector, a matrix, or a cell array with a
38 | % vector or matrix in each cell. Here the parameters are
39 | % {[weight matrix], [hidden bias], [visible bias]}.
40 | theta_init = {randn(J,M), randn(J,1), randn(M,1)};
41 | % initialize the optimizer
42 | optimizer = sfo(@f_df_autoencoder, theta_init, sub_refs);
43 | % uncomment the following line to test the gradient of f_df
44 | %optimizer.check_grad();
45 | % run the optimizer for half a pass through the data
46 | theta = optimizer.optimize(0.5);
47 | % run the optimizer for another 20 passes through the data, continuing from
48 | % the theta value where the prior call to optimize() ended
49 | theta = optimizer.optimize(20);
50 | % plot the convergence trace
51 | plot(optimizer.hist_f_flat);
52 | xlabel('Iteration');
53 | ylabel('Minibatch Function Value');
54 | title('Convergence Trace');
55 | end
56 |
57 | % define an objective function and gradient
58 | function [f, dfdtheta] = f_df_autoencoder(theta, v)
59 | % [f, dfdtheta] = f_df_autoencoder(theta, v)
60 | % Calculate L2 reconstruction error and gradient for an autoencoder
61 | % with sigmoid nonlinearity.
62 | % Parameters:
63 | % theta - A cell array containing
64 | % {[weight matrix], [hidden bias], [visible bias]}.
65 | % v - A [# visible, # datapoints] matrix containing training data.
66 | % v will be different for each subfunction.
67 | % Returns:
68 | % f - The L2 reconstruction error for data v and parameters theta.
69 | % df - A cell array containing the gradient of f with each of the
70 | % parameters in theta.
71 |
72 | W = theta{1};
73 | b_h = theta{2};
74 | b_v = theta{3};
75 |
76 | h = 1./(1 + exp(-bsxfun(@plus, W * v, b_h)));
77 | v_hat = bsxfun(@plus, W' * h, b_v);
78 | f = sum(sum((v_hat - v).^2)) / size(v, 2);
79 | dv_hat = 2*(v_hat - v) / size(v, 2);
80 | db_v = sum(dv_hat, 2);
81 | dW = h * dv_hat';
82 | dh = W * dv_hat;
83 | db_h = sum(dh.*h.*(1-h), 2);
84 | dW = dW + dh.*h.*(1-h) * v';
85 | % give the gradients the same order as the parameters
86 | dfdtheta = {dW, db_h, db_v};
87 | end
88 |
--------------------------------------------------------------------------------
/sfo_demo.py:
--------------------------------------------------------------------------------
1 | """
2 | Train an autoencoder using SFO.
3 |
4 | Demonstrates usage of the Sum of Functions Optimizer (SFO) Python
5 | package. See sfo.py and
6 | https://github.com/Sohl-Dickstein/Sum-of-Functions-Optimizer
7 | for additional documentation.
8 |
9 | Copyright 2014 Jascha Sohl-Dickstein
10 |
11 | Licensed under the Apache License, Version 2.0 (the "License");
12 | you may not use this file except in compliance with the License.
13 | You may obtain a copy of the License at
14 |
15 | http://www.apache.org/licenses/LICENSE-2.0
16 |
17 | Unless required by applicable law or agreed to in writing, software
18 | distributed under the License is distributed on an "AS IS" BASIS,
19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20 | See the License for the specific language governing permissions and
21 | limitations under the License.
22 | """
23 |
24 | import matplotlib.pyplot as plt
25 | import numpy as np
26 | from numpy.random import randn
27 | from sfo import SFO
28 |
29 | # define an objective function and gradient
30 | def f_df(theta, v):
31 | """
32 | Calculate reconstruction error and gradient for an autoencoder with sigmoid
33 | nonlinearity.
34 | v contains the training data, and will be different for each subfunction.
35 | """
36 | h = 1./(1. + np.exp(-(np.dot(theta['W'], v) + theta['b_h'])))
37 | v_hat = np.dot(theta['W'].T, h) + theta['b_v']
38 | f = np.sum((v_hat - v)**2) / v.shape[1]
39 | dv_hat = 2.*(v_hat - v) / v.shape[1]
40 | db_v = np.sum(dv_hat, axis=1).reshape((-1,1))
41 | dW = np.dot(h, dv_hat.T)
42 | dh = np.dot(theta['W'], dv_hat)
43 | db_h = np.sum(dh*h*(1.-h), axis=1).reshape((-1,1))
44 | dW += np.dot(dh*h*(1.-h), v.T)
45 | dfdtheta = {'W':dW, 'b_h':db_h, 'b_v':db_v}
46 | return f, dfdtheta
47 |
48 | # set model and training data parameters
49 | M = 20 # number visible units
50 | J = 10 # number hidden units
51 | D = 100000 # full data batch size
52 | N = int(np.sqrt(D)/10.) # number minibatches
53 | # generate random training data
54 | v = randn(M,D)
55 |
56 | # create the array of subfunction specific arguments
57 | sub_refs = []
58 | for i in range(N):
59 | # extract a single minibatch of training data.
60 | sub_refs.append(v[:,i::N])
61 |
62 | # initialize parameters
63 | theta_init = {'W':randn(J,M), 'b_h':randn(J,1), 'b_v':randn(M,1)}
64 | # initialize the optimizer
65 | optimizer = SFO(f_df, theta_init, sub_refs)
66 | # # uncomment the following line to test the gradient of f_df
67 | # optimizer.check_grad()
68 | # run the optimizer for 1 pass through the data
69 | theta = optimizer.optimize(num_passes=1)
70 | # continue running the optimizer for another 20 passes through the data
71 | theta = optimizer.optimize(num_passes=20)
72 |
73 | # plot the convergence trace
74 | plt.plot(np.array(optimizer.hist_f_flat))
75 | plt.xlabel('Iteration')
76 | plt.ylabel('Minibatch Function Value')
77 | plt.title('Convergence Trace')
78 | plt.show()
79 |
--------------------------------------------------------------------------------