├── .gitignore ├── README.md ├── kfac ├── README.md ├── autoencoders.py ├── curves.py ├── kfac.py ├── kfac_util.py └── mnist.py └── lec02 ├── core.py ├── sensitivity.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | 3 | *.pyc 4 | 5 | *.aux 6 | *.bbl 7 | *.blg 8 | *.log 9 | *.out 10 | *.nav 11 | *.pyg 12 | *.snm 13 | *.toc 14 | *.rel 15 | *.fdb_latexmk 16 | *.fls 17 | *.synctex.gz 18 | 19 | *.ipynb 20 | 21 | 22 | .DS_Store 23 | *.pygtex 24 | *.pygstyle 25 | 26 | kfac/digs3pts_1.mat 27 | 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSC2541 Code Examples 2 | 3 | This repository will contain JAX example code for my topics course, [CSC2541: Neural Net Training Dynamics](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/). So far, it contains an implementation of the original K-FAC algorithm, as well as the sensitivity analysis from Lecture 2. Additional examples from Chapters 2 and 3 will be added soon. 4 | -------------------------------------------------------------------------------- /kfac/README.md: -------------------------------------------------------------------------------- 1 | ## K-FAC Example 2 | 3 | This is a JAX reimplementation of some of the experiments from Martens and Grosse (2015), [Optimizing neural networks with Kronecker-factored Approximate Curvature](https://arxiv.org/abs/1503.05671), in particular the MNIST and Curves autoencoders. 4 | 5 | This code is provided for pedagogical purposes to my course [CSC2541: Neural Net Training Dynamics](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/). I do not recommend it for use in production or for rigorous empirical comparisons. This only includes the block diagonal (not block tridiagonal) version of K-FAC, and there are likely other differences as well. See [the original MATLAB code](https://www.cs.toronto.edu/~jmartens/docs/KFAC3-MATLAB.zip) for a faithful implementation of the original experiments. 6 | 7 | To run the MNIST autoencoder experiment: 8 | ``` 9 | python mnist.py 10 | ``` 11 | 12 | To run the Curves autoexperiment, first download the [Curves dataset](https://www.cs.toronto.edu/~jmartens/digs3pts_1.mat) to the code directory, and then run: 13 | ``` 14 | python curves.py 15 | ``` 16 | -------------------------------------------------------------------------------- /kfac/autoencoders.py: -------------------------------------------------------------------------------- 1 | from jax import nn, numpy as np 2 | from jax.experimental import stax 3 | import numpy as onp 4 | import time 5 | 6 | import kfac 7 | import kfac_util 8 | 9 | 10 | def get_architecture(input_size, layer_sizes): 11 | """Construct a sigmoid MLP autoencoder architecture with the given layer sizes. 12 | The code layer, given by the name 'code', is linear.""" 13 | 14 | layers = [] 15 | param_info = [] 16 | act_name = 'in' 17 | for name, lsize in layer_sizes: 18 | if name == 'code': 19 | # Code layer is special because it's linear 20 | param_info.append((act_name, name)) 21 | act_name = name 22 | 23 | layers.append((name, stax.Dense( 24 | lsize, W_init=kfac_util.sparse_init(), b_init=nn.initializers.zeros))) 25 | else: 26 | preact_name = name + 'z' 27 | param_info.append((act_name, preact_name)) 28 | act_name = name + 'a' 29 | 30 | layers.append((preact_name, stax.Dense( 31 | lsize, W_init=kfac_util.sparse_init(), b_init=nn.initializers.zeros))) 32 | layers.append((act_name, stax.elementwise(nn.sigmoid))) 33 | 34 | layers.append(('out', stax.Dense( 35 | input_size, W_init=kfac_util.sparse_init(), b_init=nn.initializers.zeros))) 36 | 37 | param_info.append((act_name, 'out')) 38 | param_info = tuple(param_info) 39 | 40 | net_init, net_apply = kfac_util.named_serial(*layers) 41 | 42 | in_shape = (-1, input_size) 43 | flatten, unflatten = kfac_util.get_flatten_fns(net_init, in_shape) 44 | 45 | return kfac_util.Architecture(net_init, net_apply, in_shape, flatten, unflatten, param_info) 46 | 47 | 48 | def default_config(): 49 | config = {} 50 | config['max_iter'] = 20000 51 | config['initial_batch_size'] = 1000 52 | config['final_batch_size_iter'] = 500 53 | config['batch_size_granularity'] = 50 54 | config['chunk_size'] = 5000 55 | 56 | config['cov_update_interval'] = 1 57 | config['cov_batch_ratio'] = 1/8 58 | config['cov_timescale'] = 20 59 | 60 | config['eig_update_interval'] = 20 61 | 62 | config['lambda_update_interval'] = 5 63 | config['init_lambda'] = 150 64 | config['lambda_drop'] = 0.95 65 | config['lambda_boost'] = 1 / config['lambda_drop'] 66 | config['lambda_min'] = 0 67 | config['lambda_max'] = onp.infty 68 | 69 | config['weight_cost'] = 1e-5 70 | 71 | config['gamma_update_interval'] = 20 72 | config['init_gamma'] = onp.sqrt(config['init_lambda'] + config['weight_cost']) 73 | config['gamma_drop'] = onp.sqrt(config['lambda_drop']) 74 | config['gamma_boost'] = 1 / config['gamma_drop'] 75 | config['gamma_max'] = 1 76 | config['gamma_min'] = onp.sqrt(config['weight_cost']) 77 | 78 | config['param_timescale'] = 100 79 | 80 | return config 81 | 82 | def squared_error(logits, T): 83 | """Compute the squared error. For consistency with James's code, don't 84 | rescale by 0.5.""" 85 | y = nn.sigmoid(logits) 86 | return np.sum((y-T)**2) 87 | 88 | def run_training(X_train, X_test, arch, config): 89 | nll_fn = kfac_util.BernoulliModel.nll_fn 90 | state = kfac.kfac_init(arch, kfac_util.BernoulliModel, X_train, X_train, config) 91 | for i in range(config['max_iter']): 92 | t0 = time.time() 93 | state = kfac.kfac_iter(state, arch, kfac_util.BernoulliModel, X_train, X_train, config) 94 | 95 | print('Step', i) 96 | print('Time:', time.time() - t0) 97 | print('Alpha:', state['coeffs'][0]) 98 | if i > 0: 99 | print('Beta:', state['coeffs'][1]) 100 | print('Quadratic decrease:', state['quad_dec']) 101 | 102 | if i % 20 == 0: 103 | print() 104 | cost = kfac.compute_cost(arch, nll_fn, state['w'], X_train, X_train, 105 | config['weight_cost'], config['chunk_size']) 106 | print('Training objective:', cost) 107 | cost = kfac.compute_cost( 108 | arch, nll_fn, state['w_avg'], X_train, X_train, 109 | config['weight_cost'], config['chunk_size']) 110 | print('Training objective (averaged):', cost) 111 | 112 | cost = kfac.compute_cost(arch, nll_fn, state['w'], X_test, X_test, 113 | config['weight_cost'], config['chunk_size']) 114 | print('Test objective:', cost) 115 | cost = kfac.compute_cost( 116 | arch, nll_fn, state['w_avg'], X_test, X_test, 117 | config['weight_cost'], config['chunk_size']) 118 | print('Test objective (averaged):', cost) 119 | 120 | print() 121 | cost = kfac.compute_cost(arch, squared_error, state['w'], X_train, X_train, 122 | 0., config['chunk_size']) 123 | print('Training error:', cost) 124 | cost = kfac.compute_cost(arch, squared_error, state['w_avg'], X_train, X_train, 125 | 0., config['chunk_size']) 126 | print('Training error (averaged):', cost) 127 | 128 | cost = kfac.compute_cost(arch, squared_error, state['w'], X_test, X_test, 129 | 0., config['chunk_size']) 130 | print('Test error:', cost) 131 | cost = kfac.compute_cost(arch, squared_error, state['w_avg'], X_test, X_test, 132 | 0., config['chunk_size']) 133 | print('Test error (averaged):', cost) 134 | print() 135 | 136 | 137 | if i % config['lambda_update_interval'] == 0: 138 | print('New lambda:', state['lambda']) 139 | if i % config['gamma_update_interval'] == 0: 140 | print('New gamma:', state['gamma']) 141 | print() 142 | 143 | 144 | -------------------------------------------------------------------------------- /kfac/curves.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | 3 | import autoencoders 4 | 5 | 6 | def get_architecture(): 7 | layer_sizes = [('enc1', 400), 8 | ('enc2', 200), 9 | ('enc3', 100), 10 | ('enc4', 50), 11 | ('enc5', 25), 12 | ('code', 6), 13 | ('dec1', 25), 14 | ('dec2', 50), 15 | ('dec3', 100), 16 | ('dec4', 200), 17 | ('dec5', 400)] 18 | 19 | return autoencoders.get_architecture(784, layer_sizes) 20 | 21 | def get_config(): 22 | return autoencoders.default_config() 23 | 24 | def run(): 25 | try: 26 | obj = scipy.io.loadmat('digs3pts_1.mat') 27 | except: 28 | print("To run this script, first download https://www.cs.toronto.edu/~jmartens/digs3pts_1.mat to this directory.") 29 | 30 | X_train = obj['bdata'] 31 | X_test = obj['bdatatest'] 32 | 33 | config = get_config() 34 | arch = get_architecture() 35 | 36 | autoencoders.run_training(X_train, X_test, arch, config) 37 | 38 | 39 | 40 | if __name__ == '__main__': 41 | run() 42 | 43 | 44 | -------------------------------------------------------------------------------- /kfac/kfac.py: -------------------------------------------------------------------------------- 1 | from jax import grad, jit, numpy as np, random, vjp, jvp 2 | from jax.scipy.linalg import eigh 3 | import numpy as onp 4 | 5 | 6 | import kfac_util 7 | 8 | 9 | def L2_penalty(arch, w): 10 | # FIXME: don't regularize the biases 11 | return 0.5 * np.sum(w**2) 12 | 13 | def get_batch_size(step, ndata, config): 14 | """Exponentially increasing batch size schedule.""" 15 | step = np.floor(step/config['batch_size_granularity']) * config['batch_size_granularity'] 16 | pwr = onp.minimum(step / (config['final_batch_size_iter']-1), 1.) 17 | return onp.floor(config['initial_batch_size'] * (ndata / config['initial_batch_size'])**pwr).astype(onp.uint32) 18 | 19 | def get_sample_batch_size(batch_size, config): 20 | """Batch size to use for sampling the activation statistics.""" 21 | return onp.ceil(config['cov_batch_ratio'] * batch_size).astype(onp.uint32) 22 | 23 | def get_chunks(batch_size, chunk_size): 24 | """Iterator that breaks a range into smaller chunks. Useful for simulating 25 | larger batches than can fit on the GPU.""" 26 | start = 0 27 | 28 | while start < batch_size: 29 | end = min(start+chunk_size, batch_size) 30 | yield slice(start, end) 31 | start = end 32 | 33 | 34 | def make_instrumented_vjp(apply_fn, params, inputs): 35 | """Returns a function which takes in the output layer gradients and returns a dict 36 | containing the gradients for all the intermediate layers.""" 37 | dummy_input = np.zeros((2,) + inputs.shape[1:]) 38 | _, dummy_activations = apply_fn(params, dummy_input, ret_all=True) 39 | 40 | batch_size = inputs.shape[0] 41 | add_to = {name: np.zeros((batch_size,) + dummy_activations[name].shape[1:]) 42 | for name in dummy_activations} 43 | apply_wrap = lambda a: apply_fn(params, inputs, a, ret_all=True) 44 | primals_out, vjp_fn, activations = vjp(apply_wrap, add_to, has_aux=True) 45 | return primals_out, vjp_fn, activations 46 | 47 | 48 | def estimate_covariances_chunk(apply_fn, param_info, output_model, net_params, X_chunk, rng): 49 | """Compute the empirical covariances on a chunk of data.""" 50 | logits, vjp_fn, activations = make_instrumented_vjp(apply_fn, net_params, X_chunk) 51 | key, rng = random.split(rng) 52 | output_grads = output_model.sample_grads_fn(logits, key) 53 | act_grads = vjp_fn(output_grads)[0] 54 | 55 | A = {} 56 | G = {} 57 | for in_name, out_name in param_info: 58 | a = activations[in_name] 59 | a_hom = np.hstack([a, np.ones((a.shape[0], 1))]) 60 | A[in_name] = a_hom.T @ a_hom 61 | 62 | ds = act_grads[out_name] 63 | G[out_name] = ds.T @ ds 64 | 65 | return A, G 66 | 67 | estimate_covariances_chunk = jit(estimate_covariances_chunk, static_argnums=(0,1,2)) 68 | 69 | 70 | def estimate_covariances(arch, output_model, w, X, rng, chunk_size): 71 | """Compute the empirical covariances on a batch of data.""" 72 | batch_size = X.shape[0] 73 | net_params = arch.unflatten(w) 74 | A_sum = {in_name: 0. for in_name, out_name in arch.param_info} 75 | G_sum = {out_name: 0. for in_name, out_name in arch.param_info} 76 | for chunk_idxs in get_chunks(batch_size, chunk_size): 77 | X_chunk = X[chunk_idxs,:] 78 | key, rng = random.split(rng) 79 | 80 | A_curr, G_curr = estimate_covariances_chunk( 81 | arch.net_apply, arch.param_info, output_model, net_params, X_chunk, key) 82 | A_sum = {name: A_sum[name] + A_curr[name] for name in A_sum} 83 | G_sum = {name: G_sum[name] + G_curr[name] for name in G_sum} 84 | 85 | A_mean = {name: A_sum[name] / batch_size for name in A_sum} 86 | G_mean = {name: G_sum[name] / batch_size for name in G_sum} 87 | 88 | return A_mean, G_mean 89 | 90 | def update_covariances(A, G, arch, output_model, w, X, rng, cov_timescale, chunk_size): 91 | """Exponential moving average of the covariances.""" 92 | A, G = dict(A), dict(G) 93 | curr_A, curr_G = estimate_covariances(arch, output_model, w, X, rng, chunk_size) 94 | ema_param = kfac_util.get_ema_param(cov_timescale) 95 | for k in A.keys(): 96 | A[k] = ema_param * A[k] + (1-ema_param) * curr_A[k] 97 | for k in G.keys(): 98 | G[k] = ema_param * G[k] + (1-ema_param) * curr_G[k] 99 | return A, G 100 | 101 | def compute_pi(A, G): 102 | return np.sqrt((np.trace(A) * G.shape[0]) / (A.shape[0] * np.trace(G))) 103 | 104 | def compute_inverses(arch, A, G, gamma): 105 | A_inv, G_inv = {}, {} 106 | for in_name, out_name in arch.param_info: 107 | pi = compute_pi(A[in_name], G[out_name]) 108 | 109 | A_damp = gamma * pi 110 | A_inv[in_name] = np.linalg.inv(A[in_name] + A_damp * np.eye(A.shape[0])) 111 | 112 | G_damp = gamma / pi 113 | G_inv[out_name] = np.linalg.inv(G[out_name] + G_damp * np.eye(G.shape[0])) 114 | 115 | return A_inv, G_inv 116 | 117 | def compute_eigs(arch, A, G): 118 | A_eig, G_eig, pi = {}, {}, {} 119 | for in_name, out_name in arch.param_info: 120 | A_eig[in_name] = eigh(A[in_name]) 121 | G_eig[out_name] = eigh(G[out_name]) 122 | pi[out_name] = compute_pi(A[in_name], G[out_name]) 123 | return A_eig, G_eig, pi 124 | 125 | def nll_cost(apply_fn, nll_fn, unflatten_fn, w, X, T): 126 | logits = apply_fn(unflatten_fn(w), X) 127 | return nll_fn(logits, T) 128 | 129 | nll_cost = jit(nll_cost, static_argnums=(0, 1, 2)) 130 | grad_nll_cost = jit(grad(nll_cost, 3), static_argnums=(0, 1, 2)) 131 | 132 | def compute_cost(arch, nll_fn, w, X, T, weight_cost, chunk_size): 133 | batch_size = X.shape[0] 134 | total = 0 135 | 136 | for chunk_idxs in get_chunks(batch_size, chunk_size): 137 | X_chunk, T_chunk = X[chunk_idxs, :], T[chunk_idxs, :] 138 | total += nll_cost(arch.net_apply, nll_fn, arch.unflatten, 139 | w, X_chunk, T_chunk) 140 | 141 | return total / batch_size + weight_cost * L2_penalty(arch, w) 142 | 143 | def compute_gradient(arch, output_model, w, X, T, weight_cost, chunk_size): 144 | batch_size = X.shape[0] 145 | grad_w = 0 146 | 147 | for chunk_idxs in get_chunks(batch_size, chunk_size): 148 | X_chunk, T_chunk = X[chunk_idxs, :], T[chunk_idxs, :] 149 | 150 | grad_w += grad_nll_cost(arch.net_apply, output_model.nll_fn, arch.unflatten, 151 | w, X_chunk, T_chunk) 152 | 153 | grad_w /= batch_size 154 | grad_w += weight_cost * grad(L2_penalty, 1)(arch, w) 155 | return grad_w 156 | 157 | def compute_natgrad_from_inverses(arch, grad_w, A_inv, G_inv): 158 | param_grad = arch.unflatten(grad_w) 159 | natgrad = {} 160 | for in_name, out_name in arch.param_info: 161 | grad_W, grad_b = param_grad[out_name] 162 | grad_Wb = np.vstack([grad_W, grad_b.reshape((1, -1))]) 163 | 164 | natgrad_Wb = A_inv[in_name] @ grad_Wb @ G_inv[out_name] 165 | 166 | natgrad_W, natgrad_b = natgrad_Wb[:-1, :], natgrad_Wb[-1, :] 167 | natgrad[out_name] = (natgrad_W, natgrad_b) 168 | return arch.flatten(natgrad) 169 | 170 | def compute_natgrad_from_eigs_helper(param_info, param_grad, A_eig, G_eig, pi, gamma): 171 | natgrad = {} 172 | for in_name, out_name in param_info: 173 | grad_W, grad_b = param_grad[out_name] 174 | grad_Wb = np.vstack([grad_W, grad_b.reshape((1, -1))]) 175 | 176 | A_d, A_Q = A_eig[in_name] 177 | G_d, G_Q = G_eig[out_name] 178 | 179 | # rotate into Kronecker eigenbasis 180 | grad_rot = A_Q.T @ grad_Wb @ G_Q 181 | 182 | # add damping and divide 183 | denom = np.outer(A_d + gamma * pi[out_name], 184 | G_d + gamma / pi[out_name]) 185 | natgrad_rot = grad_rot / denom 186 | 187 | # rotate back to the original basis 188 | natgrad_Wb = A_Q @ natgrad_rot @ G_Q.T 189 | 190 | natgrad_W, natgrad_b = natgrad_Wb[:-1, :], natgrad_Wb[-1, :] 191 | natgrad[out_name] = (natgrad_W, natgrad_b) 192 | return natgrad 193 | 194 | compute_natgrad_from_eigs_helper = jit(compute_natgrad_from_eigs_helper, static_argnums=(0,)) 195 | 196 | def compute_natgrad_from_eigs(arch, grad_w, A_eig, G_eig, pi, gamma): 197 | param_grad = arch.unflatten(grad_w) 198 | natgrad = compute_natgrad_from_eigs_helper( 199 | arch.param_info, param_grad, A_eig, G_eig, pi, gamma) 200 | return arch.flatten(natgrad) 201 | 202 | def compute_A_chunk(apply_fn, nll_fn, unflatten_fn, w, X, T, dirs, grad_w): 203 | ndir = len(dirs) 204 | predict_wrap = lambda w: apply_fn(unflatten_fn(w), X) 205 | 206 | RY, RgY = [], [] 207 | for v in dirs: 208 | Y, RY_ = jvp(predict_wrap, (w,), (v,)) 209 | nll_wrap = lambda Y: nll_fn(Y, T) 210 | RgY_ = kfac_util.hvp(nll_wrap, Y, RY_) 211 | RY.append(RY_) 212 | RgY.append(RgY_) 213 | 214 | A = np.array([[onp.sum(RY[i] * RgY[j]) 215 | for j in range(ndir)] 216 | for i in range(ndir)]) 217 | 218 | return A 219 | 220 | compute_A_chunk = jit(compute_A_chunk, static_argnums=(0, 1, 2)) 221 | 222 | 223 | def compute_step_coeffs(arch, output_model, w, X, T, dirs, grad_w, 224 | weight_cost, lmbda, chunk_size): 225 | """Compute the coefficients alpha and beta which minimize the quadratic 226 | approximation to the cost in the update: 227 | 228 | new_update = sum of coeffs[i] * dirs[i] 229 | 230 | Note that, unlike the rest of the K-FAC algorithm, this function assumes 231 | the loss function is negative log-likelihood for an exponential family. 232 | (This is because it relies on the Fisher information matrix approximating 233 | the Hessian of the NLL.) 234 | """ 235 | ndir = len(dirs) 236 | 237 | # First, compute the "function space" portion of the quadratic approximation. 238 | # This is based on the Gauss-Newton approximation to the NLL, or equivalently, 239 | # the Fisher information matrix. 240 | 241 | A_func = onp.zeros((ndir, ndir)) 242 | batch_size = X.shape[0] 243 | for chunk_idxs in get_chunks(batch_size, chunk_size): 244 | X_chunk, T_chunk = X[chunk_idxs, :], T[chunk_idxs, :] 245 | 246 | A_func += compute_A_chunk(arch.net_apply, output_model.nll_fn, arch.unflatten, 247 | w, X_chunk, T_chunk, dirs, grad_w) 248 | 249 | A_func /= batch_size 250 | 251 | # Now compute the weight space terms, which include both the Hessian of the 252 | # L2 regularizer and the damping term. This is almost a multiple of the 253 | # identity matrix, except that the L2 penalty only applies to weights, not 254 | # biases. Hence, we need to apply a mask to zero out the entries corresponding 255 | # to biases. This can be done using a Hessian-vector product with the L2 256 | # regularizer, which has the added benefit that the solution generalizes 257 | # to non-uniform L2 regularizers as well. 258 | 259 | wrap = lambda w: L2_penalty(arch, w) 260 | Hv = [kfac_util.hvp(wrap, w, v) for v in dirs] 261 | A_L2 = onp.array([[weight_cost * Hv[i] @ dirs[j] 262 | for i in range(ndir)] 263 | for j in range(ndir)]) 264 | A_prox = onp.array([[lmbda * dirs[i] @ dirs[j] 265 | for i in range(ndir)] 266 | for j in range(ndir)]) 267 | A = A_func + A_L2 + A_prox 268 | 269 | # The linear term is much simpler: it's just the dot product with the gradient. 270 | b = onp.array([v @ grad_w for v in dirs]) 271 | 272 | # Minimize the quadratic approximation by solving the linear system. 273 | coeffs = onp.linalg.solve(A, -b) 274 | 275 | # The decrease in the quadratic objective is used to adapt lambda. 276 | quad_decrease = -0.5 * coeffs @ A @ coeffs - b @ coeffs 277 | 278 | return coeffs, quad_decrease 279 | 280 | def compute_update(coeffs, dirs): 281 | ans = 0 282 | for coeff, v in zip(coeffs, dirs): 283 | ans = ans + coeff * v 284 | return ans 285 | 286 | def update_gamma(state, arch, output_model, X, T, config): 287 | curr_gamma = state['gamma'] 288 | gamma_less = onp.maximum( 289 | curr_gamma * config['gamma_drop']**config['gamma_update_interval'], 290 | config['gamma_min']) 291 | gamma_more = onp.minimum( 292 | curr_gamma * config['gamma_boost']**config['gamma_update_interval'], 293 | config['gamma_min']) 294 | gammas = [gamma_less, curr_gamma, gamma_more] 295 | 296 | grad_w = compute_gradient( 297 | arch, output_model, state['w'], X, T, config['weight_cost'], 298 | config['chunk_size']) 299 | 300 | results = [] 301 | for gamma in gammas: 302 | natgrad_w = compute_natgrad_from_eigs( 303 | arch, grad_w, state['A_eig'], state['G_eig'], state['pi'], gamma) 304 | 305 | prev_update = state['update'] 306 | coeffs, _ = compute_step_coeffs( 307 | arch, output_model, state['w'], X, T, [-natgrad_w, prev_update], 308 | grad_w, config['weight_cost'], state['lambda'], config['chunk_size']) 309 | update = compute_update(coeffs, [-natgrad_w, prev_update]) 310 | new_w = state['w'] + update 311 | 312 | results.append(compute_cost( 313 | arch, output_model.nll_fn, new_w, X, T, config['weight_cost'], 314 | config['chunk_size'])) 315 | 316 | best_idx = onp.argmin(results) 317 | return gammas[best_idx] 318 | 319 | def update_lambda(arch, output_model, lmbda, old_w, new_w, X, T, quad_dec, config): 320 | old_cost = compute_cost( 321 | arch, output_model.nll_fn, old_w, X, T, config['weight_cost'], config['chunk_size']) 322 | new_cost = compute_cost( 323 | arch, output_model.nll_fn, new_w, X, T, config['weight_cost'], config['chunk_size']) 324 | rho = (old_cost - new_cost) / quad_dec 325 | 326 | if np.isnan(rho) or rho < 0.25: 327 | new_lambda = np.minimum( 328 | lmbda * config['lambda_boost']**config['lambda_update_interval'], 329 | config['lambda_max']) 330 | elif rho > 0.75: 331 | new_lambda = np.maximum( 332 | lmbda * config['lambda_drop']**config['lambda_update_interval'], 333 | config['lambda_min']) 334 | else: 335 | new_lambda = lmbda 336 | 337 | return new_lambda, rho 338 | 339 | 340 | def kfac_init(arch, output_model, X_train, T_train, config, random_seed=0): 341 | state = {} 342 | 343 | state['step'] = 0 344 | state['rng'] = random.PRNGKey(random_seed) 345 | 346 | state['gamma'] = config['init_gamma'] 347 | state['lambda'] = config['init_lambda'] 348 | 349 | key, state['rng'] = random.split(state['rng']) 350 | _, params = arch.net_init(key, X_train.shape) 351 | state['w'] = arch.flatten(params) 352 | state['w_avg'] = state['w'] 353 | 354 | key, state['rng'] = random.split(state['rng']) 355 | state['A'], state['G'] = estimate_covariances( 356 | arch, output_model, state['w'], X_train, key, config['chunk_size']) 357 | 358 | state['A_eig'], state['G_eig'], state['pi'] = compute_eigs( 359 | arch, state['A'], state['G']) 360 | 361 | return state 362 | 363 | 364 | def kfac_iter(state, arch, output_model, X_train, T_train, config): 365 | old_state = state 366 | state = dict(state) # shallow copy 367 | 368 | state['step'] += 1 369 | 370 | ndata = X_train.shape[0] 371 | batch_size = get_batch_size(state['step'], ndata, config) 372 | 373 | # Sample with replacement 374 | key, state['rng'] = random.split(state['rng']) 375 | idxs = random.permutation(key, np.arange(ndata))[:batch_size] 376 | X_batch, T_batch = X_train[idxs, :], T_train[idxs, :] 377 | 378 | # Update statistics by running backprop on the sampled targets 379 | if state['step'] % config['cov_update_interval'] == 0: 380 | batch_size_samp = get_sample_batch_size(batch_size, config) 381 | X_samp = X_batch[:batch_size_samp, :] 382 | state['A'], state['G'] = update_covariances( 383 | state['A'], state['G'], arch, output_model, state['w'], X_samp, state['rng'], 384 | config['cov_timescale'], config['chunk_size']) 385 | 386 | # Update the inverses 387 | if state['step'] % config['eig_update_interval'] == 0: 388 | state['A_eig'], state['G_eig'], state['pi'] = compute_eigs( 389 | arch, state['A'], state['G']) 390 | 391 | # Update gamma 392 | if state['step'] % config['gamma_update_interval'] == 0: 393 | state['gamma'] = update_gamma(state, arch, output_model, X_batch, T_batch, config) 394 | 395 | # Compute the gradient and approximate natural gradient 396 | grad_w = compute_gradient( 397 | arch, output_model, state['w'], X_batch, T_batch, 398 | config['weight_cost'], config['chunk_size']) 399 | natgrad_w = compute_natgrad_from_eigs( 400 | arch, grad_w, state['A_eig'], state['G_eig'], state['pi'], state['gamma']) 401 | 402 | # Determine the step size parameters using MVPs 403 | if 'update' in state: 404 | prev_update = state['update'] 405 | dirs = [-natgrad_w, prev_update] 406 | else: 407 | dirs = [-natgrad_w] 408 | state['coeffs'], state['quad_dec'] = compute_step_coeffs( 409 | arch, output_model, state['w'], X_batch, T_batch, dirs, 410 | grad_w, config['weight_cost'], state['lambda'], config['chunk_size']) 411 | state['update'] = compute_update(state['coeffs'], dirs) 412 | state['w'] = state['w'] + state['update'] 413 | 414 | # Update lambda 415 | if state['step'] % config['lambda_update_interval'] == 0: 416 | state['lambda'], state['rho'] = update_lambda( 417 | arch, output_model, state['lambda'], old_state['w'], state['w'], X_batch, 418 | T_batch, state['quad_dec'], config) 419 | 420 | # Iterate averaging 421 | ema_param = kfac_util.get_ema_param(config['param_timescale']) 422 | state['w_avg'] = ema_param * state['w_avg'] + (1-ema_param) * state['w'] 423 | 424 | return state 425 | 426 | 427 | 428 | -------------------------------------------------------------------------------- /kfac/kfac_util.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from jax import grad, numpy as np, random, jvp, nn 3 | from jax.flatten_util import ravel_pytree 4 | from jax.ops import index_update 5 | from jax.tree_util import tree_map 6 | 7 | 8 | 9 | def named_serial(*layers): 10 | # based on jax.experimental.stax.serial 11 | nlayers = len(layers) 12 | names, fns = zip(*layers) 13 | init_fns, apply_fns = zip(*fns) 14 | output_name = names[-1] 15 | 16 | def init_fn(rng, input_shape): 17 | params = {} 18 | for name, init_fn in zip(names, init_fns): 19 | rng, layer_rng = random.split(rng) 20 | input_shape, param = init_fn(layer_rng, input_shape) 21 | params[name] = param 22 | return input_shape, params 23 | 24 | def apply_fn(params, inputs, add_to={}, ret_all=False, **kwargs): 25 | rng = kwargs.pop('rng', None) 26 | rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers 27 | result = {'in': inputs} 28 | for fun, name, rng in zip(apply_fns, names, rngs): 29 | inputs = fun(params[name], inputs, rng=rng, **kwargs) 30 | if name in add_to: 31 | inputs = inputs + add_to[name] 32 | if ret_all: 33 | result[name] = inputs 34 | 35 | if ret_all: 36 | return inputs, result 37 | else: 38 | return inputs 39 | 40 | return init_fn, apply_fn 41 | 42 | Architecture = namedtuple('Architecture', ['net_init', 'net_apply', 'in_shape', 43 | 'flatten', 'unflatten', 'param_info']) 44 | 45 | OutputModel = namedtuple('OutputModel', ['nll_fn', 'sample_grads_fn']) 46 | 47 | 48 | def bernoulli_nll(logits, T): 49 | """Compute the sum (not the mean) of the losses on a batch.""" 50 | log_p = -np.logaddexp(0, -logits) 51 | log_1_minus_p = -np.logaddexp(0, logits) 52 | return -np.sum(T * log_p + (1-T) * log_1_minus_p) 53 | 54 | def bernoulli_sample_grads(logits, key): 55 | """Sample a vector whose covariance is the output layer metric (i.e. Fisher information).""" 56 | Y = nn.sigmoid(logits) 57 | T = random.bernoulli(key, Y) 58 | return Y - T 59 | 60 | BernoulliModel = OutputModel(bernoulli_nll, bernoulli_sample_grads) 61 | 62 | 63 | 64 | def make_float64(params): 65 | return tree_map(lambda x: x.astype(np.float64), params) 66 | 67 | def get_flatten_fns(init_fn, in_shape, float64=False): 68 | rng = random.PRNGKey(0) 69 | _, dummy_params = init_fn(rng, in_shape) 70 | if float64: 71 | dummy_params = make_float64(dummy_params) 72 | _, unflatten = ravel_pytree(dummy_params) 73 | def flatten(p): 74 | return ravel_pytree(p)[0] 75 | return flatten, unflatten 76 | 77 | def hvp(J, w, v): 78 | return jvp(grad(J), (w,), (v,))[1] 79 | 80 | def get_ema_param(timescale): 81 | return 1 - 1 / timescale 82 | 83 | 84 | def sparse_init(num_conn=15, stdev=1.): 85 | def init(rng, shape): 86 | k1, k2 = random.split(rng) 87 | in_dim, out_dim = shape 88 | num_conn_ = np.minimum(num_conn, in_dim) 89 | W = np.zeros(shape) 90 | row_idxs = np.outer(np.arange(in_dim), np.ones(out_dim)).astype(np.uint32) 91 | row_idxs = random.shuffle(k1, row_idxs)[:num_conn_, :].ravel() 92 | col_idxs = np.outer(np.ones(num_conn_), np.arange(out_dim)).astype(np.uint32).ravel() 93 | vals = random.normal(k2, shape=(num_conn_*out_dim,)) * stdev 94 | return index_update(W, (row_idxs, col_idxs), vals) 95 | return init 96 | 97 | -------------------------------------------------------------------------------- /kfac/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow_datasets as tfds 3 | 4 | import autoencoders 5 | 6 | 7 | 8 | ## def MNISTArchitecture(): 9 | ## net_init, net_apply = named_serial( 10 | ## ('enc1z', Dense(1000, W_init=sparse_init(), b_init=initializers.zeros)), 11 | ## ('enc1a', elementwise(nn.sigmoid)), 12 | ## ('enc2z', Dense(500, W_init=sparse_init(), b_init=initializers.zeros)), 13 | ## ('enc2a', elementwise(nn.sigmoid)), 14 | ## ('enc3z', Dense(250, W_init=sparse_init(), b_init=initializers.zeros)), 15 | ## ('enc3a', elementwise(nn.sigmoid)), 16 | ## ('code', Dense(30, W_init=sparse_init(), b_init=initializers.zeros)), 17 | ## ('dec1z', Dense(250, W_init=sparse_init(), b_init=initializers.zeros)), 18 | ## ('dec1a', elementwise(nn.sigmoid)), 19 | ## ('dec2z', Dense(500, W_init=sparse_init(), b_init=initializers.zeros)), 20 | ## ('dec2a', elementwise(nn.sigmoid)), 21 | ## ('dec3z', Dense(1000, W_init=sparse_init(), b_init=initializers.zeros)), 22 | ## ('dec3a', elementwise(nn.sigmoid)), 23 | ## ('out', Dense(784, W_init=sparse_init(), b_init=initializers.zeros)), 24 | ## ) 25 | ## param_info = (('in', 'enc1z'), 26 | ## ('enc1a', 'enc2z'), 27 | ## ('enc2a', 'enc3z'), 28 | ## ('enc3a', 'code'), 29 | ## ('code', 'dec1z'), 30 | ## ('dec1a', 'dec2z'), 31 | ## ('dec2a', 'dec3z'), 32 | ## ('dec3a', 'out') 33 | ## ) 34 | ## in_shape=(-1, 784) 35 | ## flatten, unflatten = get_flatten_fns(net_init, in_shape) 36 | ## return Architecture(net_init, net_apply, in_shape, flatten, unflatten, param_info) 37 | 38 | def get_architecture(): 39 | layer_sizes = [('enc1', 1000), 40 | ('enc2', 500), 41 | ('enc3', 250), 42 | ('code', 30), 43 | ('dec1', 250), 44 | ('dec2', 500), 45 | ('dec3', 1000)] 46 | 47 | return autoencoders.get_architecture(784, layer_sizes) 48 | 49 | 50 | def get_config(): 51 | return autoencoders.default_config() 52 | 53 | 54 | 55 | 56 | def run(): 57 | mnist_data, info = tfds.load(name="mnist", batch_size=-1, with_info=True) 58 | mnist_data = tfds.as_numpy(mnist_data) 59 | train_data, test_data = mnist_data['train'], mnist_data['test'] 60 | X_train = train_data['image'].reshape((-1, 784)).astype(np.float32) / 255 61 | X_test = test_data['image'].reshape((-1, 784)).astype(np.float32) / 255 62 | 63 | config = get_config() 64 | arch = get_architecture() 65 | 66 | autoencoders.run_training(X_train, X_test, arch, config) 67 | 68 | if __name__ == '__main__': 69 | run() 70 | 71 | 72 | -------------------------------------------------------------------------------- /lec02/core.py: -------------------------------------------------------------------------------- 1 | from jax.config import config 2 | config.update("jax_enable_x64", True) 3 | from jax import grad, jvp, vjp 4 | import scipy.sparse 5 | 6 | 7 | def hvp(J, w, v): 8 | return jvp(grad(J), (w,), (v,))[1] 9 | 10 | def gnhvp(f, L, w, v): 11 | y, R_y = jvp(f, (w,), (v,)) 12 | R_gy = hvp(L, y, R_y) 13 | _, f_vjp = vjp(f, w) 14 | return f_vjp(R_gy)[0] 15 | 16 | def approx_solve(A_mvp, b, niter): 17 | dim = b.size 18 | A_linop = scipy.sparse.linalg.LinearOperator((dim,dim), matvec=A_mvp) 19 | res = scipy.sparse.linalg.cg(A_linop, b, maxiter=niter) 20 | return res[0] 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /lec02/sensitivity.py: -------------------------------------------------------------------------------- 1 | from jax.config import config 2 | config.update("jax_enable_x64", True) 3 | from jax import grad, jvp, jit, numpy as np, random 4 | from jax.experimental.stax import Dense, Tanh 5 | from matplotlib import pyplot as plt 6 | import numpy as onp 7 | from collections import namedtuple 8 | 9 | from core import gnhvp, approx_solve 10 | from util import named_serial, make_float64, get_flatten_fns 11 | 12 | 13 | ######################################### 14 | ## Implementation of response Jacobian ## 15 | ######################################### 16 | 17 | def mixed_second_mvp(J_param, w, t, R_t): 18 | grad_cost = grad(J_param, 0) # gradient w.r.t. w 19 | grad_cost_t = lambda t: grad_cost(w, t) 20 | return jvp(grad_cost_t, (t,), (R_t,))[1] # forward-over-reverse 21 | 22 | def dampen(mvp, lam): 23 | def new_mvp(x): 24 | return mvp(x) + lam*x 25 | return new_mvp 26 | 27 | def approx_solve_H(f_param, L_param, w, phi, Rg_w, lam, niter): 28 | mvp = lambda v: gnhvp(lambda w: f_param(w, phi), 29 | lambda y: L_param(y, phi), w, v) 30 | mvp_damp = dampen(mvp, lam) 31 | return approx_solve(mvp_damp, Rg_w, niter) 32 | 33 | def response_jacobian_vector_product(f_param, L_param, w, phi, R_phi, lam, niter): 34 | def J_param(w, phi): 35 | return L_param(f_param(w, phi), phi) 36 | Rg_w = mixed_second_mvp(J_param, w, phi, R_phi) 37 | return approx_solve_H(f_param, L_param, w, phi, -Rg_w, lam, niter) 38 | 39 | 40 | ######################################## 41 | ## 1-D Regression example ## 42 | ######################################## 43 | 44 | 45 | Architecture = namedtuple('Architecture', ['net_init', 'net_apply', 'in_shape', 46 | 'flatten', 'unflatten']) 47 | 48 | def ToyMLP(): 49 | net_init, net_apply = named_serial( 50 | ('z1', Dense(256)), 51 | ('h1', Tanh), 52 | ('z2', Dense(256)), 53 | ('h2', Tanh), 54 | ('y', Dense(1))) 55 | in_shape = (-1, 1) 56 | flatten, unflatten = get_flatten_fns(net_init, in_shape) 57 | return Architecture(net_init, net_apply, in_shape, flatten, unflatten) 58 | 59 | 60 | def f_net(arch, w, x): 61 | x_in = x.reshape((-1, 1)) 62 | return arch.net_apply(arch.unflatten(w), x_in).ravel() 63 | def L(y, t): 64 | return 0.5 * np.sum((y-t)**2) 65 | 66 | def make_parameterized_cost(arch, x, L): 67 | """Make a cost function parameterized by a vector phi. Here, the cost 68 | is squared error, and phi represents the targets.""" 69 | def f_param(w, phi): 70 | return f_net(arch, w, x) 71 | def L_param(y, phi): 72 | return 0.5 * np.sum((y - phi)**2) 73 | return f_param, L_param 74 | 75 | 76 | def generate_toy_data2(): 77 | x1 = onp.random.uniform(-5, -2, size=50) 78 | fx1 = onp.sin(2*x1) - 2 79 | x2 = onp.random.uniform(2, 5, size=50) 80 | fx2 = onp.sin(2*x2) + 2 81 | x3 = np.array([0.]) 82 | fx3 = np.array([-1.]) 83 | x = onp.concatenate([x1, x2, x3]) 84 | fx = onp.concatenate([fx1, fx2, fx3]) 85 | t = onp.random.normal(fx, 0.5) 86 | return x, t 87 | 88 | def train_toy_network(): 89 | onp.random.seed(0) 90 | 91 | x, t = generate_toy_data2() 92 | x *= 0.2 93 | arch = ToyMLP() 94 | rng = random.PRNGKey(0) 95 | ALPHA = 1e-1 96 | 97 | out_shape, net_params = arch.net_init(rng, arch.in_shape) 98 | w_init = make_float64(arch.flatten(net_params)) 99 | 100 | def train_obj(w): 101 | net_params = arch.unflatten(w) 102 | x_in = x.reshape((-1, 1)) 103 | y = arch.net_apply(net_params, x_in).ravel() 104 | 105 | return 0.5 * np.mean((y-t)**2) 106 | 107 | grad_train_obj = jit(grad(train_obj)) 108 | 109 | w_curr = w_init.copy() 110 | for i in range(10000): 111 | w_curr -= ALPHA * grad_train_obj(w_curr) 112 | w_opt = w_curr.copy() 113 | 114 | plt.figure() 115 | plt.plot(x, t, 'bx') 116 | 117 | x_in = np.linspace(-1, 1, 100).reshape((-1, 1)) 118 | y = arch.net_apply(arch.unflatten(w_opt), x_in) 119 | plt.plot(x_in.ravel(), y, 'r-') 120 | 121 | return x, t, w_opt 122 | 123 | def make_figures(x, t, w_opt, idx): 124 | """Generate the sensitivity analysis figures. The argument idx is the index 125 | of the training example to perturb. The indices used in the figure are 0 126 | and 100 (the outlier).""" 127 | LAM = 1e-3 128 | OFFSET = 5 129 | NITER_VALS = [1, 2, 5, 10, 20, 50] 130 | 131 | arch = ToyMLP() 132 | 133 | R_t = onp.zeros(t.shape, dtype=onp.float64) 134 | R_t[idx] = OFFSET 135 | 136 | plt.figure() 137 | plt.plot(x, t, 'bx', alpha=0.5) 138 | plt.ylim(-4, 10) 139 | plt.xticks([]) 140 | plt.yticks([]) 141 | 142 | x_in = np.linspace(-1, 1, 100).reshape((-1, 1)) 143 | y = arch.net_apply(arch.unflatten(w_opt), x_in).ravel() 144 | plt.plot(x_in.ravel(), y, 'r-') 145 | 146 | plt.figure() 147 | plt.plot(x, t, 'bx', alpha=0.3) 148 | plt.plot(x[idx], t[idx]+OFFSET, 'rx', ms=10) 149 | plt.plot(x[idx], t[idx], 'gx', ms=10) 150 | plt.ylim(-4, 10) 151 | plt.xticks([]) 152 | plt.yticks([]) 153 | 154 | x_in = np.linspace(-1, 1, 100).reshape((-1, 1)) 155 | y = arch.net_apply(arch.unflatten(w_opt), x_in).ravel() 156 | 157 | f_param, L_param = make_parameterized_cost(arch, x, L) 158 | 159 | y_list = [y] 160 | for niter in NITER_VALS: 161 | R_w = response_jacobian_vector_product(f_param, L_param, w_opt, t, R_t, LAM, niter) 162 | R_y = jvp(lambda w: f_net(arch, w, x_in), (w_opt,), (R_w,))[1] 163 | y_list.append(y + R_y) 164 | 165 | labels = ['0'] + [str(i) for i in NITER_VALS] 166 | for i, curr_y in enumerate(y_list): 167 | r = i / (len(y_list) - 1) 168 | plt.plot(x_in.ravel(), curr_y, color=(r, 1-r, 0), alpha=0.5, label=labels[i]) 169 | 170 | plt.legend(loc='upper left') 171 | 172 | def run(): 173 | x, t, w_opt = train_toy_network() 174 | make_figures(x, t, w_opt, 0) 175 | make_figures(x, t, w_opt, 100) 176 | 177 | 178 | -------------------------------------------------------------------------------- /lec02/util.py: -------------------------------------------------------------------------------- 1 | from jax.config import config 2 | config.update("jax_enable_x64", True) 3 | from jax import grad, jvp, vjp, jit, numpy as np, random 4 | from jax.experimental.stax import serial, Dense, Relu, Tanh 5 | from jax.flatten_util import ravel_pytree 6 | from jax.tree_util import tree_map 7 | from jax.nn import relu 8 | from matplotlib import pyplot as plt 9 | import numpy as onp 10 | import scipy.sparse 11 | from collections import namedtuple 12 | 13 | 14 | 15 | def named_serial(*layers): 16 | # based on jax.experimental.stax.serial 17 | nlayers = len(layers) 18 | names, fns = zip(*layers) 19 | init_fns, apply_fns = zip(*fns) 20 | output_name = names[-1] 21 | 22 | def init_fn(rng, input_shape): 23 | params = {} 24 | for name, init_fn in zip(names, init_fns): 25 | rng, layer_rng = random.split(rng) 26 | input_shape, param = init_fn(layer_rng, input_shape) 27 | params[name] = param 28 | return input_shape, params 29 | 30 | def apply_fn(params, inputs, ret=None, **kwargs): 31 | rng = kwargs.pop('rng', None) 32 | rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers 33 | result = {} 34 | for fun, name, rng in zip(apply_fns, names, rngs): 35 | inputs = result[name] = fun(params[name], inputs, rng=rng, **kwargs) 36 | 37 | if ret is None: 38 | return inputs 39 | elif ret == 'all': 40 | return result 41 | else: 42 | return result[ret] 43 | 44 | return init_fn, apply_fn 45 | 46 | def make_float64(params): 47 | return tree_map(lambda x: x.astype(np.float64), params) 48 | 49 | def get_flatten_fns(init_fn, in_shape, float64=True): 50 | rng = random.PRNGKey(0) 51 | _, dummy_params = init_fn(rng, in_shape) 52 | if float64: 53 | dummy_params = make_float64(dummy_params) 54 | _, unflatten = ravel_pytree(dummy_params) 55 | def flatten(p): 56 | return ravel_pytree(p)[0] 57 | return flatten, unflatten 58 | 59 | 60 | 61 | 62 | 63 | --------------------------------------------------------------------------------