├── .gitignore ├── README.md ├── correlation.py ├── curvature_utils.py ├── displacement_integration.py ├── gradient_descent_mlp_utils.py ├── init_directory.py ├── learn-to-teleport ├── gradient_descent_mlp.py ├── lstm.py ├── plot.py ├── run_mlp_regression.py └── teleportation.py ├── models.py ├── plot.py ├── teleport_optimization.py └── teleport_sharpness_curvature.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | data 3 | figures 4 | logs 5 | __pycache__ 6 | learn-to-teleport/__pycache__ 7 | learn-to-teleport/figures 8 | learn-to-teleport/results.pkl 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Paper 2 | Bo Zhao, Robert M. Gower, Robin Walters, Rose Yu. [Improving Convergence and Generalization Using Parameter Symmetries](https://arxiv.org/abs/2305.13404). *International Conference on Learning Representations (ICLR)*, 2024. 3 | 4 | ## Abstract 5 | In many neural networks, different values of the parameters may result in the same loss value. Parameter space symmetries are loss-invariant transformations that change the model parameters. Teleportation applies such transformations to accelerate optimization. However, the exact mechanism behind this algorithm's success is not well understood. In this paper, we show that teleportation not only speeds up optimization in the short-term, but gives overall faster time to convergence. Additionally, teleporting to minima with different curvatures improves generalization, which suggests a connection between the curvature of the minimum and generalization ability. Finally, we show that integrating teleportation into a wide range of optimization algorithms and optimization-based meta-learning improves convergence. Our results showcase the versatility of teleportation and demonstrate the potential of incorporating symmetry in optimization. 6 | 7 | ## Requirements 8 | * [PyTorch](https://pytorch.org/) 9 | * [Matplotlib](https://matplotlib.org/) 10 | * [SciPy](https://scipy.org/install/) 11 | * [Shapely](https://shapely.readthedocs.io/en/stable/) 12 | 13 | ## Initializing directories 14 | ``` 15 | python init_directory.py 16 | ``` 17 | 18 | ## Reproducing experiments in the paper 19 | Correlation between sharpness/curvature and validation loss (Table 1, Figure 9, Figure 10): 20 | 21 | ``` 22 | python correlation.py 23 | ``` 24 | 25 | How curvature influences the expected displacement of minima under distribution shifts (Figure 7): 26 | 27 | ``` 28 | python displacement_integration.py 29 | ``` 30 | 31 | Teleportation to change sharpness or curvature (Figure 4): 32 | 33 | ``` 34 | python teleport_sharpness_curvature.py 35 | ``` 36 | 37 | Integrating teleportation with various optimizers (Figure 5, Figure 13): 38 | 39 | ``` 40 | python teleport_optimization.py 41 | ``` 42 | 43 | Meta-learning (Figure 6): 44 | 45 | ``` 46 | cd learn-to-teleport 47 | python multi_layer_regression.py 48 | ``` 49 | Figures are saved in directories `figures/` and `learn-to-teleport/figures/`. 50 | 51 | ## Cite 52 | ``` 53 | @article{zhao2024improving, 54 | title={Improving Convergence and Generalization Using Parameter Symmetries}, 55 | author={Bo Zhao and Robert M. Gower and Robin Walters and Rose Yu}, 56 | journal={International Conference on Learning Representations}, 57 | year={2024} 58 | } 59 | ``` -------------------------------------------------------------------------------- /correlation.py: -------------------------------------------------------------------------------- 1 | """ Compute correlation between sharpness/curvature and validation loss. """ 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | import pickle 6 | from scipy.stats import pearsonr 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torchvision import datasets 11 | import torchvision.transforms as transforms 12 | from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler 13 | 14 | from gradient_descent_mlp_utils import init_param_MLP, valid_MLP, loss_MLP_from_vec, loss_multi_layer 15 | from models import MLP 16 | from curvature_utils import W_list_to_vec, vec_to_W_list, compute_curvature, compute_gamma_12 17 | from plot import plot_correlation 18 | 19 | device = 'cpu' 20 | dataset = 'CIFAR10' # 'MNIST', 'FashionMNIST', 'CIFAR10' 21 | sigma_name = 'leakyrelu' # 'leakyrelu' 22 | sigma = nn.LeakyReLU(0.1) 23 | criterion = nn.CrossEntropyLoss() 24 | 25 | num_run = 100 26 | total_epoch = 40 27 | check_epoch = 40 # compute curvature/sharpness using W_lists at this epoch. 28 | 29 | # dataset and hyper-parameters 30 | batch_size = 20 31 | valid_size = 0.2 32 | tele_batch_size = 200 33 | if dataset == 'MNIST': 34 | lr = 1e-2 35 | t_start = 0.001 36 | t_end = 0.2 37 | t_interval = 0.01 38 | dim = [batch_size, 28*28, 16, 10, 10] # [batch_size, 28*28, 512, 512, 10] 39 | teledim = [tele_batch_size, 28*28, 16, 10, 10] # [tele_batch_size, 28*28, 512, 512, 10] 40 | train_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor()) 41 | test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor()) 42 | elif dataset == 'FashionMNIST': 43 | lr = 1e-2 44 | t_start = 0.0001 45 | t_end = 0.02 46 | t_interval = 0.001 47 | dim = [batch_size, 28*28, 16, 10, 10] 48 | teledim = [tele_batch_size, 28*28, 16, 10, 10] 49 | train_data = datasets.FashionMNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor()) 50 | test_data = datasets.FashionMNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor()) 51 | elif dataset == 'CIFAR10': 52 | lr = 2e-2 53 | t_start = 0.0001 54 | t_end = 0.02 55 | t_interval = 0.001 56 | dim = [batch_size, 32*32*3, 128, 32, 10] 57 | teledim = [tele_batch_size, 32*32*3, 128, 32, 10] 58 | train_data = datasets.CIFAR10(root = 'data', train = True, download = True, transform = transforms.ToTensor()) 59 | test_data = datasets.CIFAR10(root = 'data', train = False, download = True, transform = transforms.ToTensor()) 60 | else: 61 | raise ValueError('dataset should be one of MNIST, fashion, and CIFAR10') 62 | 63 | # data loaders 64 | if dataset in ['MNIST', 'FashionMNIST']: 65 | train_subset, val_subset = torch.utils.data.random_split( 66 | train_data, [50000, 10000], generator=torch.Generator().manual_seed(1)) 67 | train_sampler = SequentialSampler(train_subset) 68 | train_loader = torch.utils.data.DataLoader(train_subset, batch_size = batch_size, 69 | sampler = train_sampler, num_workers = 0) 70 | test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, 71 | num_workers = 0) 72 | teleport_loader = torch.utils.data.DataLoader(train_subset, batch_size = tele_batch_size, 73 | shuffle=True, num_workers = 0) 74 | teleport_loader_iterator = iter(teleport_loader) 75 | else: #CIFAR10 76 | train_sampler = SequentialSampler(train_data) 77 | train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, 78 | sampler = train_sampler, num_workers = 0) 79 | test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, 80 | num_workers = 0) 81 | teleport_loader = torch.utils.data.DataLoader(train_data, batch_size = tele_batch_size, 82 | shuffle=True, num_workers = 0) 83 | teleport_loader_iterator = iter(teleport_loader) 84 | 85 | 86 | def train_step_SGD(x_train, y_train, model, criterion, optimizer): 87 | model.train() 88 | optimizer.zero_grad() 89 | output = model.forward(x_train) 90 | loss = criterion(output.T, y_train) 91 | loss.backward() 92 | optimizer.step() 93 | return loss 94 | 95 | def run_SGD_rand(seed): 96 | # run SGD with initial weights determined by seed 97 | W_list = init_param_MLP(dim, seed) 98 | loss_arr_SGD = [] 99 | dL_dt_arr_SGD = [] 100 | valid_loss_SGD = [] 101 | valid_correct_SGD = [] 102 | 103 | model = MLP(init_W_list=W_list, activation=sigma) 104 | model.to(device) 105 | criterion = nn.CrossEntropyLoss() 106 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) 107 | 108 | for epoch in range(total_epoch): 109 | epoch_loss = 0.0 110 | for data, label in train_loader: 111 | batch_size = data.shape[0] 112 | data = torch.t(data.view(batch_size, -1)) 113 | loss = train_step_SGD(data, label, model, criterion, optimizer) 114 | epoch_loss += loss.item() * data.size(1) 115 | loss_arr_SGD.append(epoch_loss / len(train_loader.sampler)) 116 | 117 | if epoch == 39: 118 | W_list_40 = model.get_W_list() 119 | 120 | if epoch in [39]: 121 | valid_loss, valid_correct = valid_MLP(model, criterion, test_loader) 122 | valid_loss_SGD.append(valid_loss) 123 | valid_correct_SGD.append(valid_correct) 124 | else: 125 | valid_loss_SGD.append(0) 126 | valid_correct_SGD.append(0) 127 | 128 | if epoch % 10 == 9: 129 | print(epoch, loss_arr_SGD[-1], valid_loss_SGD[-1], valid_correct_SGD[-1]) 130 | 131 | results = (loss_arr_SGD, valid_loss_SGD, dL_dt_arr_SGD, valid_correct_SGD, W_list_40) 132 | return results 133 | 134 | 135 | # run SGD with random seed 100 times and save 136 | for i in range(num_run): 137 | print(i) 138 | results = run_SGD_rand((i+1)*54321) 139 | print(i, results[1][-1]) 140 | with open('logs/correlation/{}/{}_final_W_lists/W_lists_{}_{}.pkl'.format(dataset, dataset, sigma_name, i), 'wb') as f: 141 | pickle.dump(results, f) 142 | 143 | # compute sharpness and curvature and save 144 | perturb_mean_list = [] 145 | curvature_mean_list = [] 146 | valid_loss_list = [] 147 | train_loss_list = [] 148 | 149 | for i in range(num_run): 150 | if i % 10 == 0: 151 | print(i) 152 | with open('logs/correlation/{}/{}_final_W_lists/W_lists_{}_{}.pkl'.format(dataset, dataset, sigma_name, i), 'rb') as f: 153 | train_loss, valid_loss, _, _, W_list = pickle.load(f) 154 | 155 | train_loss_list.append(train_loss[check_epoch-1]) 156 | valid_loss_list.append(valid_loss[check_epoch-1]) 157 | W_vec_all = W_list_to_vec(W_list) 158 | 159 | curvature_list = [] 160 | perturb_list = [] 161 | for curve_idx in range(200): 162 | # load data batch 163 | try: 164 | tele_data, tele_target = next(teleport_loader_iterator) 165 | except StopIteration: 166 | teleport_loader_iterator = iter(teleport_loader) 167 | tele_data, tele_target = next(teleport_loader_iterator) 168 | 169 | batch_size = tele_data.shape[0] 170 | tele_data = torch.t(tele_data.view(batch_size, -1)) 171 | X = tele_data 172 | Y = tele_target 173 | 174 | # curvature (Equation 5) 175 | M_list = [] 176 | torch.manual_seed(12345 * curve_idx) 177 | 178 | for m in range(0, len(W_list)-1): 179 | M = torch.rand(dim[m+2], dim[m+2]) 180 | M = M / torch.norm(M, p='fro', dim=None) 181 | M_list.append(M) 182 | 183 | gamma_1_list, gamma_2_list = compute_gamma_12(M_list, W_list, X) 184 | curvature = compute_curvature(gamma_1_list, gamma_2_list).item() 185 | curvature_list.append(curvature) 186 | 187 | # sharpness (Equation 4) 188 | W_list_perturb = [] 189 | for t in np.arange(t_start, t_end, t_interval): 190 | random_dir = torch.rand(W_vec_all.size()[0]) 191 | random_dir = random_dir / torch.norm(random_dir) * t 192 | W_vec_all_perturb = W_vec_all + random_dir 193 | loss_perturb = loss_MLP_from_vec(W_vec_all_perturb, X, Y, dim, sigma) 194 | perturb_list.append(loss_perturb) 195 | 196 | curvature_mean_list.append(np.average(curvature_list)) 197 | perturb_mean_list.append(np.average(perturb_list) / tele_batch_size) # correct 198 | 199 | curvature_mean_list = np.array(curvature_mean_list) 200 | valid_loss_list = np.array(valid_loss_list) 201 | train_loss_list = np.array(train_loss_list) 202 | 203 | results = (curvature_mean_list, perturb_mean_list, valid_loss_list, train_loss_list) 204 | with open('logs/correlation/{}/{}_final_W_lists/curvatures_all_{}.pkl'.format(dataset, dataset, sigma_name), 'wb') as f: 205 | pickle.dump(results, f) 206 | 207 | plot_correlation(dataset, sigma_name) -------------------------------------------------------------------------------- /curvature_utils.py: -------------------------------------------------------------------------------- 1 | """ Functions for computing curvatures. """ 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | sigma = nn.LeakyReLU(0.1) 8 | sigma_inv = nn.LeakyReLU(10) 9 | 10 | def sigma_1(x): 11 | # first derivative of sigma (LeakyReLU) 12 | x[x > 0] = 1.0 13 | x[x < 0] = 0.1 14 | return x 15 | 16 | def sigma_inv_1(x): 17 | # first derivative of sigma^{-1} (LeakyReLU) 18 | x[x > 0] = 1.0 19 | x[x < 0] = 10.0 20 | return x 21 | 22 | def W_list_to_vec(W_list): 23 | # Flatten and concatenate all weight matrices to a vector. 24 | W_vec_all = torch.flatten(W_list[0]) 25 | for i in range(1, len(W_list)): 26 | W_vec = torch.flatten(W_list[i]) 27 | W_vec_all = torch.concat((W_vec_all, W_vec)) 28 | return W_vec_all 29 | 30 | def vec_to_W_list(W_vec_all, dim): 31 | # Reshape vectorized weight to matrices. 32 | # dim: list of dimensions of weight matrices. Example: [4, 5, 6, 7, 8] -> X: 5x4, W1:6x5, W2:7x6, W3:8x7, Y:8x4 33 | W_list = [] 34 | start_idx = 0 35 | for i in range(len(dim)-2): 36 | end_idx = start_idx + dim[i+2]*dim[i+1] 37 | W_list.append(torch.reshape(W_vec_all[start_idx:end_idx], (dim[i+2], dim[i+1]))) 38 | start_idx = end_idx 39 | return W_list 40 | 41 | def compute_curvature(gamma_1_list, gamma_2_list): 42 | """Compute curvature of gamma from its first and second derivatives. (Equation (47) in paper) 43 | 44 | Args: 45 | gamma_1_list: First derivative of curve gamma(t), d gamma / dt. 46 | gamma_2_list: Second derivative of curve gamma(t), d^2 gamma / dt^2. 47 | 48 | Returns: 49 | Curvature of gamma. 50 | """ 51 | 52 | gamma_1_vec = W_list_to_vec(gamma_1_list) 53 | gamma_2_vec = W_list_to_vec(gamma_2_list) 54 | gamma_1_norm = torch.norm(gamma_1_vec) 55 | gamma_2_norm = torch.norm(gamma_2_vec) 56 | numerator = torch.sqrt(gamma_1_norm**2 * gamma_2_norm**2 - torch.dot(gamma_1_vec, gamma_2_vec)**2) 57 | denominator = gamma_1_norm**3 58 | return numerator / denominator 59 | 60 | def compute_gamma_12(M_list, W_list, X): 61 | """Compute the first and second derivatives of curve gamma(t). 62 | See Equation (57) and (60) in paper. Note that the second derivative of leaky ReLU is 0. 63 | 64 | Args: 65 | M_list: List of Lie algebras (random square matrices). 66 | W_list: List of weight matrices. 67 | X: Data matrix. 68 | 69 | Returns: 70 | gamma_1_list: First derivative of curve gamma(t), d gamma / dt. 71 | gamma_2_list: Second derivative of curve gamma(t), d^2 gamma / dt^2. 72 | """ 73 | 74 | gamma_1_list = [] 75 | gamma_2_list = [] 76 | for m in range(0, len(W_list)): 77 | gamma_1_list.append(torch.zeros_like(W_list[m])) 78 | gamma_2_list.append(torch.zeros_like(W_list[m])) 79 | 80 | h = X 81 | for m in range(0, len(W_list)-1): 82 | M = M_list[m] 83 | U = W_list[m+1] 84 | V = W_list[m] 85 | 86 | h_inv = torch.linalg.pinv(h) 87 | M_2 = torch.matmul(M, M) 88 | M_3 = torch.matmul(M_2, M) 89 | sigma_VX = sigma(torch.matmul(V, h)) 90 | sigma_1_VX = sigma_1(torch.matmul(V, h)) 91 | M_sigma_VX = torch.matmul(M, sigma_VX) 92 | 93 | 94 | gamma_1_list[m+1] += (-1) * torch.matmul(U, M) 95 | gamma_1_list[m] += torch.matmul(M_sigma_VX / sigma_1_VX, h_inv) 96 | 97 | gamma_2_list[m+1] += torch.matmul(U, M_2) 98 | gamma_2_list[m] += torch.matmul(torch.matmul(M_2, sigma_VX) / sigma_1_VX, h_inv) 99 | 100 | h = sigma(torch.matmul(V, h)) 101 | 102 | return gamma_1_list, gamma_2_list 103 | -------------------------------------------------------------------------------- /displacement_integration.py: -------------------------------------------------------------------------------- 1 | # Script for computing the expected distance between w0 and a curve as a function of the curvature at w0. 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import shapely.geometry as geom 6 | from scipy.integrate import quad, nquad 7 | 8 | # define integrand = dist((r\cos{\theta}, r\sin{\theta}), \gamma) r 9 | def integrand(t, r, gamma): 10 | point = geom.Point(r * np.cos(t), r * np.sin(t)) 11 | dist = point.distance(gamma) * r 12 | return dist 13 | 14 | # set up color maps 15 | N = 5 16 | base_cmaps = ['Blues','Oranges'] 17 | n_base = len(base_cmaps) 18 | colors = np.concatenate([plt.get_cmap(name)(np.linspace(0.2,0.8,N)) for name in base_cmaps]) 19 | 20 | 21 | # gamma: (x, kx^2) 22 | plt.figure() 23 | plt.rcParams["axes.prop_cycle"] = plt.cycler("color", colors) 24 | for r in [2e0, 1e0, 5e-1, 2e-1, 1e-1]: 25 | k = np.arange(0, 40, 0.3) 26 | int_dist = np.zeros_like(k) 27 | for i in range(len(int_dist)): 28 | # construct gamma, which is a parabola here 29 | x_arr = np.arange(-10, 10, 0.1) 30 | y_arr = k[i] * np.abs(x_arr**2) 31 | 32 | point_arr = [] 33 | for p_idx in range(len(x_arr)): 34 | point_arr.append(geom.Point(x_arr[p_idx], y_arr[p_idx])) 35 | parabola = geom.LineString(point_arr) 36 | 37 | # compute \int_0^{2\pi} dist((r\cos{\theta}, r\sin{\theta}), \gamma) r d\theta 38 | int_dist[i] = quad(integrand, 0, 2*np.pi, args=(r, parabola))[0] 39 | 40 | # plot int_dist / (2 \pi r) / r. The additional r aligns all curves in one plot. 41 | plt.plot(2*k, int_dist / (2 * np.pi * r * r), linewidth=3, label='r={}'.format(r)) 42 | 43 | plt.xlabel(chr(954)+r'$ = 2k_1$', fontsize=26) 44 | plt.ylabel(r'$\frac{1}{r}\mathbb{E}_{S_r} dist(w, \gamma_1)$', fontsize=26) 45 | plt.xticks([0, 20, 40, 60, 80], fontsize=20) 46 | plt.yticks([0.625, 0.675, 0.725, 0.775], fontsize=20) 47 | plt.legend(fontsize=20) 48 | plt.savefig('figures/curvature_displacement_integral_kx_sqr.pdf', dpi=400, bbox_inches='tight') 49 | 50 | 51 | # gamma: x^2 + (y-k)^2 = k^2 52 | plt.figure() 53 | plt.rcParams["axes.prop_cycle"] = plt.cycler("color", colors) 54 | for r in [1e-1, 5e-2, 2e-2, 1e-2, 5e-3]: 55 | k = np.concatenate((np.arange(0.1, 1.0, 0.01), np.arange(1.0, 10.0, 0.1))) 56 | int_dist = np.zeros_like(k) 57 | for i in range(len(int_dist)): 58 | # construct gamma 59 | theta_arr = np.arange(0, 2*np.pi, 1e-3) 60 | x_arr = k[i] * np.cos(theta_arr) 61 | y_arr = k[i] * np.sin(theta_arr) + k[i] 62 | 63 | point_arr = [] 64 | for p_idx in range(len(x_arr)): 65 | point_arr.append(geom.Point(x_arr[p_idx], y_arr[p_idx])) 66 | gamma = geom.LineString(point_arr) 67 | 68 | # compute \int_0^{2\pi} dist((r\cos{\theta}, r\sin{\theta}), \gamma) r d\theta 69 | int_dist[i] = quad(integrand, 0, 2*np.pi, args=(r, gamma), limit=500)[0] 70 | 71 | # plot int_dist / (2 \pi r) / r. The additional r aligns all curves in one plot. 72 | plt.plot(1 / k, int_dist / (2 * np.pi * r * r), linewidth=3, label='r={}'.format(r)) 73 | 74 | plt.xlabel(r'$\kappa = \frac{1}{k_2}$', fontsize=26) 75 | plt.ylabel(r'$\frac{1}{r}\mathbb{E}_{S_r} dist(w, \gamma_2)$', fontsize=26) 76 | plt.xticks(fontsize= 20) 77 | plt.yticks([0.6, 0.61, 0.62, 0.63], fontsize=20) 78 | plt.legend(fontsize=20) 79 | plt.savefig('figures/curvature_displacement_integral_circle.pdf', dpi=400, bbox_inches='tight') 80 | -------------------------------------------------------------------------------- /gradient_descent_mlp_utils.py: -------------------------------------------------------------------------------- 1 | """ Functions for gradient descent and teleportations. """ 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from curvature_utils import W_list_to_vec, vec_to_W_list, compute_curvature 8 | 9 | def init_param_MLP(dim, seed=54321): 10 | # dim: list of dimensions of weight matrices. 11 | # Example: [4, 5, 6, 7, 8] -> X: 5x4, W1:6x5, W2:7x6, W3:8x7, Y:8x4 12 | torch.manual_seed(seed) 13 | W_list = [] 14 | for i in range(len(dim) - 2): 15 | k = 1 / np.sqrt(dim[i+1]) # https://pytorch.org/docs/stable/generated/torch.nn.Linear.html 16 | W = 2 * k * torch.rand(dim[i+2], dim[i+1], requires_grad=True) - k 17 | W_list.append(W) 18 | return W_list 19 | 20 | def loss_multi_layer(W_list, X, Y, sigma): 21 | h = X 22 | for i in range(len(W_list)-1): 23 | h = sigma(torch.matmul(W_list[i], h)) 24 | pred = torch.matmul(W_list[-1], h) 25 | pred = F.log_softmax(pred, dim=0) 26 | return F.nll_loss(torch.t(pred), Y), pred 27 | 28 | def loss_MLP_from_vec(W_vec_all, X, Y, dim, sigma): 29 | W_list = vec_to_W_list(W_vec_all, dim) 30 | L, _ = loss_multi_layer(W_list, X, Y, sigma) 31 | return L 32 | 33 | def valid_MLP(model, criterion, valid_loader): 34 | model.eval() 35 | test_loss = 0.0 36 | test_correct = 0 37 | for data, target in valid_loader: 38 | batch_size = data.shape[0] 39 | data = torch.t(data.view(batch_size, -1)) 40 | output = model(data) 41 | L = criterion(output.T, target) 42 | test_loss += L.item()*data.size(1) 43 | 44 | _, pred = torch.max(output, 0) 45 | test_correct += pred.eq(target.data.view_as(pred)).sum().item() 46 | 47 | test_loss = test_loss / len(valid_loader.sampler) 48 | test_correct = 100.0 * test_correct / len(valid_loader.sampler) 49 | return test_loss, test_correct 50 | 51 | def train_step(x_train, y_train, model, criterion, optimizer): 52 | model.train() 53 | optimizer.zero_grad() 54 | output = model.forward(x_train) 55 | loss = criterion(output.T, y_train) 56 | loss.backward() 57 | optimizer.step() 58 | return loss 59 | 60 | def test_MLP(model, criterion, test_loader): 61 | model.eval() 62 | test_loss = 0.0 63 | class_correct = list(0. for i in range(10)) 64 | class_total = list(0. for i in range(10)) 65 | 66 | for data, target in test_loader: 67 | batch_size = data.shape[0] 68 | data = torch.t(data.view(batch_size, -1)) 69 | output = model(data) 70 | L = criterion(output.T, target) 71 | test_loss += L.item()*data.size(1) 72 | _, pred = torch.max(output, 0) 73 | correct = np.squeeze(pred.eq(target.data.view_as(pred))) 74 | for i in range(len(target)): 75 | label = target.data[i] 76 | class_correct[label] += correct[i].item() 77 | class_total[label] += 1 78 | 79 | test_loss = test_loss/len(test_loader.sampler) 80 | print('Test Loss: {:.6f}\n'.format(test_loss)) 81 | 82 | for i in range(10): 83 | print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % ( 84 | str(i), 100 * class_correct[i] / class_total[i], 85 | np.sum(class_correct[i]), np.sum(class_total[i]))) 86 | print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % ( 87 | 100. * np.sum(class_correct) / np.sum(class_total), 88 | np.sum(class_correct), np.sum(class_total))) 89 | return test_loss, np.sum(class_correct) / np.sum(class_total) 90 | 91 | 92 | ############################################################## 93 | # group actions 94 | def group_action(U, V, X, X_inv, T, sigma): 95 | # U, V -> U sigma(VX) sigma((I+T)VX)^+, (I+T)V 96 | 97 | k = list(T.size())[0] 98 | I = torch.eye(k) 99 | 100 | V_out = torch.matmul((I+T), V) 101 | Wh = torch.matmul(V, X) 102 | sigma_Wh = sigma(Wh) 103 | sigma_gWh = sigma(torch.matmul((I+T), Wh)) 104 | sigma_gWh_inv = torch.linalg.pinv(sigma_gWh) 105 | U_out = torch.matmul(torch.matmul(U, sigma_Wh), sigma_gWh_inv) 106 | return U_out, V_out 107 | 108 | def group_action_large(U, V, X, X_inv, g, g_inv, sigma): 109 | # U, V -> U sigma(VX) sigma(gVX)^+, gV 110 | 111 | k = list(g.size())[0] 112 | I = torch.eye(k) 113 | 114 | V_out = torch.matmul(g, V) 115 | Wh = torch.matmul(V, X) 116 | sigma_Wh = sigma(Wh) 117 | sigma_gWh = sigma(torch.matmul(g, Wh)) 118 | sigma_gWh_inv = torch.linalg.pinv(sigma_gWh) 119 | U_out = torch.matmul(torch.matmul(U, sigma_Wh), sigma_gWh_inv) 120 | return U_out, V_out 121 | 122 | def group_action_exp(t, U, V, X, X_inv, M, sigma): 123 | # U, V -> U sigma(VX) sigma(exp(tM)VX)^+, exp(tM)V 124 | 125 | g = torch.linalg.matrix_exp(t * M) 126 | g_inv = torch.linalg.pinv(g) 127 | 128 | V_out = torch.matmul(g, V) 129 | Wh = torch.matmul(V, X) 130 | sigma_Wh = sigma(Wh) 131 | sigma_gWh = sigma(torch.matmul(g, Wh)) 132 | sigma_gWh_inv = torch.linalg.pinv(sigma_gWh) 133 | U_out = torch.matmul(torch.matmul(U, sigma_Wh), sigma_gWh_inv) 134 | return U_out, V_out 135 | 136 | ############################################################## 137 | # first (or second) derivatives of the component of gamma corresponding to U (or V) 138 | def compute_gamma_1_U(t, U, V, h, h_inv, M, sigma): 139 | func = lambda t_: group_action_exp(t_, U, V, h, h_inv, M, sigma)[0] 140 | gamma_1 = torch.autograd.functional.jacobian(func, t, create_graph=True) 141 | gamma_1 = torch.squeeze(gamma_1) 142 | return gamma_1 143 | 144 | def compute_gamma_1_V(t, U, V, h, h_inv, M, sigma): 145 | func = lambda t_: group_action_exp(t_, U, V, h, h_inv, M, sigma)[1] 146 | gamma_1 = torch.autograd.functional.jacobian(func, t, create_graph=True) 147 | gamma_1 = torch.squeeze(gamma_1) 148 | return gamma_1 149 | 150 | def compute_gamma_2_U(t, U, V, h, h_inv, M, sigma): 151 | func = lambda t_: compute_gamma_1_U(t_, U, V, h, h_inv, M, sigma) 152 | gamma_2 = torch.autograd.functional.jacobian(func, t, create_graph=True) 153 | gamma_2 = torch.squeeze(gamma_2) 154 | return gamma_2 155 | 156 | def compute_gamma_2_V(t, U, V, h, h_inv, M, sigma): 157 | func = lambda t_: compute_gamma_1_V(t_, U, V, h, h_inv, M, sigma) 158 | gamma_2 = torch.autograd.functional.jacobian(func, t, create_graph=True) 159 | gamma_2 = torch.squeeze(gamma_2) 160 | return gamma_2 161 | 162 | ############################################################## 163 | # teleportation 164 | 165 | def teleport_curvature(W_list, X, Y, lr_teleport, dim, sigma, telestep=10, reverse=False): 166 | # reverse = True if minimizing curvature, False if maximizing curvature. 167 | print("before teleport", loss_multi_layer(W_list, X, Y, sigma)[0]) 168 | 169 | X_inv = torch.linalg.pinv(X) 170 | h_list = [X] 171 | h_inv_list = [X_inv] 172 | for m in range(0, len(W_list)-2): 173 | h = sigma(torch.matmul(W_list[m], h_list[-1])) 174 | h_list.append(h) 175 | h_inv_list.append(torch.linalg.pinv(h)) 176 | 177 | for teleport_step in range(telestep): 178 | layer = 1 179 | t = torch.zeros(1, requires_grad=True) 180 | M = torch.rand(dim[layer+2], dim[layer+2], requires_grad=True) 181 | 182 | # compute curvature using autograd 183 | gamma_1_U = compute_gamma_1_U(t, W_list[1+layer], W_list[0+layer], h_list[0+layer], h_inv_list[0+layer], M, sigma) 184 | gamma_1_V = compute_gamma_1_V(t, W_list[1+layer], W_list[0+layer], h_list[0+layer], h_inv_list[0+layer], M, sigma) 185 | 186 | gamma_2_U = compute_gamma_2_U(t, W_list[1+layer], W_list[0+layer], h_list[0+layer], h_inv_list[0+layer], M, sigma) 187 | gamma_2_V = compute_gamma_2_V(t, W_list[1+layer], W_list[0+layer], h_list[0+layer], h_inv_list[0+layer], M, sigma) 188 | 189 | gamma_1_list = [] 190 | gamma_2_list = [] 191 | for m in range(0, len(W_list)): 192 | gamma_1_list.append(torch.zeros_like(W_list[m])) 193 | gamma_2_list.append(torch.zeros_like(W_list[m])) 194 | 195 | gamma_1_list[0+layer] = gamma_1_U 196 | gamma_1_list[1+layer] = gamma_1_V 197 | gamma_2_list[0+layer] = gamma_2_U 198 | gamma_2_list[1+layer] = gamma_2_V 199 | 200 | kappa = compute_curvature(gamma_1_list, gamma_2_list) # curvature 201 | kappa_1 = torch.autograd.grad(kappa, inputs=t, create_graph=True)[0] # derivative of curvature 202 | 203 | # gradient descent/ascent on t to decrease/increase curvature 204 | if reverse: 205 | t = t - lr_teleport * kappa_1 206 | else: 207 | t = t + lr_teleport * kappa_1 208 | print(kappa, kappa_1, t) 209 | 210 | # transform weights using the updated t 211 | g = torch.linalg.matrix_exp(t * M) 212 | g_inv = torch.linalg.pinv(g) 213 | W_list[1+layer], W_list[0+layer] = group_action_exp(t, W_list[1+layer], W_list[0+layer], h_list[0+layer], h_inv_list[0+layer], M, sigma) 214 | 215 | h_list = [X] 216 | h_inv_list = [X_inv] 217 | for m in range(0, len(W_list)-2): 218 | h = sigma(torch.matmul(W_list[m], h_list[-1])) 219 | h_list.append(h) 220 | h_inv_list.append(torch.linalg.pinv(h)) 221 | 222 | print("after teleport", loss_multi_layer(W_list, X, Y, sigma)[0]) 223 | 224 | return W_list 225 | 226 | 227 | def teleport_sharpness(W_list, X, Y, lr_teleport, dim, sigma, telestep=10, loss_perturb_cap=2.0, reverse=False, \ 228 | t_start=0.001, t_end=0.2, t_interval=0.001): 229 | # reverse = True if minimizing sharpness, False if maximizing sharpness. 230 | 231 | X_inv = torch.linalg.pinv(X) 232 | h_list = [X] 233 | h_inv_list = [X_inv] 234 | for m in range(0, len(W_list)-2): 235 | h = sigma(torch.matmul(W_list[m], h_list[-1])) 236 | h_list.append(h) 237 | h_inv_list.append(torch.linalg.pinv(h)) 238 | 239 | for teleport_step in range(telestep): 240 | gW_list = W_list.copy() 241 | T = [] # list of elements of Lie algebras 242 | 243 | # initialize T[i] = 0 and g.W = (I+T).W 244 | for m in range(0, len(gW_list)-1): 245 | T.append(torch.zeros(dim[m+2], dim[m+2], requires_grad=True)) 246 | gW_list[m+1], gW_list[m] = group_action(gW_list[m+1], gW_list[m], h_list[m], h_inv_list[m], T[m], sigma) 247 | 248 | # compute sharpness (loss_perturb_mean) 249 | num_t = len(np.arange(0.1, 5.0, 0.5)) 250 | num_d = 100 251 | loss_perturb_mean = 0.0 252 | for (idx, t) in enumerate(np.arange(0.1, 1.0, 0.1)): 253 | for d_idx in range(num_d): 254 | W_vec_all = W_list_to_vec(gW_list) 255 | random_dir = torch.rand(W_vec_all.size()[0], requires_grad=True) 256 | random_dir = random_dir / torch.norm(random_dir) * t 257 | W_vec_all_perturb = W_vec_all + random_dir 258 | loss_perturb = loss_MLP_from_vec(W_vec_all_perturb, X, Y, dim, sigma) 259 | loss_perturb_mean += loss_perturb 260 | loss_perturb_mean = loss_perturb_mean / num_t / num_d 261 | print(teleport_step, loss_perturb_mean) 262 | if loss_perturb_mean > loss_perturb_cap: 263 | break 264 | 265 | # gradient descent/ascent on T to decrease/increase sharpness (loss_perturb_mean) 266 | dLdt_dT_list = torch.autograd.grad(loss_perturb_mean, inputs=T, create_graph=True) 267 | for i in range(len(T)): 268 | if reverse: 269 | T[i] = T[i] - lr_teleport * dLdt_dT_list[i] 270 | else: 271 | T[i] = T[i] + lr_teleport * dLdt_dT_list[i] 272 | 273 | # transform weights using the updated T 274 | for m in range(0, len(W_list)-1): 275 | W_list[m+1], W_list[m] = group_action(W_list[m+1], W_list[m], h_list[m], h_inv_list[m], T[m], sigma) 276 | 277 | # update the list of hidden representations h_list 278 | for m in range(1, len(h_list)): 279 | k = list(T[m-1].size())[0] 280 | I = torch.eye(k) 281 | h_list[m] = torch.matmul(I + T[m-1], h_list[m]) 282 | h_inv_list[m] = torch.matmul(h_inv_list[m], I - T[m-1]) 283 | 284 | return W_list 285 | 286 | 287 | def teleport(W_list, X, Y, lr_teleport, dim, sigma, telestep=10, dL_dt_cap=100, random_teleport=False, reverse=False): 288 | # teleportation to increase gradient norm 289 | 290 | # print("before teleport", loss_multi_layer(W_list, X, Y, sigma)[0]) 291 | X_inv = torch.linalg.pinv(X) 292 | h_list = [X] 293 | h_inv_list = [X_inv] 294 | for m in range(0, len(W_list)-2): 295 | h = sigma(torch.matmul(W_list[m], h_list[-1])) 296 | h_list.append(h) 297 | h_inv_list.append(torch.linalg.pinv(h)) 298 | 299 | if random_teleport == True: 300 | for m in range(0, len(W_list)-1): 301 | g = torch.rand(dim[m+2], dim[m+2]) 302 | g = g / torch.norm(g, p='fro', dim=None) * 0.01 + torch.eye(dim[m+2]) * 1e0 303 | g_inv = torch.linalg.pinv(g) 304 | W_list[m+1], W_list[m] = group_action_large(W_list[m+1], W_list[m], h_list[m], h_inv_list[m], g, g_inv, sigma) 305 | return W_list 306 | 307 | 308 | for teleport_step in range(telestep): 309 | # populate gW_list with T.W, where T=I 310 | gW_list = W_list.copy() 311 | T = [] 312 | for m in range(0, len(gW_list)-1): 313 | T.append(torch.zeros(dim[m+2], dim[m+2], requires_grad=True)) 314 | gW_list[m+1], gW_list[m] = group_action(gW_list[m+1], gW_list[m], h_list[m], h_inv_list[m], T[m], sigma) 315 | 316 | # compute L(T.W) and dL/d(T.W) 317 | L, _ = loss_multi_layer(gW_list, X, Y, sigma) 318 | dL_dW_list = torch.autograd.grad(L, inputs=gW_list, create_graph=True) 319 | 320 | # compute dL/dt=||dL/d(T.W)||^2 and d/dT dL/dt 321 | dL_dt = 0 322 | for i in range(len(gW_list)): 323 | dL_dt += torch.norm(dL_dW_list[i])**2 324 | 325 | if dL_dt.detach().numpy() > dL_dt_cap: 326 | break 327 | 328 | # gradient ascent step on T, in the direction of d/dT dL/dt 329 | dLdt_dT_list = torch.autograd.grad(dL_dt, inputs=T) 330 | for i in range(len(T)): 331 | if reverse: 332 | T[i] = T[i] - lr_teleport * dLdt_dT_list[i] 333 | else: 334 | T[i] = T[i] + lr_teleport * dLdt_dT_list[i] 335 | 336 | # replace original W's with T.W, using the new T's 337 | for m in range(0, len(W_list)-1): 338 | W_list[m+1], W_list[m] = group_action(W_list[m+1], W_list[m], h_list[m], h_inv_list[m], T[m], sigma) 339 | 340 | 341 | for m in range(1, len(h_list)): 342 | k = list(T[m-1].size())[0] 343 | I = torch.eye(k) 344 | h_list[m] = torch.matmul(I + T[m-1], h_list[m]) 345 | h_inv_list[m] = torch.matmul(h_inv_list[m], I - T[m-1]) 346 | 347 | # print("after teleport", loss_multi_layer(W_list, X, Y, sigma)[0]) 348 | 349 | return W_list 350 | -------------------------------------------------------------------------------- /init_directory.py: -------------------------------------------------------------------------------- 1 | """ Scripts for setting up directories. """ 2 | 3 | import os 4 | 5 | if not os.path.exists('data'): 6 | os.mkdir('data') 7 | 8 | if not os.path.exists('figures'): 9 | os.mkdir('figures') 10 | if not os.path.exists('figures/correlation'): 11 | os.mkdir('figures/correlation') 12 | if not os.path.exists('figures/generalization'): 13 | os.mkdir('figures/generalization') 14 | if not os.path.exists('figures/optimization'): 15 | os.mkdir('figures/optimization') 16 | 17 | if not os.path.exists('logs'): 18 | os.mkdir('logs') 19 | if not os.path.exists('logs/correlation'): 20 | os.mkdir('logs/correlation') 21 | if not os.path.exists('logs/generalization'): 22 | os.mkdir('logs/generalization') 23 | if not os.path.exists('logs/optimization'): 24 | os.mkdir('logs/optimization') 25 | 26 | dataset_list = ['MNIST', 'FashionMNIST', 'CIFAR10'] 27 | for dataset in dataset_list: 28 | if not os.path.exists('logs/correlation/{}'.format(dataset)): 29 | os.mkdir('logs/correlation/{}'.format(dataset)) 30 | if not os.path.exists('logs/generalization/{}'.format(dataset)): 31 | os.mkdir('logs/generalization/{}'.format(dataset)) 32 | if not os.path.exists('logs/optimization/{}'.format(dataset)): 33 | os.mkdir('logs/optimization/{}'.format(dataset)) 34 | 35 | if not os.path.exists('logs/generalization/{}/{}_SGD'.format(dataset, dataset)): 36 | os.mkdir('logs/generalization/{}/{}_SGD'.format(dataset, dataset)) 37 | if not os.path.exists('logs/generalization/{}/teleport_curvature'.format(dataset)): 38 | os.mkdir('logs/generalization/{}/teleport_curvature'.format(dataset)) 39 | if not os.path.exists('logs/generalization/{}/teleport_sharpness'.format(dataset)): 40 | os.mkdir('logs/generalization/{}/teleport_sharpness'.format(dataset)) 41 | 42 | if not os.path.exists('logs/correlation/{}/{}_final_W_lists'.format(dataset, dataset)): 43 | os.mkdir('logs/correlation/{}/{}_final_W_lists'.format(dataset, dataset)) -------------------------------------------------------------------------------- /learn-to-teleport/gradient_descent_mlp.py: -------------------------------------------------------------------------------- 1 | """ Meta-learning algorithms for training MLPs. """ 2 | 3 | import numpy as np 4 | import time 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Variable 8 | from teleportation import teleport_MLP_random, teleport_MLP_gradient_ascent, teleport_MLP 9 | from lstm import LSTM_tele, LSTM_tele_lr, LSTM_local_update 10 | 11 | def detach_var(v): 12 | # make gradient an independent variable that is independent from the rest of the computational graph 13 | var = Variable(v.data, requires_grad=True) 14 | var.retain_grad() 15 | return var 16 | 17 | def init_param(dim, seed=12345): 18 | # dim: list of dimensions of weight matrices. 19 | # Example: [4, 5, 6, 7, 8] -> X: 5x4, W1:6x5, W2:7x6, W3:8x7, Y:8x4 20 | torch.manual_seed(seed) 21 | W_list = [] 22 | for i in range(len(dim) - 2): 23 | W_list.append(torch.rand(dim[i+2], dim[i+1], requires_grad=True)) 24 | X = torch.rand(dim[1], dim[0], requires_grad=True) 25 | Y = torch.rand(dim[-1], dim[0], requires_grad=True) 26 | return W_list, X, Y 27 | 28 | def W_list_to_vec(W_list): 29 | W_vec_all = torch.flatten(W_list[0]) 30 | for i in range(1, len(W_list)): 31 | W_vec = torch.flatten(W_list[i]) 32 | W_vec_all = torch.concat((W_vec_all, W_vec)) 33 | return W_vec_all 34 | 35 | def vec_to_W_list(W_vec_all, dim): 36 | W_list = [] 37 | start_idx = 0 38 | for i in range(len(dim)-2): 39 | end_idx = start_idx + dim[i+2]*dim[i+1] 40 | W_list.append(torch.reshape(W_vec_all[start_idx:end_idx], (dim[i+2], dim[i+1]))) 41 | start_idx = end_idx 42 | return W_list 43 | 44 | def loss_multi_layer(W_list, X, Y, sigma=nn.LeakyReLU(0.1)): 45 | h = X 46 | for i in range(len(W_list)-1): 47 | h = sigma(torch.matmul(W_list[i], h)) 48 | return 0.5 * torch.norm(Y - torch.matmul(W_list[-1], h)) ** 2 49 | 50 | def train_epoch_GD(W_list, X, Y, lr): 51 | L = loss_multi_layer(W_list, X, Y) 52 | dL_dW_list = torch.autograd.grad(L, inputs=W_list, retain_graph=True) 53 | dL_dt = 0 54 | for i in range(len(W_list)): 55 | W_list[i] = W_list[i] - lr * dL_dW_list[i] 56 | dL_dt += torch.norm(dL_dW_list[i])**2 57 | return W_list, L, dL_dt, dL_dW_list 58 | 59 | 60 | def train_GD(dim, n_run=5, n_epoch=300, K=[5], teleport=False, random=False, lr=1e-4, lr_teleport=1e-7, T_magnitude=1.0): 61 | """ Run gradient descent with or without teleportation. 62 | 63 | Args: 64 | dim: list of dimensions of weight matrices. Example: [4, 5, 6, 7, 8] -> 65 | X: 5x4, W1:6x5, W2:7x6, W3:8x7, Y:8x4 66 | n_run: number of runs using different random seeds for initialization. 67 | n_epoch: number of epochs of each run. 68 | K: teleportation schedule. Teleportation is performed at epochs in K if teleport=True. 69 | teleport: True if using teleportation, False otherwise. 70 | random: True if teleporting using random group element, False if using optimized group element. 71 | lr: learning rate for gradient descent. 72 | lr_teleport: learning rate for gradient ascent on group element during teleportation. 73 | T_magnitude: frobenius norm of elements in T after normalization. 74 | 75 | Returns: 76 | time_arr_SGD_n: Wall-clock time at each epoch. Dimension n_run x n_epoch. 77 | loss_arr_SGD_n: Loss after each epoch. Dimension n_run x n_epoch. 78 | dL_dt_arr_SGD_n: Squared gradient norm at each epoch. Dimension n_run x n_epoch. 79 | """ 80 | 81 | time_arr_SGD_n = [] 82 | loss_arr_SGD_n = [] 83 | dL_dt_arr_SGD_n = [] 84 | 85 | for n in range(n_run): 86 | W_list, X, Y = init_param(dim, seed=n*n*12345) 87 | time_arr_SGD = [] 88 | loss_arr_SGD = [] 89 | dL_dt_arr_SGD = [] 90 | 91 | t0 = time.time() 92 | for epoch in range(n_epoch): 93 | if teleport == True and epoch in K: 94 | if random: 95 | W_list = teleport_MLP_random(W_list, X, T_magnitude, dim) 96 | else: 97 | W_list = teleport_MLP_gradient_ascent(W_list, X, Y, lr_teleport, dim, loss_multi_layer, 8) 98 | 99 | W_list, loss, dL_dt, _ = train_epoch_GD(W_list, X, Y, lr) 100 | t1 = time.time() 101 | time_arr_SGD.append(t1 - t0) 102 | loss_arr_SGD.append(loss.detach().numpy()) 103 | dL_dt_arr_SGD.append(dL_dt.detach().numpy()) 104 | 105 | time_arr_SGD_n.append(time_arr_SGD) 106 | loss_arr_SGD_n.append(loss_arr_SGD) 107 | dL_dt_arr_SGD_n.append(dL_dt_arr_SGD) 108 | 109 | return time_arr_SGD_n, loss_arr_SGD_n, dL_dt_arr_SGD_n 110 | 111 | 112 | def train_meta_opt(dim, n_run=20, n_epoch=20, unroll=5, lr=1e-4, lr_meta=1e-3, learn_lr=True, learn_tele=True, learn_update=True, T_magnitude=0.01): 113 | """ Run gradient descent with or without teleportation. 114 | 115 | Args: 116 | dim: list of dimensions of weight matrices. Example: [4, 5, 6, 7, 8] -> 117 | X: 5x4, W1:6x5, W2:7x6, W3:8x7, Y:8x4 118 | n_run: number of runs using different random seeds for initialization. 119 | n_epoch: number of epochs of each run. 120 | unroll: number of steps before each update of the parameters in the meta-optimizers 121 | lr: learning rate for gradient descent of the MLP parameters. 122 | lr_meta: learning rate for meta-optimizers. 123 | learn_lr: True if meta-optimizers learn lr, False otherwise. 124 | learn_tele: True if meta-optimizers learn teleportation, False if no teleportation is applied. 125 | learn_update: True if meta-optimizers learn local updates, False otherwise. 126 | T_magnitude: frobenius norm of elements in T after normalization. 127 | 128 | Returns: 129 | meta_opt_list: a meta_optimizer that outputs the group elements for teleportation 130 | meta_opt_update: a meta_optimizer that outputs the local update for MLP parameters. None if learn_update=False. 131 | """ 132 | 133 | # initialize a list of meta_opt, one for each pair of weights (for teleportation) 134 | meta_opt_list = [] 135 | optimizer_list = [] 136 | n_meta_opt = len(dim) - 3 137 | 138 | for i in range(n_meta_opt): 139 | n_param = dim[i+1] * dim[i+2] + dim[i+2] * dim[i+3] 140 | if learn_update: 141 | meta_opt = LSTM_tele(n_param, 300, dim[i+2]) 142 | else: 143 | if learn_lr==False: 144 | meta_opt = LSTM_tele(n_param, 20, dim[i+2]) 145 | else: 146 | meta_opt = LSTM_tele_lr(n_param, 20, dim[i+2]) 147 | 148 | meta_opt_list.append(meta_opt) 149 | optimizer = torch.optim.Adam(meta_opt_list[i].parameters(), lr=1e-4) 150 | optimizer_list.append(optimizer) 151 | 152 | meta_opt_list[i].train() 153 | 154 | # initialize a meta_opt for all weights (for update step) 155 | if learn_update: 156 | W_list, _, _ = init_param(dim, seed=12345) 157 | n_param = W_list_to_vec(W_list).shape[0] 158 | meta_opt_update = LSTM_local_update(n_param, 300, n_param) 159 | optimizer_update = torch.optim.Adam(meta_opt_update.parameters(), lr=lr_meta) 160 | else: 161 | meta_opt_update = None 162 | optimizer_update = None 163 | 164 | # for each of the n_run training trajectories 165 | for n in range(n_run): 166 | if n % 100 == 0: 167 | print("run", n) 168 | W_list, X, Y = init_param(dim, seed=n*n*12345-1) 169 | X_inv = torch.linalg.pinv(X) 170 | loss_sum = None 171 | loss_sum_all = 0.0 172 | 173 | # initialize LSTM hidden and cell 174 | hidden = [] 175 | cell = [] 176 | for i in range(n_meta_opt): 177 | cell.append(Variable(torch.zeros(2, 1, meta_opt.lstm_hidden_dim), requires_grad=True)) 178 | hidden.append(Variable(torch.zeros(2, 1, meta_opt.lstm_hidden_dim), requires_grad=True)) 179 | 180 | if learn_update: # learn local updates 181 | cell_update = Variable(torch.zeros(2, 1, meta_opt_update.lstm_hidden_dim), requires_grad=True) 182 | hidden_update = Variable(torch.zeros(2, 1, meta_opt_update.lstm_hidden_dim), requires_grad=True) 183 | 184 | for epoch in range(n_epoch): 185 | # compute loss gradients, compute local updates from meta optimizer, perform local updates for MLP parameters 186 | loss = loss_multi_layer(W_list, X, Y) 187 | dL_dW_list = torch.autograd.grad(loss, inputs=W_list, retain_graph=True) 188 | W_update, hidden_update, cell_update = meta_opt_update(W_list_to_vec(dL_dW_list), hidden_update, cell_update) 189 | W_list = vec_to_W_list(W_list_to_vec(W_list) - W_update, dim) 190 | 191 | loss_sum_all += loss.data 192 | if loss_sum is None: 193 | loss_sum = loss 194 | else: 195 | loss_sum += loss 196 | 197 | # update meta optimizers 198 | if epoch % unroll == 0 and epoch != 0: 199 | for i in range(n_meta_opt): 200 | optimizer_list[i].zero_grad() 201 | optimizer_update.zero_grad() 202 | loss_sum.backward(retain_graph=True) 203 | 204 | for i in range(n_meta_opt): 205 | optimizer_list[i].step() 206 | optimizer_update.step() 207 | 208 | loss_sum = None 209 | hidden = [detach_var(v) for v in hidden] 210 | cell = [detach_var(v) for v in cell] 211 | hidden_update = detach_var(hidden_update) 212 | cell_update = detach_var(cell_update) 213 | W_list = [detach_var(v) for v in W_list] 214 | 215 | # compute group elements from meta optimizers and teleport MLP parameters 216 | if learn_tele == True: 217 | g_list = [] 218 | for i in range(n_meta_opt): 219 | g, hidden[i], cell[i] = meta_opt_list[i](dL_dW_list[i], dL_dW_list[i+1], hidden[i], cell[i]) 220 | g_list.append(g) 221 | W_list = teleport_MLP(W_list, X, X_inv, g_list, using_T=True, T_magnitude=T_magnitude) 222 | 223 | else: # does not learn local updates 224 | for epoch in range(n_epoch): 225 | # one gradient descent step on MLP parameters, using learned learning rate if learn_lr=True 226 | if learn_lr==False or epoch == 0: 227 | W_list, loss, dL_dt, dL_dW_list = train_epoch_GD(W_list, X, Y, lr) 228 | else: 229 | learned_lr = torch.mean(torch.stack(step_size_list), dim=0) 230 | W_list, loss, dL_dt, dL_dW_list = train_epoch_GD(W_list, X, Y, learned_lr) 231 | 232 | loss_sum_all += loss.data 233 | if loss_sum is None: 234 | loss_sum = loss 235 | else: 236 | loss_sum += loss 237 | 238 | # update meta optimizers 239 | if epoch % unroll == 0 and epoch != 0: 240 | for i in range(n_meta_opt): 241 | optimizer_list[i].zero_grad() 242 | loss_sum.backward(retain_graph=True) 243 | 244 | for i in range(n_meta_opt): 245 | optimizer_list[i].step() 246 | 247 | loss_sum = None 248 | hidden = [detach_var(v) for v in hidden] 249 | cell = [detach_var(v) for v in cell] 250 | W_list = [detach_var(v) for v in W_list] 251 | 252 | # compute group elements from meta optimizers and teleport MLP parameters 253 | g_list = [] 254 | step_size_list = [] 255 | for i in range(n_meta_opt): 256 | if learn_lr == True: 257 | g, step_size, hidden[i], cell[i] = meta_opt_list[i](dL_dW_list[i], dL_dW_list[i+1], hidden[i], cell[i]) 258 | else: 259 | g, hidden[i], cell[i] = meta_opt_list[i](dL_dW_list[i], dL_dW_list[i+1], hidden[i], cell[i]) 260 | step_size = None 261 | g_list.append(g) 262 | step_size_list.append(step_size) 263 | 264 | W_list = teleport_MLP(W_list, X, X_inv, g_list, using_T=True, T_magnitude=0.01) 265 | 266 | return meta_opt_list, meta_opt_update 267 | 268 | 269 | def test_meta_opt(meta_opt_list, meta_opt_update, dim, n_run=5, n_epoch=300, lr=1e-4, learn_lr=False, learn_tele=True, learn_update=True, T_magnitude=0.01): 270 | """ Run gradient descent with or without teleportation. 271 | 272 | Args: 273 | meta_opt_list: a meta_optimizer that outputs the group elements for teleportation 274 | meta_opt_update: a meta_optimizer that outputs the local update for MLP parameters. None if learn_update=False. 275 | dim: list of dimensions of weight matrices. Example: [4, 5, 6, 7, 8] -> 276 | X: 5x4, W1:6x5, W2:7x6, W3:8x7, Y:8x4 277 | n_run: number of runs using different random seeds for initialization. 278 | n_epoch: number of epochs of each run. 279 | lr: learning rate for gradient descent of the MLP parameters. 280 | learn_lr: True if meta-optimizers learn lr, False otherwise. 281 | learn_tele: True if meta-optimizers learn teleportation, False if no teleportation is applied. 282 | learn_update: True if meta-optimizers learn local updates, False otherwise. 283 | T_magnitude: frobenius norm of elements in T after normalization. 284 | 285 | Returns: 286 | time_arr_teleport_n: Wall-clock time at each epoch. Dimension n_run x n_epoch. 287 | loss_arr_teleport_n: Loss after each epoch. Dimension n_run x n_epoch. 288 | dL_dt_arr_teleport_n: Squared gradient norm at each epoch. Dimension n_run x n_epoch. 289 | lr_arr_teleport_n: Learning rate for MLP parameters at each epoch. Dimension n_run x n_epoch. 290 | """ 291 | 292 | time_arr_teleport_n = [] 293 | loss_arr_teleport_n = [] 294 | dL_dt_arr_teleport_n = [] 295 | lr_arr_teleport_n = [] 296 | 297 | n_meta_opt = len(dim) - 3 298 | for i in range(n_meta_opt): 299 | meta_opt_list[i].eval() # list of meta_opt, one for each pair of weights (for teleportation) 300 | 301 | if learn_update: 302 | meta_opt_update.eval() # meta_opt for all weights (for update step) 303 | 304 | for n in range(n_run): 305 | W_list, X, Y = init_param(dim, seed=n*n*12345) 306 | X_inv = torch.linalg.pinv(X) 307 | 308 | # initialize LSTM hidden and cell 309 | hidden = [] 310 | cell = [] 311 | for i in range(n_meta_opt): 312 | cell.append(Variable(torch.zeros(2, 1, meta_opt_list[0].lstm_hidden_dim), requires_grad=True)) 313 | hidden.append(Variable(torch.zeros(2, 1, meta_opt_list[0].lstm_hidden_dim), requires_grad=True)) 314 | 315 | if learn_update: 316 | cell_update = Variable(torch.zeros(2, 1, meta_opt_update.lstm_hidden_dim), requires_grad=True) 317 | hidden_update = Variable(torch.zeros(2, 1, meta_opt_update.lstm_hidden_dim), requires_grad=True) 318 | 319 | time_arr_teleport = [] 320 | loss_arr_teleport = [] 321 | dL_dt_arr_teleport = [] 322 | lr_arr_teleport = [] 323 | 324 | t0 = time.time() 325 | for epoch in range(n_epoch): 326 | if learn_update: # learn local updates 327 | # compute loss gradients, compute local updates from meta optimizer, perform local updates for MLP parameters 328 | loss = loss_multi_layer(W_list, X, Y) 329 | dL_dW_list = torch.autograd.grad(loss, inputs=W_list, retain_graph=True) 330 | W_update, hidden_update, cell_update = meta_opt_update(W_list_to_vec(dL_dW_list), hidden_update, cell_update) 331 | W_list = vec_to_W_list(W_list_to_vec(W_list) - W_update, dim) 332 | dL_dt = torch.norm(W_list_to_vec(W_list))**2 333 | 334 | # compute group elements from meta optimizers and teleport MLP parameters 335 | g_list = [] 336 | step_size_list = [] 337 | for i in range(n_meta_opt): 338 | g, hidden[i], cell[i] = meta_opt_list[i](dL_dW_list[i], dL_dW_list[i+1], hidden[i], cell[i]) 339 | step_size = None 340 | g_list.append(g) 341 | step_size_list.append(step_size) 342 | if learn_tele == True: 343 | W_list = teleport_MLP(W_list, X, X_inv, g_list, using_T=True, T_magnitude=T_magnitude) 344 | 345 | else: # does not learn local updates 346 | # one gradient descent step on MLP parameters, using learned learning rate if learn_lr=True 347 | if learn_lr==False or epoch == 0: 348 | W_list, loss, dL_dt, dL_dW_list = train_epoch_GD(W_list, X, Y, lr) 349 | else: 350 | learned_lr = torch.mean(torch.stack(step_size_list), dim=0) 351 | W_list, loss, dL_dt, dL_dW_list = train_epoch_GD(W_list, X, Y, learned_lr) 352 | lr_arr_teleport.append(learned_lr.detach().numpy()) 353 | 354 | # compute group elements from meta optimizers and teleport MLP parameters 355 | g_list = [] 356 | step_size_list = [] 357 | for i in range(n_meta_opt): 358 | if learn_lr == True: 359 | g, step_size, hidden[i], cell[i] = meta_opt_list[i](dL_dW_list[i], dL_dW_list[i+1], hidden[i], cell[i]) 360 | else: 361 | g, hidden[i], cell[i] = meta_opt_list[i](dL_dW_list[i], dL_dW_list[i+1], hidden[i], cell[i]) 362 | step_size = None 363 | g_list.append(g) 364 | step_size_list.append(step_size) 365 | 366 | teleport_MLP(W_list, X, X_inv, g_list, using_T=True, T_magnitude=0.01) 367 | 368 | t1 = time.time() 369 | time_arr_teleport.append(t1 - t0) 370 | loss_arr_teleport.append(loss.detach().numpy()) 371 | dL_dt_arr_teleport.append(dL_dt.detach().numpy()) 372 | 373 | time_arr_teleport_n.append(time_arr_teleport) 374 | loss_arr_teleport_n.append(loss_arr_teleport) 375 | dL_dt_arr_teleport_n.append(dL_dt_arr_teleport) 376 | lr_arr_teleport_n.append(lr_arr_teleport) 377 | 378 | return time_arr_teleport_n, loss_arr_teleport_n, dL_dt_arr_teleport_n, lr_arr_teleport_n 379 | -------------------------------------------------------------------------------- /learn-to-teleport/lstm.py: -------------------------------------------------------------------------------- 1 | """ LSTM models. """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class LSTM_tele(nn.Module): 8 | """ 9 | LSTM model that takes in the gradient of two layers. 10 | Returns a group element for teleportation. 11 | """ 12 | def __init__(self, input_dim, lstm_hidden_dim, out_dim, preproc=False): 13 | super(LSTM_tele, self).__init__() 14 | 15 | self.input_dim = input_dim 16 | self.out_dim = out_dim 17 | self.lstm_hidden_dim = lstm_hidden_dim 18 | 19 | self.lstm1 = nn.LSTMCell(input_dim, lstm_hidden_dim) 20 | self.lstm2 = nn.LSTMCell(lstm_hidden_dim, lstm_hidden_dim) 21 | self.linear = nn.Linear(lstm_hidden_dim, out_dim*out_dim) 22 | 23 | 24 | def forward(self, dL_dU, dL_dV, hidden, cell): 25 | grad_flatten = torch.cat((torch.flatten(dL_dU), torch.flatten(dL_dV)), 0) 26 | grad_flatten = grad_flatten[None, :] 27 | 28 | h0, c0 = self.lstm1(grad_flatten, (hidden[0], cell[0])) 29 | h1, c1 = self.lstm2(h0, (hidden[1], cell[1])) 30 | g = self.linear(h1) 31 | g = torch.reshape(g, (self.out_dim, self.out_dim)) 32 | return g, torch.stack((h0, h1)), torch.stack((c0, c1)) 33 | 34 | 35 | class LSTM_tele_lr(nn.Module): 36 | """ 37 | LSTM model that takes in the gradient of two layers. 38 | Returns a group element for teleportation and a step size for the next gradient descent step. 39 | """ 40 | def __init__(self, input_dim, lstm_hidden_dim, out_dim, preproc=False): 41 | super(LSTM_tele_lr, self).__init__() 42 | 43 | self.input_dim = input_dim 44 | self.out_dim = out_dim 45 | self.lstm_hidden_dim = lstm_hidden_dim 46 | 47 | self.lstm1 = nn.LSTMCell(input_dim, lstm_hidden_dim) 48 | self.lstm2 = nn.LSTMCell(lstm_hidden_dim, lstm_hidden_dim) 49 | self.linear1 = nn.Linear(lstm_hidden_dim, out_dim*out_dim) 50 | self.linear2 = nn.Linear(lstm_hidden_dim, 1) 51 | 52 | 53 | def forward(self, dL_dU, dL_dV, hidden, cell): 54 | grad_flatten = torch.cat((torch.flatten(dL_dU), torch.flatten(dL_dV)), 0) 55 | grad_flatten = grad_flatten[None, :] 56 | 57 | h0, c0 = self.lstm1(grad_flatten, (hidden[0], cell[0])) 58 | h1, c1 = self.lstm2(h0, (hidden[1], cell[1])) 59 | g = self.linear1(h1) 60 | g = torch.reshape(g, (self.out_dim, self.out_dim)) 61 | 62 | step_size = torch.clamp(self.linear2(h1), min=1e-7, max=5e-3) 63 | 64 | return g, step_size, torch.stack((h0, h1)), torch.stack((c0, c1)) 65 | 66 | 67 | class LSTM_local_update(nn.Module): 68 | """ 69 | LSTM model that takes in the gradient of all weights and returns the local update. 70 | input_dim and output_dim are expected to be the same. 71 | """ 72 | def __init__(self, input_dim, lstm_hidden_dim, out_dim, preproc=False): 73 | super(LSTM_local_update, self).__init__() 74 | 75 | self.input_dim = input_dim 76 | self.out_dim = out_dim 77 | self.lstm_hidden_dim = lstm_hidden_dim 78 | 79 | self.lstm1 = nn.LSTMCell(input_dim, lstm_hidden_dim) 80 | self.lstm2 = nn.LSTMCell(lstm_hidden_dim, lstm_hidden_dim) 81 | self.linear = nn.Linear(lstm_hidden_dim, out_dim) 82 | 83 | 84 | def forward(self, grad_flatten, hidden, cell): 85 | grad_flatten = grad_flatten[None, :] 86 | 87 | h0, c0 = self.lstm1(grad_flatten, (hidden[0], cell[0])) 88 | h1, c1 = self.lstm2(h0, (hidden[1], cell[1])) 89 | update = self.linear(h1) 90 | update = torch.squeeze(update) 91 | update = torch.clamp(update, min=-1e7, max=1e7) 92 | 93 | return update, torch.stack((h0, h1)), torch.stack((c0, c1)) 94 | -------------------------------------------------------------------------------- /learn-to-teleport/plot.py: -------------------------------------------------------------------------------- 1 | """ Helper functions for figures. """ 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | import os 6 | 7 | 8 | def plot_all(time_arr_list, loss_arr_list, dL_dt_arr_list, label_list, n_epoch=300): 9 | if not os.path.exists('figures'): 10 | os.mkdir('figures') 11 | 12 | # compute mean and std across multiple runs 13 | loss_mean_list = [] 14 | loss_std_list = [] 15 | time_mean_list = [] 16 | time_std_list = [] 17 | for i in range(len(time_arr_list)): 18 | time_arr_list[i] = np.array(time_arr_list[i])[:, :n_epoch] 19 | loss_arr_list[i] = np.array(loss_arr_list[i])[:, :n_epoch] 20 | dL_dt_arr_list[i] = np.array(dL_dt_arr_list[i])[:, :n_epoch] 21 | loss_mean_list.append(np.mean(loss_arr_list[i], axis=0)) 22 | loss_std_list.append(np.std(loss_arr_list[i], axis=0)) 23 | time_mean_list.append(np.mean(time_arr_list[i], axis=0)) 24 | time_std_list.append(np.std(time_arr_list[i], axis=0)) 25 | 26 | # plot loss vs epoch 27 | plt.figure() 28 | for i in range(len(time_arr_list)): 29 | plt.plot(loss_mean_list[i], linewidth=3, label=label_list[i]) 30 | plt.gca().set_prop_cycle(None) 31 | for i in range(len(time_arr_list)): 32 | plt.fill_between(np.arange(n_epoch), loss_mean_list[i]-loss_std_list[i], loss_mean_list[i]+loss_std_list[i], alpha=0.5) 33 | plt.xlabel('Epoch', fontsize=26) 34 | plt.ylabel('Loss', fontsize=26) 35 | plt.yscale('log') 36 | plt.xticks([0, 10, 20, 30], fontsize= 20) 37 | plt.yticks(fontsize= 20) 38 | plt.legend(fontsize=17) 39 | plt.savefig('figures/multi_layer_loss.pdf', bbox_inches='tight') 40 | 41 | # plot loss vs wall-clock time 42 | plt.figure() 43 | for i in range(len(time_arr_list)): 44 | plt.plot(time_mean_list[i], loss_mean_list[i], linewidth=3, label=label_list[i]) 45 | plt.fill_between(time_mean_list[i], loss_mean_list[i]-loss_std_list[i], loss_mean_list[i]+loss_std_list[i], alpha=0.5) 46 | plt.xlabel('time (s)', fontsize=26) 47 | plt.ylabel('Loss', fontsize=26) 48 | max_t = np.max(time_arr_list[0]) 49 | interval = np.round(max_t * 0.3, 2) 50 | plt.xticks([0, interval, interval * 2, interval * 3], fontsize= 20) 51 | plt.yticks(fontsize= 20) 52 | plt.yscale('log') 53 | plt.legend(fontsize=17) 54 | plt.savefig('figures/multi_layer_loss_vs_time.pdf', bbox_inches='tight') 55 | 56 | # plot loss vs dL/dt 57 | plt.figure() 58 | for i in range(len(time_arr_list)): 59 | plt.plot(loss_arr_list[i][-1], dL_dt_arr_list[i][-1], linewidth=3, label=label_list[i]) 60 | plt.xlabel('Loss', fontsize=26) 61 | plt.ylabel('dL/dt', fontsize=26) 62 | plt.yscale('log') 63 | plt.xscale('log') 64 | plt.xticks(fontsize= 20) 65 | plt.yticks([1e1, 1e3, 1e5, 1e7], fontsize= 20) 66 | plt.legend(fontsize=17) 67 | plt.savefig('figures/multi_layer_loss_vs_gradient.pdf', bbox_inches='tight') 68 | 69 | return 70 | -------------------------------------------------------------------------------- /learn-to-teleport/run_mlp_regression.py: -------------------------------------------------------------------------------- 1 | """ Scripts for training and evaluating various meta-learning algorithms for MLP. """ 2 | 3 | import numpy as np 4 | import pickle 5 | import torch 6 | from torch import nn 7 | 8 | from gradient_descent_mlp import train_GD, train_meta_opt, test_meta_opt 9 | from plot import plot_all 10 | 11 | sigma = nn.LeakyReLU(0.1) 12 | sigma_inv = nn.LeakyReLU(10) 13 | 14 | dim = [20, 20, 20, 20] 15 | 16 | # do some random things first so that the wall-clock time comparison is fair 17 | train_GD(dim, n_run=1, n_epoch=20, lr=3e-4) 18 | 19 | # training with GD 20 | epoch = 300 21 | lr = 1e-4 22 | time_arr_SGD_n, loss_arr_SGD_n, dL_dt_arr_SGD_n = \ 23 | train_GD(dim, n_run=5, n_epoch=epoch, lr=3e-4) 24 | 25 | # train an lstm that learns teleportation + lr 26 | meta_opt_list, _ = train_meta_opt(dim, n_run=30, n_epoch=epoch, unroll=10, lr=lr, lr_meta=1e-3, learn_lr=True, learn_update=False, T_magnitude=0.01) 27 | time_arr_teleport_lstm_lr_n, loss_arr_teleport_lstm_lr_n, dL_dt_arr_teleport_lstm_lr_n, lr_arr_teleport_lstm_lr_n = \ 28 | test_meta_opt(meta_opt_list, None, dim, n_run=5, n_epoch=epoch, learn_lr=True, learn_update=False, T_magnitude=0.01) 29 | 30 | # train an lstm that learns local update + teleportation 31 | meta_opt_list, meta_opt_update = train_meta_opt(dim, n_run=700, n_epoch=100, unroll=10, lr=lr, lr_meta=1e-3, learn_tele=True, T_magnitude=0.01) 32 | time_arr_lstm_update_tele_n, loss_arr_lstm_update_tele_n, dL_dt_arr_lstm_update_tele_n, lr_arr_lstm_update_tele_n = \ 33 | test_meta_opt(meta_opt_list, meta_opt_update, dim, n_run=5, n_epoch=100, learn_tele=True, T_magnitude=0.01) 34 | 35 | # train an lstm that learns local update only 36 | meta_opt_list, meta_opt_update = train_meta_opt(dim, n_run=600, n_epoch=100, unroll=10, lr=lr, lr_meta=1e-3, learn_tele=False) 37 | time_arr_lstm_update_n, loss_arr_lstm_update_n, dL_dt_arr_lstm_update_n, lr_arr_lstm_update_n = \ 38 | test_meta_opt(meta_opt_list, meta_opt_update, dim, n_run=5, n_epoch=100, learn_tele=False) 39 | 40 | 41 | # save test results 42 | results = (time_arr_SGD_n, loss_arr_SGD_n, dL_dt_arr_SGD_n, \ 43 | time_arr_teleport_lstm_lr_n, loss_arr_teleport_lstm_lr_n, dL_dt_arr_teleport_lstm_lr_n, lr_arr_teleport_lstm_lr_n, \ 44 | time_arr_lstm_update_tele_n, loss_arr_lstm_update_tele_n, dL_dt_arr_lstm_update_tele_n, lr_arr_lstm_update_tele_n, \ 45 | time_arr_lstm_update_n, loss_arr_lstm_update_n, dL_dt_arr_lstm_update_n, lr_arr_lstm_update_n) 46 | with open('results.pkl', 'wb') as f: 47 | pickle.dump(results, f) 48 | 49 | 50 | # load and plot test results 51 | with open('results.pkl', 'rb') as f: 52 | (time_arr_SGD_n, loss_arr_SGD_n, dL_dt_arr_SGD_n, \ 53 | time_arr_teleport_lstm_lr_n, loss_arr_teleport_lstm_lr_n, dL_dt_arr_teleport_lstm_lr_n, lr_arr_teleport_lstm_lr_n, \ 54 | time_arr_lstm_update_tele_n, loss_arr_lstm_update_tele_n, dL_dt_arr_lstm_update_tele_n, lr_arr_lstm_update_tele_n, \ 55 | time_arr_lstm_update_n, loss_arr_lstm_update_n, dL_dt_arr_lstm_update_n, lr_arr_lstm_update_n) = pickle.load(f) 56 | 57 | plot_all([time_arr_SGD_n, time_arr_teleport_lstm_lr_n, time_arr_lstm_update_n, time_arr_lstm_update_tele_n], \ 58 | [loss_arr_SGD_n, loss_arr_teleport_lstm_lr_n, loss_arr_lstm_update_n, loss_arr_lstm_update_tele_n], \ 59 | [dL_dt_arr_SGD_n, dL_dt_arr_teleport_lstm_lr_n, dL_dt_arr_lstm_update_n, dL_dt_arr_lstm_update_tele_n], \ 60 | ['GD', 'LSTM(lr,tele)', 'LSTM(update)', 'LSTM(update,tele)'], n_epoch=30) 61 | -------------------------------------------------------------------------------- /learn-to-teleport/teleportation.py: -------------------------------------------------------------------------------- 1 | """ Group actions and teleportation algorithms for MLP. """ 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | sigma = nn.LeakyReLU(0.1) 8 | sigma_inv = nn.LeakyReLU(10) 9 | 10 | def group_action_MLP(U, V, X, X_inv, T, using_T=True, sigma=nn.LeakyReLU(0.1), sigma_inv=nn.LeakyReLU(10)): 11 | """GL(R) group actions on a pair of matrices. 12 | 13 | Performs the group action in equation (8) in https://arxiv.org/pdf/2205.10637.pdf. 14 | U = W_m, V = W_{m-1}, X = h_{m-2} 15 | 16 | Args: 17 | U: Matrix with dimension m x k. Weight acting on sigma(VX) 18 | V: Matrix with dimension k x n. Weight acting on X. 19 | X: Matrix with dimension m x n. Output from the previous layer. 20 | X_inv: Matrix with dimension n x m. Inverse of X. 21 | T: Matrix with dimension k x k. Element in the Lie algebra of GL_k(R) 22 | using_T: If True, U_out = U (I-T), V_out = sigma^{-1}((I+T)sigma(VX)) X^{-1} 23 | If False, U_out = U T^{-1}, V_out = sigma^{-1}(T sigma(VX)) X^{-1} 24 | sigma: Element-wise activation function. 25 | sigma_inv: Inverse of sigma. 26 | 27 | Returns: 28 | U_out: Result of g acting on U. Same dimension as U. 29 | V_out: Result of g acting on V. Same dimension as V. 30 | """ 31 | 32 | k = list(T.size())[0] 33 | I = torch.eye(k) 34 | if using_T: 35 | U_out = torch.matmul(U, (I-T)) 36 | V_out = sigma(torch.matmul(V, X)) 37 | V_out = torch.matmul((I+T), V_out) 38 | V_out = sigma_inv(V_out) 39 | V_out = torch.matmul(V_out, X_inv) 40 | else: 41 | T_inv = torch.linalg.pinv(T) 42 | 43 | U_out = torch.matmul(U, T_inv) 44 | V_out = sigma(torch.matmul(V, X)) 45 | V_out = torch.matmul(T, V_out) 46 | V_out = sigma_inv(V_out) 47 | V_out = torch.matmul(V_out, X_inv) 48 | return U_out, V_out 49 | 50 | def teleport_MLP(W_list, X, X_inv, T, using_T=True, T_magnitude=None): 51 | """ GL(R) group actions on all layers in an MLP. 52 | 53 | Args: 54 | W_list: list of weight matrices. 55 | X: Data matrix, with dimension a x b. 56 | X_inv: Matrix with dimension n x m. Inverse of X. 57 | T: list of Lie algebra elements used to transform the weight matrices 58 | T_magnitude: frobenius norm of elements in T 59 | 60 | Returns: 61 | W_list: Teleported weights. Same shapes as the input W_list. 62 | """ 63 | 64 | # Normalize T's to the specified magnitude 65 | if T_magnitude != None: 66 | for m in range(0, len(W_list)-1): 67 | T[m] = T[m] / torch.norm(T[m], p='fro', dim=None) * T_magnitude 68 | 69 | h = X 70 | h_inv = X_inv 71 | h_inv_list = [h_inv] 72 | for m in range(0, len(W_list)-1): 73 | W_list[m+1], W_list[m] = group_action_MLP(W_list[m+1], W_list[m], h, h_inv, T[m], using_T) 74 | h = sigma(torch.matmul(W_list[m], h)) 75 | h_inv = torch.linalg.pinv(h) 76 | h_inv_list.append(h_inv) 77 | 78 | return W_list 79 | 80 | def teleport_MLP_random(W_list, X, magnitude, dim): 81 | """ Teleportation using random T's with specified magnitude. 82 | 83 | Args: 84 | W_list: list of weight matrices. 85 | X: Data matrix, with dimension a x b. 86 | magnitude: frobenius norm of elements in T 87 | dim: list of dimensions of weight matrices. Example: [4, 5, 6, 7, 8] -> 88 | X: 5x4, W1:6x5, W2:7x6, W3:8x7, Y:8x4 89 | 90 | Returns: 91 | W_list: Teleported weights. Same shapes as the input W_list. 92 | """ 93 | X_inv = torch.linalg.pinv(X) 94 | T = [] 95 | for m in range(0, len(W_list)-1): 96 | T_m = torch.rand(dim[m+2], dim[m+2], requires_grad=True) 97 | T_m = T_m / torch.norm(T_m, p='fro', dim=None) * magnitude + torch.eye(dim[m+2]) * 1e0 98 | T.append(T_m) 99 | W_list = teleport_MLP(W_list, X, X_inv, T, using_T=False) 100 | return W_list 101 | 102 | def teleport_MLP_gradient_ascent(W_list, X, Y, lr_teleport, dim, loss_func, step=10, sigma=nn.LeakyReLU(0.1)): 103 | """Teleportation on weight matrices in a multi-layer neural network, using gradient ascent. 104 | 105 | Args: 106 | W_list: list of weight matrices. 107 | X: Data matrix, with dimension a x b. 108 | Y: Label matrix, with dimension c x b. 109 | lr_teleport: A scalar. Learning rate used in optimizing the group element. 110 | dim: list of dimensions of weight matrices. Example: [4, 5, 6, 7, 8] -> 111 | X: 5x4, W1:6x5, W2:7x6, W3:8x7, Y:8x4 112 | loss_func: Loss function in the optimization problem. 113 | step: An integer. Number of gradient ascent steps used to optimize the group 114 | element. 115 | sigma: Element-wise activation function. 116 | 117 | Returns: 118 | W_list: Teleported weights. Same shapes as the input W_list. 119 | """ 120 | 121 | X_inv = torch.linalg.pinv(X) 122 | 123 | for teleport_step in range(step): 124 | # populate gW_list with T.W, where T=I 125 | gW_list = W_list.copy() 126 | T = [] 127 | h = X 128 | h_inv = X_inv 129 | for m in range(0, len(gW_list)-1): 130 | T.append(torch.zeros(dim[m+2], dim[m+2], requires_grad=True)) 131 | gW_list[m+1], gW_list[m] = group_action_MLP(gW_list[m+1], gW_list[m], h, h_inv, T[m]) 132 | h = sigma(torch.matmul(gW_list[m], h)) 133 | h_inv = torch.linalg.pinv(h) 134 | 135 | # compute L(T.W) and dL/d(T.W) 136 | L = loss_func(gW_list, X, Y) 137 | dL_dW_list = torch.autograd.grad(L, inputs=gW_list, create_graph=True) 138 | 139 | # compute dL/dt=||dL/d(T.W)||^2 and d/dT dL/dt 140 | dL_dt = 0 141 | for i in range(len(gW_list)): 142 | dL_dt += torch.norm(dL_dW_list[i])**2 143 | dLdt_dT_list = torch.autograd.grad(dL_dt, inputs=T) 144 | 145 | # gradient ascent step on T, in the direction of d/dT dL/dt 146 | for i in range(len(T)): 147 | T[i] = T[i] + lr_teleport * dLdt_dT_list[i] 148 | 149 | # replace original W's with T.W, using the new T's 150 | W_list = teleport_MLP(W_list, X, X_inv, T) 151 | 152 | return W_list -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """ MLP models """ 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import Parameter 6 | 7 | class MLP(nn.Module): 8 | def __init__(self, init_W_list, activation): 9 | super(MLP, self).__init__() 10 | self.W_list = nn.ParameterList([]) 11 | for i in range(len(init_W_list)): 12 | self.W_list.append(Parameter(init_W_list[i].clone())) 13 | self.activation = activation 14 | 15 | def forward(self, X): 16 | h = X.clone() 17 | for i in range(len(self.W_list)-1): 18 | h = self.activation(torch.matmul(self.W_list[i], h)) 19 | 20 | out = torch.matmul(self.W_list[-1], h) 21 | return out 22 | 23 | def get_W_list(self): 24 | W_list = [] 25 | for param in self.W_list: 26 | W_list.append(param.data.clone()) 27 | return W_list -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | """ Helper functions for figures. """ 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | import os 6 | import pickle 7 | from scipy.stats import pearsonr 8 | 9 | def plot_optimization(opt_method_list, dataset, lr): 10 | x_right = 40 11 | for opt_method in opt_method_list: 12 | if opt_method == 'Adagrad': 13 | y_ticks = [0.2, 0.3, 0.4, 0.5] 14 | elif opt_method == 'momentum': 15 | y_ticks = [0.1, 0.3, 0.5, 0.7] 16 | elif opt_method == 'RMSprop': 17 | y_ticks = [0.2, 0.4, 0.6, 0.8] 18 | elif opt_method == 'Adam': 19 | y_ticks = [0.3, 0.5, 0.7, 0.9] 20 | 21 | plt.figure() 22 | train_loss_list = [] 23 | valid_loss_list = [] 24 | time_list = [] 25 | train_loss_teleport_list = [] 26 | valid_loss_teleport_list = [] 27 | time_teleport_list = [] 28 | for run_num in range(5): 29 | with open('logs/optimization/{}/{}_{}_lr_{}_{}.pkl'.format(dataset, dataset, opt_method, lr, run_num), 'rb') as f: 30 | loss_arr_SGD, valid_loss_SGD, _, _, time_SGD, _ = pickle.load(f) 31 | with open('logs/optimization/{}/{}_{}_lr_{}_teleport_{}.pkl'.format(dataset, dataset, opt_method, lr, run_num), 'rb') as f: 32 | loss_arr_teleport, valid_loss_teleport, _, _, time_teleport, _ = pickle.load(f) 33 | train_loss_list.append(loss_arr_SGD) 34 | valid_loss_list.append(valid_loss_SGD) 35 | time_list.append(time_SGD) 36 | time_teleport_list.append(time_teleport) 37 | 38 | train_loss_teleport_list.append(loss_arr_teleport) 39 | valid_loss_teleport_list.append(valid_loss_teleport) 40 | 41 | time_mean = np.mean(time_list, axis=0) 42 | time_teleport_mean = np.mean(time_teleport_list, axis=0) 43 | 44 | train_loss_teleport_mean = np.mean(train_loss_teleport_list, axis=0) 45 | train_loss_teleport_std = np.std(train_loss_teleport_list, axis=0) 46 | valid_loss_teleport_mean = np.mean(valid_loss_teleport_list, axis=0) 47 | valid_loss_teleport_std = np.std(valid_loss_teleport_list, axis=0) 48 | 49 | train_loss_SGD_mean = np.mean(train_loss_list, axis=0) 50 | train_loss_SGD_std = np.std(train_loss_list, axis=0) 51 | valid_loss_SGD_mean = np.mean(valid_loss_list, axis=0) 52 | valid_loss_SGD_std = np.std(valid_loss_list, axis=0) 53 | 54 | 55 | plt.figure() 56 | plt.plot(train_loss_SGD_mean[:x_right], '--', linewidth=3, color='#1f77b4', label='{} train'.format(opt_method)) 57 | plt.plot(valid_loss_SGD_mean[:x_right], '-', linewidth=3, color='#1f77b4', label='{} test'.format(opt_method)) 58 | plt.plot(train_loss_teleport_mean[:x_right], '--', linewidth=3, color='#ff7f0e', label='{}+teleport train'.format(opt_method)) 59 | plt.plot(valid_loss_teleport_mean[:x_right], '-', linewidth=3, color='#ff7f0e', label='{}_teleport test'.format(opt_method)) 60 | 61 | N = len(train_loss_SGD_mean) 62 | plt.fill_between(np.arange(N), \ 63 | valid_loss_SGD_mean[:x_right] - valid_loss_SGD_std[:x_right], \ 64 | valid_loss_SGD_mean[:x_right] + valid_loss_SGD_std[:x_right], \ 65 | color='#1f77b4', alpha=0.5) 66 | plt.fill_between(np.arange(N), \ 67 | valid_loss_teleport_mean[:x_right] - valid_loss_teleport_std[:x_right], \ 68 | valid_loss_teleport_mean[:x_right] + valid_loss_teleport_std[:x_right], \ 69 | color='#ff7f0e', alpha=0.5) 70 | 71 | plt.xlabel('Epoch', fontsize=28) 72 | plt.ylabel('Loss', fontsize=28) 73 | plt.yscale('log') 74 | plt.minorticks_off() 75 | plt.xticks([0, 10, 20, 30, 40], fontsize= 22) 76 | plt.yticks(y_ticks, y_ticks, fontsize= 22) 77 | plt.legend(fontsize=19) 78 | plt.savefig('figures/optimization/{}_{}_loss_vs_epoch.pdf'.format(dataset, opt_method), bbox_inches='tight') 79 | 80 | 81 | fig = plt.subplots() 82 | plt.plot(time_mean, train_loss_SGD_mean[:x_right], '--', linewidth=3, color='#1f77b4', label='{} train'.format(opt_method)) 83 | plt.plot(time_mean, valid_loss_SGD_mean[:x_right], '-', linewidth=3, color='#1f77b4', label='{} test'.format(opt_method)) 84 | plt.plot(time_teleport_mean, train_loss_teleport_mean[:x_right], '--', linewidth=3, color='#ff7f0e', label='{}+teleport train'.format(opt_method)) 85 | plt.plot(time_teleport_mean, valid_loss_teleport_mean[:x_right], '-', linewidth=3, color='#ff7f0e', label='{}+teleport test'.format(opt_method)) 86 | 87 | N = len(train_loss_SGD_mean) 88 | plt.fill_between(time_mean, \ 89 | valid_loss_SGD_mean[:x_right] - valid_loss_SGD_std[:x_right], \ 90 | valid_loss_SGD_mean[:x_right] + valid_loss_SGD_std[:x_right], \ 91 | color='#1f77b4', alpha=0.5) 92 | plt.fill_between(time_teleport_mean, \ 93 | valid_loss_teleport_mean[:x_right] - valid_loss_teleport_std[:x_right], \ 94 | valid_loss_teleport_mean[:x_right] + valid_loss_teleport_std[:x_right], \ 95 | color='#ff7f0e', alpha=0.5) 96 | 97 | plt.xlabel('Time (s)', fontsize=28) 98 | plt.ylabel('Loss', fontsize=28) 99 | plt.yscale('log') 100 | plt.minorticks_off() 101 | plt.xticks(fontsize= 22) 102 | plt.yticks(y_ticks, y_ticks, fontsize= 22) 103 | plt.legend(fontsize=19) 104 | plt.savefig('figures/optimization/{}_{}_loss_vs_time.pdf'.format(dataset, opt_method), bbox_inches='tight') 105 | 106 | 107 | def plot_correlation(dataset, sigma_name): 108 | with open('logs/correlation/{}/{}_final_W_lists/curvatures_all_{}.pkl'.format(dataset, dataset, sigma_name), 'rb') as f: 109 | curvature_mean_list, perturb_mean_list, valid_loss_list, train_loss_list = pickle.load(f) 110 | 111 | plt.figure() 112 | corr, _ = pearsonr(curvature_mean_list, valid_loss_list) 113 | plt.scatter(curvature_mean_list, valid_loss_list, label='Corr={:.3f}'.format(corr)) 114 | plt.xlabel(r'$\psi$', fontsize=26) 115 | plt.ylabel('validation loss', fontsize=26) 116 | plt.yticks(fontsize= 20) 117 | if dataset == 'MNIST': 118 | plt.xlim(0.0005, 0.0035) 119 | plt.xticks([0.001, 0.002, 0.003], fontsize=20) 120 | elif dataset == 'FashionMNIST': 121 | plt.xlim(0.0005, 0.0055) 122 | plt.xticks([0.001, 0.003, 0.005], fontsize=20) 123 | else: 124 | plt.xticks([0.0003, 0.0006, 0.0009], fontsize=20) 125 | plt.legend(fontsize=20) 126 | plt.savefig('figures/correlation/{}_{}_loss_vs_curvature.pdf'.format(dataset, sigma_name), bbox_inches='tight') 127 | 128 | plt.figure() 129 | corr, _ = pearsonr(perturb_mean_list, valid_loss_list) 130 | plt.scatter(perturb_mean_list, valid_loss_list, label='Corr={:.3f}'.format(corr)) 131 | plt.xlabel(r'$\phi$', fontsize=26) 132 | plt.ylabel('validation loss', fontsize=26) 133 | plt.yticks(fontsize= 20) 134 | if dataset == 'MNIST': 135 | plt.xticks([0.0005, 0.0006, 0.0007], fontsize=20) 136 | elif dataset == 'FashionMNIST': 137 | plt.xticks([0.00144, 0.00153, 0.00162], fontsize=20) 138 | else: 139 | plt.xticks([0.0057, 0.0060, 0.0063], fontsize=20) 140 | plt.legend(fontsize=20) 141 | plt.savefig('figures/correlation/{}_{}_loss_vs_perturbed_loss.pdf'.format(dataset, sigma_name), bbox_inches='tight') 142 | 143 | plt.figure() 144 | corr, _ = pearsonr(perturb_mean_list, curvature_mean_list) 145 | plt.scatter(perturb_mean_list, curvature_mean_list, label='Corr={:.3f}'.format(corr)) 146 | plt.xlabel(r'$\phi$', fontsize=26) 147 | plt.ylabel(r'$\psi$', fontsize=26) 148 | if dataset == 'MNIST': 149 | plt.xticks([0.0005, 0.0006, 0.0007], fontsize=20) 150 | plt.ylim(0.0005, 0.0035) 151 | plt.yticks([0.001, 0.002, 0.003], fontsize=20) 152 | elif dataset == 'FashionMNIST': 153 | plt.xticks([0.00144, 0.00153, 0.00162], fontsize=20) 154 | plt.ylim(0.0005, 0.0055) 155 | plt.yticks([0.001, 0.003, 0.005], fontsize=20) 156 | else: 157 | plt.xticks([0.0057, 0.0060, 0.0063], fontsize=20) 158 | plt.yticks([0.0003, 0.0006, 0.0009], fontsize=20) 159 | plt.legend(fontsize=20) 160 | plt.savefig('figures/correlation/{}_{}_curvature_vs_perturbed_loss.pdf'.format(dataset, sigma_name), bbox_inches='tight') 161 | 162 | 163 | def plot_sharpness_curvature(dataset, objective_list): 164 | x_right = 40 165 | 166 | for objective in objective_list: 167 | if objective == 'sharpness': 168 | variable_name = 'phi' 169 | else: 170 | variable_name = 'psi' 171 | 172 | train_loss_list = [] 173 | valid_loss_list = [] 174 | train_loss_teleport_true_list = [] 175 | valid_loss_teleport_true_list = [] 176 | train_loss_teleport_false_list = [] 177 | valid_loss_teleport_false_list = [] 178 | 179 | for run_num in range(5): 180 | with open('logs/generalization/{}/teleport_{}/teleport_{}_true_{}.plk'.format(dataset, objective, objective, run_num), 'rb') as f: 181 | train_loss, train_loss_teleport, valid_loss, valid_loss_teleport = pickle.load(f) 182 | train_loss_list.append(train_loss) 183 | train_loss_teleport_true_list.append(train_loss_teleport) 184 | valid_loss_list.append(valid_loss) 185 | valid_loss_teleport_true_list.append(valid_loss_teleport) 186 | 187 | with open('logs/generalization/{}/teleport_{}/teleport_{}_false_{}.plk'.format(dataset, objective, objective, run_num), 'rb') as f: 188 | train_loss, train_loss_teleport, valid_loss, valid_loss_teleport = pickle.load(f) 189 | train_loss_teleport_false_list.append(train_loss_teleport) 190 | valid_loss_teleport_false_list.append(valid_loss_teleport) 191 | 192 | train_loss_teleport_true_mean = np.mean(train_loss_teleport_true_list, axis=0) 193 | train_loss_teleport_true_std = np.std(train_loss_teleport_true_list, axis=0) 194 | valid_loss_teleport_true_mean = np.mean(valid_loss_teleport_true_list, axis=0) 195 | valid_loss_teleport_true_std = np.std(valid_loss_teleport_true_list, axis=0) 196 | 197 | train_loss_teleport_false_mean = np.mean(train_loss_teleport_false_list, axis=0) 198 | train_loss_teleport_false_std = np.std(train_loss_teleport_false_list, axis=0) 199 | valid_loss_teleport_false_mean = np.mean(valid_loss_teleport_false_list, axis=0) 200 | valid_loss_teleport_false_std = np.std(valid_loss_teleport_false_list, axis=0) 201 | 202 | 203 | # '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf' 204 | plt.figure() 205 | plt.plot(train_loss_list[0][:x_right], '--', linewidth=3, color='black') # label='SGD train' 206 | plt.plot(valid_loss_list[0][:x_right], '-', linewidth=3, color='black', label='SGD') # label='SGD valid' 207 | plt.plot(train_loss_teleport_true_mean[:x_right], '--', linewidth=3, color='#1f77b4') #label=r'teleport(decrease $\{}$) train'.format(variable_name) 208 | plt.plot(valid_loss_teleport_true_mean[:x_right], '-', linewidth=3, color='#1f77b4', label=r'teleport(decrease $\{}$)'.format(variable_name)) #label=r'teleport(decrease $\{}$) valid'.format(variable_name) 209 | plt.plot(train_loss_teleport_false_mean[:x_right], '--', linewidth=3, color='#ff7f0e') #label=r'teleport(increase $\{}$) train'.format(variable_name) 210 | plt.plot(valid_loss_teleport_false_mean[:x_right], '-', linewidth=3, color='#ff7f0e', label=r'teleport(increase $\{}$)'.format(variable_name)) #label=r'teleport(increase $\{}$) valid'.format(variable_name) 211 | 212 | N = len(train_loss_teleport_false_mean) 213 | plt.fill_between(np.arange(N-1), \ 214 | train_loss_teleport_true_mean[:x_right] - train_loss_teleport_true_std[:x_right], \ 215 | train_loss_teleport_true_mean[:x_right] + train_loss_teleport_true_std[:x_right], \ 216 | color='#1f77b4', alpha=0.5) 217 | 218 | plt.fill_between(np.arange(N-1), \ 219 | valid_loss_teleport_true_mean[:x_right] - valid_loss_teleport_true_std[:x_right], \ 220 | valid_loss_teleport_true_mean[:x_right] + valid_loss_teleport_true_std[:x_right], \ 221 | color='#1f77b4', alpha=0.5) 222 | 223 | plt.fill_between(np.arange(N-1), \ 224 | train_loss_teleport_false_mean[:x_right] - train_loss_teleport_false_std[:x_right], \ 225 | train_loss_teleport_false_mean[:x_right] + train_loss_teleport_false_std[:x_right], \ 226 | color='#ff7f0e', alpha=0.5) 227 | 228 | plt.fill_between(np.arange(N-1), \ 229 | valid_loss_teleport_false_mean[:x_right] - valid_loss_teleport_false_std[:x_right], \ 230 | valid_loss_teleport_false_mean[:x_right] + valid_loss_teleport_false_std[:x_right], \ 231 | color='#ff7f0e', alpha=0.5) 232 | 233 | plt.xlabel('Epoch', fontsize=26) 234 | plt.ylabel('Loss', fontsize=26) 235 | plt.xticks([0, 20, 40], fontsize= 20) 236 | if dataset == 'MNIST': 237 | plt.yticks([0.1, 0.5, 0.9, 1.3], fontsize= 20) 238 | elif dataset == 'FashionMNIST': 239 | plt.yticks([0.3, 0.7, 1.1], fontsize= 20) 240 | elif dataset == 'CIFAR10': 241 | plt.yticks([1.4, 1.7, 2.0], fontsize= 20) 242 | plt.legend(fontsize=17) 243 | plt.savefig('figures/generalization/{}_loss_{}.pdf'.format(dataset, objective), bbox_inches='tight') 244 | -------------------------------------------------------------------------------- /teleport_optimization.py: -------------------------------------------------------------------------------- 1 | """ Evaluate various optimizers augmented with teleportation. """ 2 | 3 | import numpy as np 4 | import time 5 | from matplotlib import pyplot as plt 6 | import pickle 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torchvision import datasets 11 | import torchvision.transforms as transforms 12 | from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler 13 | 14 | from gradient_descent_mlp_utils import init_param_MLP, train_step, valid_MLP, teleport 15 | from models import MLP 16 | from plot import plot_optimization 17 | 18 | 19 | device = 'cpu' #'cuda' 20 | run_new = True # False if using cached results 21 | dataset = 'MNIST' # 'MNIST', 'FashionMNIST', 'CIFAR10' 22 | opt_method_list = ['Adagrad', 'momentum', 'RMSprop', 'Adam'] 23 | 24 | criterion = nn.CrossEntropyLoss() 25 | sigma = nn.LeakyReLU(0.1) 26 | batch_size = 20 27 | valid_size = 0.2 28 | tele_batch_size = 200 29 | 30 | # dataset and hyper-parameters 31 | if dataset == 'MNIST': 32 | lr = 1e-2 33 | dim = [batch_size, 28*28, 16, 10, 10] 34 | teledim = [tele_batch_size, 28*28, 16, 10, 10] 35 | train_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor()) 36 | test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor()) 37 | elif dataset == 'FashionMNIST': 38 | lr = 1e-2 39 | dim = [batch_size, 28*28, 16, 10, 10] 40 | teledim = [tele_batch_size, 28*28, 16, 10, 10] 41 | train_data = datasets.FashionMNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor()) 42 | test_data = datasets.FashionMNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor()) 43 | elif dataset == 'CIFAR10': 44 | lr = 2e-2 45 | dim = [batch_size, 32*32*3, 128, 32, 10] 46 | teledim = [tele_batch_size, 32*32*3, 128, 32, 10] 47 | train_data = datasets.CIFAR10(root = 'data', train = True, download = True, transform = transforms.ToTensor()) 48 | test_data = datasets.CIFAR10(root = 'data', train = False, download = True, transform = transforms.ToTensor()) 49 | else: 50 | raise ValueError('dataset should be one of MNIST, fashion, and CIFAR10') 51 | 52 | # data loaders 53 | if dataset in ['MNIST', 'FashionMNIST']: 54 | train_subset, val_subset = torch.utils.data.random_split( 55 | train_data, [50000, 10000], generator=torch.Generator().manual_seed(1)) 56 | train_sampler = SequentialSampler(train_subset) 57 | train_loader = torch.utils.data.DataLoader(train_subset, batch_size = batch_size, 58 | sampler = train_sampler, num_workers = 0) 59 | test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, 60 | num_workers = 0) 61 | teleport_loader = torch.utils.data.DataLoader(train_subset, batch_size = tele_batch_size, 62 | shuffle=True, num_workers = 0) 63 | teleport_loader_iterator = iter(teleport_loader) 64 | else: #CIFAR10 65 | train_sampler = SequentialSampler(train_data) 66 | train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, 67 | sampler = train_sampler, num_workers = 0) 68 | test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, 69 | num_workers = 0) 70 | teleport_loader = torch.utils.data.DataLoader(train_data, batch_size = tele_batch_size, 71 | shuffle=True, num_workers = 0) 72 | teleport_loader_iterator = iter(teleport_loader) 73 | 74 | 75 | def get_optimizer(model, opt_method, lr, dataset): 76 | if opt_method == 'SGD': 77 | return torch.optim.SGD(model.parameters(), lr=lr) 78 | elif opt_method == 'Adagrad': 79 | return torch.optim.Adagrad(model.parameters(), lr=lr) 80 | elif opt_method == 'momentum': 81 | return torch.optim.SGD(model.parameters(), lr=lr/1e1, momentum=0.9) 82 | elif opt_method == 'RMSprop': 83 | return torch.optim.RMSprop(model.parameters(), lr=lr/1e2) 84 | elif opt_method == 'Adam': 85 | return torch.optim.Adam(model.parameters(), lr=lr/1e2) 86 | else: 87 | raise ValueError('opt_method should be one of SGD, AdaGrad, momentum, RMSProp, and Adam') 88 | 89 | 90 | start_epoch = 15 91 | end_epoch = 40 92 | 93 | if run_new == True: 94 | for opt_method in opt_method_list: 95 | if opt_method == 'SGD': 96 | tele_epochs = [2] 97 | tele_lr = 1e-4 98 | tele_step = 10 99 | elif opt_method == 'Adagrad': 100 | tele_epochs = [2] 101 | tele_lr = 1e-4 102 | tele_step = 10 103 | elif opt_method == 'momentum': 104 | tele_epochs = [0] 105 | tele_lr = 5e-2 106 | tele_step = 10 107 | elif opt_method == 'RMSprop': 108 | tele_epochs = [0] 109 | tele_lr = 5e-2 110 | tele_step = 10 111 | elif opt_method == 'Adam': 112 | tele_epochs = [0] 113 | tele_lr = 5e-2 114 | tele_step = 10 115 | else: 116 | raise ValueError('opt_method should be one of SGD, AdaGrad, momentum, RMSProp, and Adam') 117 | 118 | for run_num in range(5): 119 | print(opt_method, 'run', run_num) 120 | 121 | ############################################################## 122 | # training with opt_method without teleportation (e.g. AdaGrad) 123 | W_list = init_param_MLP(dim, seed=(run_num+1)*54321) 124 | loss_arr_SGD = [] 125 | dL_dt_arr_SGD = [] 126 | valid_loss_SGD = [] 127 | valid_correct_SGD = [] 128 | time_SGD = [] 129 | 130 | model = MLP(init_W_list=W_list, activation=sigma) 131 | model.to(device) 132 | optimizer = get_optimizer(model, opt_method, lr, dataset) 133 | 134 | t0 = time.time() 135 | for epoch in range(40): 136 | epoch_loss = 0.0 137 | for data, label in train_loader: 138 | batch_size = data.shape[0] 139 | data = torch.t(data.view(batch_size, -1)) # [20, 1, 28, 28] -> [784, 20] 140 | loss = train_step(data, label, model, criterion, optimizer) 141 | epoch_loss += loss.item() * data.size(1) 142 | 143 | loss_arr_SGD.append(epoch_loss / len(train_loader.sampler)) 144 | valid_loss, valid_correct = valid_MLP(model, criterion, test_loader) 145 | valid_loss_SGD.append(valid_loss) 146 | valid_correct_SGD.append(valid_correct) 147 | 148 | # print(epoch, loss_arr_SGD[-1], valid_loss_SGD[-1], valid_correct_SGD[-1]) 149 | 150 | t1 = time.time() 151 | time_SGD.append(t1 - t0) 152 | 153 | results = (loss_arr_SGD, valid_loss_SGD, dL_dt_arr_SGD, valid_correct_SGD, time_SGD, 0) 154 | with open('logs/optimization/{}/{}_{}_lr_{}_{}.pkl'.format(dataset, dataset, opt_method, lr, run_num), 'wb') as f: 155 | pickle.dump(results, f) 156 | 157 | 158 | ############################################################## 159 | # training with opt_method + teleport 160 | W_list = init_param_MLP(dim, seed=(run_num+1)*54321) 161 | loss_arr_teleport = [] 162 | dL_dt_arr_teleport = [] 163 | valid_loss_teleport = [] 164 | valid_correct_teleport = [] 165 | time_teleport = [] 166 | 167 | model = MLP(init_W_list=W_list, activation=sigma) 168 | model.to(device) 169 | optimizer = get_optimizer(model, opt_method, lr, dataset) 170 | 171 | teleport_count = 0 172 | t0 = time.time() 173 | 174 | for epoch in range(40): 175 | epoch_loss = 0.0 176 | for data, label in train_loader: 177 | batch_size = data.shape[0] 178 | data = torch.t(data.view(batch_size, -1)) # [20, 1, 28, 28] -> [784, 20] 179 | if (epoch in tele_epochs and teleport_count < 8): 180 | teleport_count += 1 181 | W_list = model.get_W_list() 182 | 183 | # load data batch 184 | try: 185 | tele_data, tele_target = next(teleport_loader_iterator) 186 | except StopIteration: 187 | teleport_loader_iterator = iter(teleport_loader) 188 | tele_data, tele_target = next(teleport_loader_iterator) 189 | 190 | # teleport 191 | batch_size = tele_data.shape[0] 192 | tele_data = torch.t(tele_data.view(batch_size, -1)) # [tele_batch_size, 1, 28, 28] -> [784, tele_batch_size] 193 | 194 | W_list = teleport(W_list, tele_data, tele_target, tele_lr, dim, sigma, telestep=tele_step, random_teleport=False, reverse=False) 195 | 196 | # update W_list in model 197 | model = MLP(init_W_list=W_list, activation=sigma) 198 | model.to(device) 199 | optimizer = get_optimizer(model, opt_method, lr, dataset) 200 | 201 | 202 | loss = train_step(data, label, model, criterion, optimizer) 203 | epoch_loss += loss.item() * data.size(1) 204 | 205 | loss_arr_teleport.append(epoch_loss / len(train_loader.sampler)) 206 | valid_loss, valid_correct = valid_MLP(model, criterion, test_loader) 207 | valid_loss_teleport.append(valid_loss) 208 | valid_correct_teleport.append(valid_correct) 209 | 210 | # print(epoch, loss_arr_teleport[-1], valid_loss_teleport[-1], valid_correct_teleport[-1]) 211 | 212 | t1 = time.time() 213 | time_teleport.append(t1 - t0) 214 | 215 | results = (loss_arr_teleport, valid_loss_teleport, dL_dt_arr_teleport, valid_correct_teleport, time_teleport, 0) 216 | with open('logs/optimization/{}/{}_{}_lr_{}_teleport_{}.pkl'.format(dataset, dataset, opt_method, lr, run_num), 'wb') as f: 217 | pickle.dump(results, f) 218 | 219 | plot_optimization(opt_method_list, dataset, lr) 220 | -------------------------------------------------------------------------------- /teleport_sharpness_curvature.py: -------------------------------------------------------------------------------- 1 | """ Teleportation to change sharpness or curvature. """ 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | import pickle 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from torchvision import datasets 10 | import torchvision.transforms as transforms 11 | from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler 12 | 13 | from gradient_descent_mlp_utils import init_param_MLP, train_step, valid_MLP, teleport_curvature, teleport_sharpness 14 | from models import MLP 15 | from plot import plot_sharpness_curvature 16 | 17 | device = 'cpu' #'cuda' 18 | dataset = 'CIFAR10' # 'MNIST', 'FashionMNIST', 'CIFAR10' 19 | objective_list = ['sharpness', 'curvature'] 20 | 21 | sigma = nn.LeakyReLU(0.1) 22 | batch_size = 20 23 | valid_size = 0.2 24 | tele_batch_size = 2000 25 | 26 | # dataset and hyper-parameters 27 | if dataset == 'MNIST': 28 | lr = 1e-2 29 | t_start = 0.001 30 | t_end = 0.2 31 | t_interval = 0.01 32 | dim = [batch_size, 28*28, 16, 10, 10] 33 | teledim = [tele_batch_size, 28*28, 16, 10, 10] 34 | train_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor()) 35 | test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor()) 36 | elif dataset == 'FashionMNIST': 37 | lr = 1e-2 38 | t_start = 0.0001 39 | t_end = 0.02 40 | t_interval = 0.001 41 | dim = [batch_size, 28*28, 16, 10, 10] 42 | teledim = [tele_batch_size, 28*28, 16, 10, 10] 43 | train_data = datasets.FashionMNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor()) 44 | test_data = datasets.FashionMNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor()) 45 | elif dataset == 'CIFAR10': 46 | lr = 2e-2 47 | t_start = 0.0001 48 | t_end = 0.02 49 | t_interval = 0.001 50 | dim = [batch_size, 32*32*3, 32, 10, 10] # 32*32*3, 128, 32, 10] 51 | teledim = [tele_batch_size, 32*32*3, 32, 10, 10] 52 | train_data = datasets.CIFAR10(root = 'data', train = True, download = True, transform = transforms.ToTensor()) 53 | test_data = datasets.CIFAR10(root = 'data', train = False, download = True, transform = transforms.ToTensor()) 54 | else: 55 | raise ValueError('dataset should be one of MNIST, fashion, and CIFAR10') 56 | 57 | # data loaders 58 | if dataset in ['MNIST', 'FashionMNIST']: 59 | train_subset, val_subset = torch.utils.data.random_split( 60 | train_data, [50000, 10000], generator=torch.Generator().manual_seed(1)) 61 | train_sampler = SequentialSampler(train_subset) 62 | train_loader = torch.utils.data.DataLoader(train_subset, batch_size = batch_size, 63 | sampler = train_sampler, num_workers = 0) 64 | test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, 65 | num_workers = 0) 66 | teleport_loader = torch.utils.data.DataLoader(train_subset, batch_size = tele_batch_size, 67 | shuffle=True, num_workers = 0) 68 | teleport_loader_iterator = iter(teleport_loader) 69 | else: #CIFAR10 70 | train_sampler = SequentialSampler(train_data) 71 | train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, 72 | sampler = train_sampler, num_workers = 0) 73 | test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, 74 | num_workers = 0) 75 | teleport_loader = torch.utils.data.DataLoader(train_data, batch_size = tele_batch_size, 76 | shuffle=True, num_workers = 0) 77 | teleport_loader_iterator = iter(teleport_loader) 78 | 79 | 80 | ############################################################## 81 | # run SGD without teleportation once 82 | W_list = init_param_MLP(dim) 83 | loss_arr_SGD = [] 84 | dL_dt_arr_SGD = [] 85 | valid_loss_SGD = [] 86 | valid_correct_SGD = [] 87 | 88 | model = MLP(init_W_list=W_list, activation=sigma) 89 | model.to(device) 90 | criterion = nn.CrossEntropyLoss() 91 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) 92 | 93 | for epoch in range(100): 94 | epoch_loss = 0.0 95 | for data, label in train_loader: 96 | batch_size = data.shape[0] 97 | data = torch.t(data.view(batch_size, -1)) 98 | loss = train_step(data, label, model, criterion, optimizer) 99 | epoch_loss += loss.item() * data.size(1) 100 | 101 | loss_arr_SGD.append(epoch_loss / len(train_loader.sampler)) 102 | valid_loss, valid_correct = valid_MLP(model, criterion, test_loader) #valid_loader) 103 | valid_loss_SGD.append(valid_loss) 104 | valid_correct_SGD.append(valid_correct) 105 | 106 | if epoch % 10 == 0: 107 | print(epoch, loss_arr_SGD[-1], valid_loss_SGD[-1], valid_correct_SGD[-1]) 108 | 109 | W_list = model.get_W_list() 110 | with open('logs/generalization/{}/{}_SGD/{}_SGD_epoch_{}.pkl'.format(dataset, dataset, dataset, epoch), 'wb') as f: 111 | pickle.dump(W_list, f) 112 | 113 | results = (loss_arr_SGD, valid_loss_SGD, dL_dt_arr_SGD, valid_correct_SGD, 0) 114 | with open('logs/generalization/{}/{}_SGD_lr_{:e}.pkl'.format(dataset, dataset, lr), 'wb') as f: 115 | pickle.dump(results, f) 116 | 117 | 118 | with open('logs/generalization/{}/{}_SGD_lr_{:e}.pkl'.format(dataset, dataset, lr), 'rb') as f: 119 | loss_arr_SGD, valid_loss_SGD, dL_dt_arr_SGD, _, _ = pickle.load(f) 120 | 121 | 122 | ############################################################## 123 | # training with SGD + teleport sharpness/curvature 124 | 125 | if dataset == 'CIFAR10': 126 | lr_teleport_sharpness = {True: 5e-2, False: 5e-2} 127 | lr_teleport_curvature = {True: 5e-2, False: 2e-1} 128 | elif dataset == 'FashionMNIST': 129 | lr_teleport_sharpness = {True: 3e-1, False: 1e-1} 130 | lr_teleport_curvature = {True: 3e-3, False: 5e-3} 131 | elif dataset == 'MNIST': 132 | lr_teleport_sharpness = {True: 5e-2, False: 5e-2} 133 | lr_teleport_curvature = {True: 5e-2, False: 2e-1} 134 | 135 | 136 | start_epoch = 15 # use saved weights from the SGD run 137 | end_epoch = 40 138 | 139 | for objective in objective_list: # sharpness or curvature 140 | for reverse in [False, True]: # teleport to increase or decrease sharpness/curvature 141 | for run_num in range(5): 142 | loss_arr_teleport_rand_curvature = [] 143 | dL_dt_arr_teleport_rand_curvature = [] 144 | valid_loss_teleport_rand_curvature = [] 145 | valid_correct_teleport_rand_curvature = [] 146 | 147 | if start_epoch == 0: 148 | W_list = init_param_MLP(dim) 149 | else: 150 | with open('logs/generalization/{}/{}_SGD/{}_SGD_epoch_{}.pkl'.format(dataset, dataset, dataset, start_epoch), 'rb') as f: 151 | W_list = pickle.load(f) 152 | loss_arr_teleport_rand_curvature = loss_arr_SGD[:start_epoch+1] 153 | valid_loss_teleport_rand_curvature = valid_loss_SGD[:start_epoch+1] 154 | 155 | model = MLP(init_W_list=W_list, activation=sigma) 156 | model.to(device) 157 | criterion = nn.CrossEntropyLoss() #nn.MSELoss() 158 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) 159 | 160 | 161 | teleport_count = 0 162 | for epoch in range(start_epoch, end_epoch): 163 | epoch_loss = 0.0 164 | for data, label in train_loader: 165 | batch_size = data.shape[0] 166 | data = torch.t(data.view(batch_size, -1)) 167 | if (epoch == 20 and teleport_count < 10):# 5 for mnist, fashion 168 | # teleport 169 | teleport_count += 1 170 | W_list = model.get_W_list() 171 | 172 | # load data batch 173 | try: 174 | tele_data, tele_target = next(teleport_loader_iterator) 175 | except StopIteration: 176 | teleport_loader_iterator = iter(teleport_loader) 177 | tele_data, tele_target = next(teleport_loader_iterator) 178 | 179 | # teleport 180 | batch_size = tele_data.shape[0] 181 | tele_data = torch.t(tele_data.view(batch_size, -1)) 182 | if objective == 'sharpness': 183 | if (reverse == True and teleport_count < 11) or (reverse == False and teleport_count < 3): # teleport once if increasing sharpness 184 | W_list = teleport_sharpness(W_list, tele_data, tele_target, lr_teleport_sharpness[reverse], teledim, sigma, \ 185 | telestep=10, reverse=reverse, t_start=t_start, t_end=t_end, t_interval=t_interval) 186 | elif objective == 'curvature': 187 | if (reverse == True and teleport_count < 6) or (reverse == False and teleport_count < 11): # teleport once if increasing sharpness 188 | W_list = teleport_curvature(W_list, tele_data, tele_target, lr_teleport_curvature[reverse], teledim, sigma, telestep=1, reverse=reverse) 189 | else: 190 | raise ValueError("Teleportation objective should be either sharpness or curvature") 191 | 192 | # update W_list in model 193 | model = MLP(init_W_list=W_list, activation=sigma) 194 | model.to(device) 195 | criterion = nn.CrossEntropyLoss() #nn.MSELoss() 196 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) 197 | 198 | 199 | loss = train_step(data, label, model, criterion, optimizer) 200 | epoch_loss += loss.item() * data.size(1) 201 | 202 | loss_arr_teleport_rand_curvature.append(epoch_loss / len(train_loader.sampler)) 203 | valid_loss, valid_correct = valid_MLP(model, criterion, test_loader) #valid_loader) 204 | valid_loss_teleport_rand_curvature.append(valid_loss) 205 | valid_correct_teleport_rand_curvature.append(valid_correct) 206 | 207 | if epoch % 1 == 0: 208 | print(epoch, loss_arr_teleport_rand_curvature[-1], valid_loss_teleport_rand_curvature[-1], valid_correct_teleport_rand_curvature[-1]) 209 | 210 | 211 | results = (loss_arr_SGD, loss_arr_teleport_rand_curvature, valid_loss_SGD, valid_loss_teleport_rand_curvature) 212 | with open('logs/generalization/{}/teleport_{}/teleport_{}_{}_{}.plk'.format(dataset, objective, objective, reverse,run_num), 'wb') as f: 213 | pickle.dump(results, f) 214 | 215 | plot_sharpness_curvature(dataset, objective_list) 216 | --------------------------------------------------------------------------------