├── .gitignore
├── README.md
├── train.ipynb
├── linear_transformer.py
├── rotation demonstration-Adam-p0.ipynb
├── variable_L_exp.ipynb
├── variable_N_exp.ipynb
├── plot_stochastic_noise.ipynb
└── plot_loss.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | "*.pdf"
2 | "*.pth"
3 | "*.png"
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LinearTransformer
2 | Pytorch code for reproducing experiments for the following papers:
3 |
4 | [1] [Transformers learn to implement preconditioned gradient descent for in-context learning](https://arxiv.org/abs/2306.00297). *Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, Suvrit Sra*
5 | [2] [Linear attention is (maybe) all you need (to understand Transformer optimization)](https://arxiv.org/abs/2310.01082). *Kwangjun Ahn, Xiang Cheng, Minhak Song, Chulhee Yun, Suvrit Sra, Ali Jadbabaie*
6 |
7 |
8 |
9 |
10 | **'simple demonstration.ipynb'**:
11 | - Training a 3 layer Linear Transformer with SGD/Adam, **covariates have identity covariance**
12 | - Plotting test loss
13 | - Displaying matrices at end of training + distance to identity (similar to Figure 4 of [1])
14 |
15 | **'rotation demonstration-Adam.ipynb'**:
16 | - Training a 3 layer Linear Transformer with Adam, **covariates have non-identity covariance** (Adam requires about 100x more steps to converge compared to the identity covariance case)
17 | - Plotting test loss
18 | - Displaying matrices at end of training + distance to identity (similar to Figure 4 of [1])
19 | 'rotation demonstration-Adam-p0.ipynb' is similar to 'rotation demonstration-Adam.ipynb', but enforces that the P matrix has top left block = 0
20 |
21 | **'variable_L_exp.ipynb'**:
22 | - Compares n-layer linear Transformer against n-step Gradient Descent/ Preconditioned Gradient Descent, for n = 1,2,3,4, for fixed context length N=20
23 |
24 | **'variable_N_exp.ipynb':**
25 | - Compares 3-layer linear Transformer against 3-step Gradient Descent/ Preconditioned Gradient Descent, for context length N={2,4,6...20}
26 |
27 |
28 | **'linear_transformer.py'** contains definition of the Linear Transformer model, along with some other handy functions.
29 |
30 |
31 |
32 | ### Quck Start
33 | Setting: training 3 layer linear transformer with Adam/SGD (with clipping), covariates have normal convariance
34 |
35 | 1. Run `train.ipynb` - training linear transformer
36 | 2. Run `plot_loss.ipynb` - generates loss vs iteration plot
37 | 3. Run `plot_stochastic_noise.ipynb` - generates stochastic gradient noise histogram
38 | 4. Run `plot_condition_number.ipynb` - generates robust 5ondition number plot
39 | 5. Run `plot_smoothness_vs_gradnorm.ipynb` - generates smoothness vs gradienr norm plot
40 |
41 | ### Hyperparameters
42 |
43 | `mode`: method of generating samples (`['normal', 'sphere', 'gamma']`)
44 |
45 | `alg`: Optimization algorithm (`['sgd', 'adam']`)
46 |
47 | `toclip`: `True` if use clipping algorithm. Otherwise, `False`.
48 |
49 | `lr`: learning rate
50 |
51 | `sd`: random seed
52 |
53 | `max_iters`: maximum number of iterations
54 |
55 | `n_layer`: number of layers of linear transformer
56 |
57 | `N`: number of in-context samples
58 |
59 | ### Learning Rates
60 |
61 | |Setting (n_layer=3)|SGDM (with clipping)|Adam (with clipping)|
62 | |-----|-----:|-----:|
63 | |`mode='normal', N=5`|0.01|0.005|
64 | |`mode='normal', N=20`|0.02|0.02|
65 | |`mode='sphere', N=20`|5|0.1|
66 | |`mode='gamma', N=20`|0.02|0.02|
--------------------------------------------------------------------------------
/train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 19,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "%matplotlib inline\n",
11 | "from matplotlib import pyplot as plt\n",
12 | "import math\n",
13 | "import torch.nn.functional as F\n",
14 | "from torch.nn.functional import relu\n",
15 | "from torch import nn\n",
16 | "import torch.optim as optim\n",
17 | "import torch.optim.lr_scheduler as lr_scheduler\n",
18 | "import random\n",
19 | "import numpy as np\n",
20 | "import gc\n",
21 | "from pylab import *\n",
22 | "import os\n",
23 | "import random\n",
24 | "import json\n",
25 | "import pandas as pd\n",
26 | "from scipy.stats import norm\n",
27 | "pd.set_option('display.float_format', lambda x: '%.5f' % x)\n",
28 | "import sys\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import time\n",
31 | "\n",
32 | "from linear_transformer import Transformer_F, attention, generate_data, in_context_loss, generate_data_inplace\n",
33 | "\n",
34 | "np.set_printoptions(precision = 4, suppress = True)\n",
35 | "torch.set_printoptions(precision=2)\n",
36 | "device = torch.device(\"cuda\")\n",
37 | "torch.cuda.set_device(0)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 20,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "# Set Hyperparameters\n",
47 | "\n",
48 | "# Fixed\n",
49 | "n_head = 1\n",
50 | "d = 5\n",
51 | "B = 1000\n",
52 | "var = 0.05\n",
53 | "shape_k = 0.1\n",
54 | "\n",
55 | "# Number of Iterations to run\n",
56 | "max_iters = 10000\n",
57 | "hist_stride = 1\n",
58 | "\n",
59 | "# We vary the following parameters\n",
60 | "n_layer = 3\n",
61 | "mode = 'normal'\n",
62 | "N = 20\n",
63 | "seeds = [0,1,2,3,4,5]\n",
64 | "algos = ['sgd','adam']\n",
65 | "lrs = [0.02]\n"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 21,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "# pipe output to log file\n",
75 | "log_dir = 'log' \n",
76 | "os.makedirs(log_dir, exist_ok=True)\n",
77 | "f = open(log_dir + '/train.log', \"a\", 1)\n",
78 | "sys.stdout = f\n",
79 | "filename_format = log_dir + '/train_layer{}_N{}_{}_{}_{}_lr{}_sd{}.pth'"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 22,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "# one-step update of (non-)clipping algotirthm\n",
89 | "def clip_and_step(allparam, optimizer, toclip, clip_threshold = 1.):\n",
90 | " grad_all = allparam.grad\n",
91 | " grad_p = grad_all\n",
92 | " norm_p = grad_p.norm()\n",
93 | " if toclip and norm_p > clip_threshold:\n",
94 | " grad_all.mul_(clip_threshold/norm_p)\n",
95 | " optimizer.step()\n",
96 | " return norm_p.item()"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 23,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "\n",
106 | "## Train linear transformer\n",
107 | "\n",
108 | "for alg in algos:\n",
109 | " for toclip in [True]: # True means with clipping, False means without clipping\n",
110 | " for lr in lrs:\n",
111 | " for sd in seeds:\n",
112 | " filename = filename_format.format(n_layer, N, mode, alg, toclip, lr, sd)\n",
113 | " print(filename)\n",
114 | " np.random.seed(sd)\n",
115 | " torch.manual_seed(sd)\n",
116 | " hist_list = list()\n",
117 | "\n",
118 | " # initialize model paramter\n",
119 | " model = Transformer_F(n_layer, n_head, d, var)\n",
120 | " model.to(device)\n",
121 | "\n",
122 | " # create optimizer\n",
123 | " if alg == 'sgd':\n",
124 | " optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0)\n",
125 | " elif alg == 'adam':\n",
126 | " optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.9), weight_decay=0)\n",
127 | " else: assert False\n",
128 | "\n",
129 | " for t in range(max_iters):\n",
130 | " start = time.time()\n",
131 | " # save model parameters\n",
132 | " if t%hist_stride ==0:\n",
133 | " hist_list.append(model.allparam.clone().detach())\n",
134 | "\n",
135 | " # generate a new batch of training set\n",
136 | " Z, y = generate_data(mode,N,d,B,shape_k)\n",
137 | " Z = Z.to(device)\n",
138 | " y = y.to(device)\n",
139 | "\n",
140 | " loss = in_context_loss(model, Z, y)\n",
141 | " loss_value = loss.item()\n",
142 | " loss.backward()\n",
143 | "\n",
144 | " if mode == 'sphere':\n",
145 | " clip_threshold = 0.1\n",
146 | " else:\n",
147 | " clip_threshold = 1.0\n",
148 | "\n",
149 | " # take optimizer step\n",
150 | " norms = clip_and_step(model.allparam, optimizer,toclip,clip_threshold)\n",
151 | " optimizer.zero_grad()\n",
152 | " \n",
153 | " end=time.time()\n",
154 | " if t%100 ==0 or t<5:\n",
155 | " print('iter {} | Loss: {} time: {} gradnorm: {}'.format(t,loss_value, end-start, norms))\n",
156 | " \n",
157 | " torch.save({'hist_list':hist_list}, filename)"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 24,
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "sys.stdout = sys.__stdout__\n",
167 | "f.close()"
168 | ]
169 | }
170 | ],
171 | "metadata": {
172 | "kernelspec": {
173 | "display_name": "pytorch",
174 | "language": "python",
175 | "name": "python3"
176 | },
177 | "language_info": {
178 | "codemirror_mode": {
179 | "name": "ipython",
180 | "version": 3
181 | },
182 | "file_extension": ".py",
183 | "mimetype": "text/x-python",
184 | "name": "python",
185 | "nbconvert_exporter": "python",
186 | "pygments_lexer": "ipython3",
187 | "version": "3.11.3"
188 | },
189 | "orig_nbformat": 4
190 | },
191 | "nbformat": 4,
192 | "nbformat_minor": 2
193 | }
194 |
--------------------------------------------------------------------------------
/linear_transformer.py:
--------------------------------------------------------------------------------
1 | ###########################################
2 | # This file contains the following:
3 | # 1. Linear Transformer Model
4 | # 2. Function for clipping gradient
5 | # 3. Function for generating random data
6 | #
7 | # The notation for linear attention follows
8 | # the paper at https://arxiv.org/pdf/2306.00297.pdf
9 | ###########################################
10 |
11 |
12 | import torch
13 | from torch import nn
14 | import numpy as np
15 |
16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 |
18 | # Definition of a single linear attention unit for linear-regression data
19 | # P is the value matrix
20 | # Q is the product of key,query matrices
21 | # the dimensions of the input are
22 | # B: batch-size of prompts
23 | # N: context length (excluding query)
24 | # d: covariate dimension
25 | # P,Q are d x d matrices
26 | # Z is a B x (N+1) + (d+1) matrix
27 | # Output is also B x (N+1) + (d+1)
28 |
29 | # For linear attention, activation = None
30 | # For standard attention, activation(x) = torch.nn.functional.softmax(x, dim = 2)
31 | # For ReLU attention, activation(x) = torch.nn.relu(x)
32 | def attention(P,Q,Z, activation = None):
33 | B= Z.shape[0]
34 | N = Z.shape[1]-1
35 | d = Z.shape[2]-1
36 | P_full = torch.cat([P,torch.zeros(1,d).to(device)],dim=0)
37 | P_full = torch.cat([P_full,torch.zeros(d+1,1).to(device)],dim=1)
38 | P_full[d,d] = 1
39 | Q_full = torch.cat([Q, torch.zeros(1,d).to(device)],dim=0)
40 | Q_full = torch.cat([Q_full, torch.zeros(d+1,1).to(device)],dim=1)
41 | A = torch.eye(N+1).to(device)
42 | A[N,N] = 0
43 | Attn = torch.einsum('BNi, ij, BMj -> BNM', (Z,Q_full,Z))
44 | if activation is not None:
45 | Attn = activation(Attn)
46 | key = torch.einsum('ij, BNj -> BNi', (P_full,Z))
47 | Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn,A,key))
48 | return Output /N
49 |
50 |
51 | # The Linear Transformer module
52 | # n_layer denotes the number of layers
53 | # n_head denotes the number of heads. In most of our experiments, n_head = 1
54 | # d denotes the dimension of covariates
55 | # var denotes the variance of initialization. It needs to be sufficiently small, but exact value is not important
56 | # allparam: contains all the parameters, has dimension n_layer x n_head x 2 x d x d
57 | # For example
58 | # - P matrix at layer i, head j is allparam[i,j,0,:,:]
59 | # - Q matrix at layer i, head j is allparam[i,j,1,:,:]
60 | class Transformer_F(nn.Module):
61 | def __init__(self, n_layer, n_head, d, var):
62 | super(Transformer_F, self).__init__()
63 | self.register_parameter('allparam', torch.nn.Parameter(torch.zeros(n_layer, n_head, 2, d, d)))
64 | with torch.no_grad():
65 | self.allparam.normal_(0,var)
66 | self.n_layer = n_layer
67 | self.n_head = n_head
68 |
69 | def forward(self, Z):
70 | for i in range(self.n_layer):
71 | Zi = Z
72 | residues = 0
73 | # the forwarad map of each layer is given by F(Z) = Z + attention(Z)
74 | for j in range(self.n_head):
75 | Pij = self.allparam[i,j,0,:,:]
76 | Qij = self.allparam[i,j,1,:,:]
77 | residues = residues + attention(Pij,Qij,Zi)
78 | Z = Zi + residues
79 | return Z
80 |
81 | #enforces top-left-dxd-block sparsity on p
82 | def zero_p(self):
83 | for i in range(self.n_layer):
84 | for j in range(self.n_head):
85 | with torch.no_grad():
86 | self.allparam[i,j,0,:,:].zero_()
87 |
88 | # evaluate the loss of model, given data (Z,y)
89 | def in_context_loss(model, Z, y):
90 | N = Z.shape[1]-1
91 | d = Z.shape[2]-1
92 | output = model(Z)
93 | diff = output[:,N,d]+y
94 | loss = ((diff)**2).mean()
95 | return loss
96 |
97 | # generate random data for linear regression
98 | # mode: distribution of samples to generate. Currently supports 'normal', 'gamma', 'sphere'
99 | # N: number of context examples
100 | # d: dimension of covariates
101 | # For gamma distribution:
102 | # - shape_k: shape parameter of gamma distribution (unused otherwise)
103 | # - scale parameter: hard coded so that when shape_k = 5/2 and d=5, the generated data is standard normal
104 | def generate_data(mode='normal',N=20,d=1,B=1000,shape_k=0.1, U=None, D=None):
105 | W= torch.FloatTensor(B, d).normal_(0,1).to(device)
106 | X = torch.FloatTensor(B, N, d).normal_(0, 1).to(device)
107 | X_test = torch.FloatTensor(B,1,d).normal_(0, 1).to(device)
108 |
109 | if U is not None:
110 | U = U.to(device)
111 | D = D.to(device)
112 | W= torch.FloatTensor(B, d).normal_(0,1).to(device)
113 | W = torch.mm(W,torch.inverse(D))
114 | W = torch.mm(W,U.t())
115 |
116 | if mode =='sphere':
117 | X.div_(X.norm(p=2,dim=2)[:,:,None])
118 | X_test.div_(X_test.norm(p=2,dim=2)[:,:,None])
119 | elif mode == 'gamma':
120 | # random gamma scaling for X
121 | gamma_scales = np.random.gamma(shape=shape_k, scale=(10/shape_k)**(0.5), size=[B,N])
122 | gamma_scales = torch.Tensor(gamma_scales).to(device)
123 | gamma_scales = gamma_scales.sqrt()
124 | # random gamma scaling for X_test
125 | gamma_test_scales = np.random.gamma(shape=shape_k, scale=(10/shape_k)**(0.5), size=[B,1])
126 | gamma_test_scales = torch.Tensor(gamma_test_scales).to(device)
127 | gamma_test_scales = gamma_test_scales.sqrt()
128 | # normalize to unit norm
129 | X.div_(X.norm(p=2,dim=2)[:,:,None])
130 | X_test.div_(X_test.norm(p=2,dim=2)[:,:,None])
131 | # scale by gamma
132 | X.mul_(gamma_scales[:,:,None])
133 | X_test.mul_(gamma_test_scales[:,:,None])
134 | elif mode =='normal':
135 | assert True
136 | elif mode == 'relu':
137 | return generate_data_relu(N=N, d=d, B=B, hidden_dim=d)
138 | elif mode == 'mlp':
139 | generate_data_mlp(N=N, d=d, B=B, hidden_dim=d)
140 | else:
141 | assert False
142 |
143 | if U is not None:
144 | X = torch.einsum('ij, jk, BNk -> BNi', (U,D,X))
145 | X_test = torch.einsum('ij, jk, BNk -> BNi', (U,D,X_test))
146 |
147 | y = torch.einsum('bi,bni->bn', (W, X)).unsqueeze(2)
148 | y_zero = torch.zeros(B,1,1).to(device)
149 | y_test = torch.einsum('bi,bni->bn', (W, X_test)).squeeze(1)
150 | X_comb= torch.cat([X,X_test],dim=1)
151 | y_comb= torch.cat([y,y_zero],dim=1)
152 | Z= torch.cat([X_comb,y_comb],dim=2)
153 | return Z.to(device),y_test.to(device)
154 |
155 | def generate_data_inplace(Z, U=None, D=None):
156 |
157 |
158 | B = Z.shape[0]
159 | N = Z.shape[1]-1
160 | d = Z.shape[2]-1
161 | X = Z[:,:,0:-1]
162 | X.normal_(0, 1).to(device)
163 | W= torch.FloatTensor(B, d).normal_(0,1).to(device)
164 | if U is not None:
165 | U = U.to(device)
166 | D = D.to(device)
167 | W = torch.mm(W,torch.inverse(D))
168 | W = torch.mm(W,U.t())
169 | Z[:,:,0:-1] = torch.einsum('ij, jk, BNk -> BNi', (U,D,X))
170 |
171 | Z[:,:,-1] = torch.einsum('bi,bni->bn', (W, Z[:,:,0:-1])) #y update
172 | y_test = Z[:,-1,-1].detach().clone()
173 | Z[:,-1,-1].zero_()
174 | return Z.to(device),y_test.to(device)
175 |
176 | def generate_data_sine(N=10, B=1000):
177 | # Sample amplitude a and phase p for each task
178 | a = torch.FloatTensor(B).uniform_(0.1, 5).to(device)
179 | p = torch.FloatTensor(B).uniform_(0, math.pi).to(device)
180 |
181 | X = torch.FloatTensor(B, N).uniform_(-5, 5).to(device)
182 |
183 | Y = a.unsqueeze(1) * torch.sin(p.unsqueeze(1) + X)
184 |
185 | X = X.unsqueeze(-1)
186 | Y = Y.unsqueeze(-1)
187 |
188 | return X, Y
189 |
190 | def generate_data_relu(mode='normal', N=20, d=1, B=1000, shape_k=0.1, U=None, D=None, hidden_dim=100):
191 | # Generate random input data
192 | X = torch.FloatTensor(B, N, d).normal_(0, 1).to(device)
193 | X_test = torch.FloatTensor(B, 1, d).normal_(0, 1).to(device)
194 |
195 | # Additional transformations if mode is 'sphere' or 'gamma' [Similar to the existing generate_data function]
196 |
197 | # Define a 1-hidden layer ReLU network
198 | model = nn.Sequential(
199 | nn.Linear(d, hidden_dim),
200 | nn.ReLU(),
201 | nn.Linear(hidden_dim, 1)
202 | ).to(device)
203 | model[0].weight.data.normal_(0, 0.1)
204 | model[2].weight.data.normal_(0, 0.1)
205 |
206 | # Generate y values using the ReLU network
207 | y = model(X.view(-1, d)).view(B, N, 1)
208 | y_test = model(X_test.view(-1, d)).view(B, 1).squeeze(1)
209 |
210 | y_zero = torch.zeros(B, 1, 1).to(device)
211 | X_comb = torch.cat([X, X_test], dim=1)
212 | y_comb = torch.cat([y, y_zero], dim=1)
213 | Z = torch.cat([X_comb, y_comb], dim=2)
214 |
215 | return Z, y_test
216 |
217 | def generate_data_mlp(N=20, d=1, B=1000, hidden_dim=100):
218 | # Generate random input data
219 | X = torch.FloatTensor(B, N, d).normal_(0, 1).to(device)
220 | X_test = torch.FloatTensor(B, 1, d).normal_(0, 1).to(device)
221 |
222 | # Additional transformations if mode is 'sphere' or 'gamma' [Similar to the existing generate_data function]
223 |
224 | # Define a 1-hidden layer ReLU network
225 | model = nn.Sequential(
226 | nn.Linear(d, hidden_dim),
227 | nn.ReLU(),
228 | nn.Linear(hidden_dim, d)
229 | ).to(device)
230 | model[0].weight.data.normal_(0, 1)
231 | model[2].weight.data.normal_(0, 1)
232 |
233 | X_MLP = model(X.view(-1, d)).view(B, N, d)
234 | X_test_MLP = model(X_test.view(-1, d)).view(B, 1, d)
235 |
236 | W = torch.FloatTensor(B, d).normal_(0,1).to(device)
237 | y = torch.einsum('bi,bni->bn', (W, X_MLP)).unsqueeze(2)
238 | y_zero = torch.zeros(B,1,1).to(device)
239 | y_test = torch.einsum('bi,bni->bn', (W, X_test_MLP)).squeeze(1)
240 | X_comb= torch.cat([X_MLP,X_test_MLP],dim=1)
241 | y_comb= torch.cat([y,y_zero],dim=1)
242 | Z= torch.cat([X_comb,y_comb],dim=2)
243 |
244 | return Z, y_test
245 |
--------------------------------------------------------------------------------
/rotation demonstration-Adam-p0.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 11,
6 | "id": "3fcfaf4d",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import torch\n",
11 | "from matplotlib import pyplot as plt\n",
12 | "import sys\n",
13 | "import time\n",
14 | "import os\n",
15 | "import numpy as np\n",
16 | "import math\n",
17 | "\n",
18 | "#####################################################\n",
19 | "# Same as rotation demonstration-Adam.ipynb, except\n",
20 | "# we additionally enforce that P=0 for each layer\n",
21 | "#####################################################\n",
22 | "\n",
23 | "#use cuda if available, else use cpu\n",
24 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
25 | "#torch.cuda.set_device(2)\n",
26 | "# import the model and some useful functions\n",
27 | "from linear_transformer import Transformer_F, attention, generate_data, in_context_loss, generate_data_inplace\n",
28 | "\n",
29 | "# set up some print options\n",
30 | "np.set_printoptions(precision = 2, suppress = True)\n",
31 | "torch.set_printoptions(precision=2)\n",
32 | "\n",
33 | "#begin logging\n",
34 | "cur_dir = 'log' \n",
35 | "os.makedirs(cur_dir, exist_ok=True)\n",
36 | "#f = open(cur_dir + '/rotation.log', \"a\", 1)\n",
37 | "#sys.stdout = f"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 18,
43 | "id": "9700bf1b",
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "# Set up problem parameters\n",
48 | "\n",
49 | "lr = 0.02\n",
50 | "clip_r = 0.01\n",
51 | "alg = 'adam'\n",
52 | "mode = 'normal'\n",
53 | "\n",
54 | "n_layer = 3 # number of layers of transformer\n",
55 | "N = 20 # context length\n",
56 | "d = 5 # dimension\n",
57 | "\n",
58 | "\n",
59 | "n_head = 1 # 1-headed attention\n",
60 | "B = 20000 # 1000 minibatch size\n",
61 | "var = 0.0001 # initializations scale of transformer parameter\n",
62 | "shape_k = 0.1 # shape_k: parameter for Gamma distributed covariates\n",
63 | "max_iters = 30000 # Number of Iterations to run\n",
64 | "hist_stride = 1 # stride for saved model paramters in `train.ipynb'\n",
65 | "stride = 100\n",
66 | "\n",
67 | "# a convenience function for taking a step and clipping\n",
68 | "def clip_and_step(allparam, optimizer, clip_r = None):\n",
69 | " norm_p=None\n",
70 | " grad_all = allparam.grad\n",
71 | " if clip_r is not None:\n",
72 | " for l in range(grad_all.shape[0]):\n",
73 | " for h in range(grad_all.shape[1]):\n",
74 | " for t in range(grad_all.shape[2]):\n",
75 | " norm_p = grad_all[l,h,t,:,:].norm().item()\n",
76 | " if norm_p > clip_r:\n",
77 | " grad_all[l,h,t,:,:].mul_(clip_r/norm_p)\n",
78 | " optimizer.step()\n",
79 | " return norm_p"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "id": "e69d0ea4",
86 | "metadata": {
87 | "scrolled": false
88 | },
89 | "outputs": [],
90 | "source": [
91 | "filename_format = '/rotation_hist_adam_pnull_{}_{}_{}.pth'\n",
92 | "filename = filename_format.format(n_layer, N, d)\n",
93 | "filename = (cur_dir + filename)\n",
94 | "hist_dict = {}\n",
95 | "U_dict = {}\n",
96 | "D_dict = {}\n",
97 | "\n",
98 | "seeds = [0,1,2,3,4]\n",
99 | "keys = [(s,) for s in seeds]\n",
100 | "for key in keys:\n",
101 | " sd = key[0]\n",
102 | " \n",
103 | " prob_seed = sd\n",
104 | " opt_seed = sd\n",
105 | " \n",
106 | " hist_dict[key] = []\n",
107 | " \n",
108 | " #set seed and initialize model\n",
109 | " torch.manual_seed(opt_seed)\n",
110 | " \n",
111 | " model = Transformer_F(n_layer, 1, d, var)\n",
112 | " model.to(device)\n",
113 | " #initialize algorithm. Important: set beta = 0.9 for adam, 0.999 is very slow\n",
114 | " optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.99, 0.9), weight_decay=0)\n",
115 | " \n",
116 | " # set seed\n",
117 | " # sample random rotation matrix\n",
118 | " # initialize initial training batch\n",
119 | " np.random.seed(prob_seed)\n",
120 | " torch.manual_seed(prob_seed)\n",
121 | " gaus = torch.FloatTensor(5,5).uniform_(-1,1).cuda()\n",
122 | " U = torch.linalg.svd (gaus)[0].cuda()\n",
123 | " D = torch.diag(torch.FloatTensor([1,1,1/2,1/4,1])).cuda()\n",
124 | " U_dict[key]=U\n",
125 | " D_dict[key]=D\n",
126 | " Z, y = generate_data(mode,N,d,B,shape_k, U, D)\n",
127 | " Z = Z.to(device)\n",
128 | " y = y.to(device)\n",
129 | " for t in range(max_iters):\n",
130 | " if t%2000==0 and t>1:# and t < 200001:\n",
131 | " optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] *0.5\n",
132 | " if t%100==0:\n",
133 | " Z,y = generate_data_inplace(Z, U=U, D=D)\n",
134 | " start = time.time()\n",
135 | " # save model parameters\n",
136 | " if t%stride ==0:\n",
137 | " hist_dict[key].append(model.allparam.clone().detach())\n",
138 | " loss = in_context_loss(model, Z, y)\n",
139 | " # compute gradient, take step\n",
140 | " loss.backward()\n",
141 | " norms = clip_and_step(model.allparam, optimizer, clip_r=clip_r)\n",
142 | " optimizer.zero_grad()\n",
143 | " \n",
144 | " #IMPORTANT: zero out the p matrices after each update! This enforces the P=0 constraint.\n",
145 | " model.zero_p()\n",
146 | "\n",
147 | " end=time.time()\n",
148 | " if t%100 ==0 or t<5:\n",
149 | " print('iter {} | Loss: {} time: {} gradnorm: {}'.format(t,loss.item(), end-start, norms))\n",
150 | "#save to \n",
151 | "torch.save({'hist_dict':hist_dict, 'U_dict':U_dict, 'D_dict':D_dict}, filename)"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": 54,
157 | "id": "83154b0e",
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "####################################\n",
162 | "# compute test loss\n",
163 | "####################################\n",
164 | "#hist_dict = torch.load(filename)['hist_dict']\n",
165 | "loss_dict = {}\n",
166 | "for key in hist_dict:\n",
167 | " sd = key[0]\n",
168 | " \n",
169 | " U = U_dict[key]\n",
170 | " D = D_dict[key]\n",
171 | " \n",
172 | " loss_dict[key] = torch.zeros(max_iters//stride)\n",
173 | " \n",
174 | " np.random.seed(99)\n",
175 | " torch.manual_seed(99)\n",
176 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
177 | " Z = Z.to(device)\n",
178 | " y = y.to(device)\n",
179 | " model = Transformer_F(n_layer, n_head, d, var).to(device)\n",
180 | " for t in range(0,max_iters,stride):\n",
181 | " with torch.no_grad():\n",
182 | " model.allparam.copy_(hist_dict[key][t//stride])\n",
183 | " loss_dict[key][t//stride] = in_context_loss(model, Z, y).item()"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": null,
189 | "id": "b7c41abb",
190 | "metadata": {},
191 | "outputs": [],
192 | "source": [
193 | "####################################\n",
194 | "# plot the test loss with error bars\n",
195 | "####################################\n",
196 | "\n",
197 | "fig_dir = 'figures' \n",
198 | "os.makedirs(fig_dir, exist_ok=True)\n",
199 | "\n",
200 | "fig, ax = plt.subplots(1, 1,figsize = (9, 7))\n",
201 | "\n",
202 | "losses = torch.zeros(len(seeds), max_iters//stride)\n",
203 | "keys = loss_dict.keys()\n",
204 | "for idx, key in enumerate(keys):\n",
205 | " losses[idx,:] = loss_dict[key].log()\n",
206 | "losses_mean = torch.mean(losses, axis=0)\n",
207 | "losses_std = torch.std(losses, axis=0)\n",
208 | "ax.plot(range(0,max_iters,stride), losses_mean, color = 'blue', lw = 3)#, label='Adam')\n",
209 | "ax.fill_between(range(0,max_iters,stride), losses_mean-losses_std, losses_mean+losses_std, color = 'black', alpha = 0.2)\n",
210 | "ax.set_xlabel('Iteration',fontsize=30)\n",
211 | "ax.set_ylabel('log(Loss)',fontsize=30)\n",
212 | "ax.tick_params(axis='both', which='major', labelsize=20, width = 3, length = 10)\n",
213 | "ax.tick_params(axis='both', which='minor', labelsize=20, width = 3, length = 5)\n",
214 | "\n",
215 | "\n",
216 | "plt.tight_layout()\n",
217 | "plt.savefig(fig_dir + '/rotation_demonstration_adam_pnull_loss_plot.pdf', dpi=600)"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": null,
223 | "id": "88adc8d9",
224 | "metadata": {},
225 | "outputs": [],
226 | "source": [
227 | "####################################\n",
228 | "# display the parameter matrices\n",
229 | "# image/font setting assumes d=5\n",
230 | "####################################\n",
231 | "\n",
232 | "key = (0,)\n",
233 | "\n",
234 | "U = U_dict[(0,)]\n",
235 | "D = D_dict[(0,)]\n",
236 | "UD = torch.mm(U,D) \n",
237 | "for l in range(n_layer):\n",
238 | " for h in range(n_head):\n",
239 | " fig, ax = plt.subplots(1, 1,figsize = (6, 6))\n",
240 | " matrix = hist_dict[key][-1][l,h,1,:,:]\n",
241 | " #rotate matrix by inverse of UD\n",
242 | " matrix = torch.mm(torch.mm(UD.t(), matrix), UD)\n",
243 | " # Create a heatmap using imshow\n",
244 | " im = ax.imshow(matrix.cpu(), cmap='gray_r')\n",
245 | " # Add the matrix values as text\n",
246 | " for i in range(matrix.shape[0]):\n",
247 | " for j in range(matrix.shape[1]):\n",
248 | " ax.text(j, i, format(matrix[i, j], '.2f'), ha='center', va='center', color='r')\n",
249 | " # Add a colorbar for reference\n",
250 | " fig.colorbar(im)\n",
251 | " #ax.set_title('$A_{}$'.format(l),fontsize=20)\n",
252 | " plt.savefig(fig_dir + '/rotation_demonstration_pnull_A{}.pdf'.format(l), dpi=600)\n",
253 | " "
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 56,
259 | "id": "1a10c824",
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "########################################################\n",
264 | "# plot the distance-to-identity of each matrix with time\n",
265 | "########################################################\n",
266 | "\n",
267 | "# function for computing distance to identity\n",
268 | "def compute_dist_identity(M):\n",
269 | " scale = torch.sum(torch.diagonal(M))/M.shape[0]\n",
270 | " ideal_identity = scale* torch.eye(M.shape[0]).to(device)\n",
271 | " difference = M - ideal_identity\n",
272 | " err = (torch.norm(difference,p='fro')/torch.norm(M,p='fro'))\n",
273 | " return err\n",
274 | "\n",
275 | "########################################\n",
276 | "# compute distances (assume n_head = 1)\n",
277 | "########################################\n",
278 | "dist_dict = {}\n",
279 | "\n",
280 | "id_dist_dict={}\n",
281 | " \n",
282 | "for key in hist_dict:\n",
283 | " (sd,) = key\n",
284 | " dist_dict[key] = torch.zeros(n_layer, 2, max_iters//stride)\n",
285 | " id_dist_dict[key] = torch.zeros(n_layer, 2, max_iters//stride)\n",
286 | " U = U_dict[key]\n",
287 | " D = D_dict[key]\n",
288 | " UD = torch.mm(U,D) \n",
289 | " for t in range(0,max_iters,stride):\n",
290 | " with torch.no_grad():\n",
291 | " allparam = hist_dict[key][t//stride]\n",
292 | " for i in range(n_layer):\n",
293 | " for j in range(2):\n",
294 | " matrix = allparam[i,0,j,:,:]\n",
295 | " if j ==1:\n",
296 | " id_dist_dict[key][i,j,t//stride] = compute_dist_identity(matrix).item()\n",
297 | " matrix = torch.mm(torch.mm(UD.t(), matrix), UD)\n",
298 | " dist_dict[key][i,j,t//stride] = compute_dist_identity(matrix).item()\n",
299 | "####################################\n",
300 | "# plot distances\n",
301 | "####################################\n",
302 | "\n",
303 | "fig_dir = 'figures' \n",
304 | "os.makedirs(fig_dir, exist_ok=True)\n",
305 | "\n",
306 | "labels = ['$B_0$', '$B_1$', None, \n",
307 | " '$\\Sigma^{1/2} A_0 \\Sigma^{1/2}$', \n",
308 | " '$\\Sigma^{1/2} A_1 \\Sigma^{1/2}$', \n",
309 | " '$\\Sigma^{1/2} A_2 \\Sigma^{1/2}$']\n",
310 | "names = ['B0', 'B1', None, \n",
311 | " 'A0', \n",
312 | " 'A1', \n",
313 | " 'A2']\n",
314 | "colors = ['blue','blue',None, 'blue','blue','blue']\n",
315 | "\n",
316 | "for l in range(n_layer):\n",
317 | " for pq in range(2):\n",
318 | " if l==n_layer-1 and pq==0:\n",
319 | " continue\n",
320 | " if pq ==0:\n",
321 | " continue\n",
322 | " fig, ax = plt.subplots(1, 1,figsize = (9, 7))\n",
323 | " if pq==1:\n",
324 | " id_dist_p = torch.zeros(len(seeds), max_iters//stride)\n",
325 | " for idx, sd in enumerate(seeds):\n",
326 | " losses[idx,:] = id_dist_dict[(sd,)][l,pq,:]\n",
327 | " dist_mean = torch.mean(losses, axis=0)\n",
328 | " dist_std = torch.std(losses, axis=0)\n",
329 | " ax.plot(range(0,max_iters,stride), dist_mean, color = 'red', lw = 3, label='$A_{}$'.format(l))\n",
330 | " ax.fill_between(range(0,max_iters,stride), dist_mean-dist_std, dist_mean+dist_std, color = 'red', alpha = 0.2)\n",
331 | " \n",
332 | " dist_p = torch.zeros(len(seeds), max_iters//stride)\n",
333 | " for idx, sd in enumerate(seeds):\n",
334 | " losses[idx,:] = dist_dict[(sd,)][l,pq,:]\n",
335 | " dist_mean = torch.mean(losses, axis=0)\n",
336 | " dist_std = torch.std(losses, axis=0)\n",
337 | " \n",
338 | " style_id = l + 3*pq\n",
339 | " \n",
340 | " ax.plot(range(0,max_iters,stride), dist_mean, color = colors[style_id], lw = 3, label=labels[style_id])\n",
341 | " ax.fill_between(range(0,max_iters,stride), dist_mean-dist_std, dist_mean+dist_std, color = colors[style_id], alpha = 0.2)\n",
342 | " ax.tick_params(axis='both', which='major', labelsize=20, width = 3, length = 10)\n",
343 | " ax.tick_params(axis='both', which='minor', labelsize=20, width = 3, length = 5)\n",
344 | " \n",
345 | " ax.set_ylim([0,1])\n",
346 | " ax.set_xlabel('Iteration',fontsize=30)\n",
347 | " ax.set_ylabel('Distance to Id',fontsize=30)\n",
348 | " ax.legend(fontsize=30)\n",
349 | " \n",
350 | " plt.savefig(fig_dir + '/rotation_demonstration_dist_to_id_adam_pnull_{}.pdf'.format(names[style_id]), dpi=600)"
351 | ]
352 | }
353 | ],
354 | "metadata": {
355 | "kernelspec": {
356 | "display_name": "Python 3 (ipykernel)",
357 | "language": "python",
358 | "name": "python3"
359 | },
360 | "language_info": {
361 | "codemirror_mode": {
362 | "name": "ipython",
363 | "version": 3
364 | },
365 | "file_extension": ".py",
366 | "mimetype": "text/x-python",
367 | "name": "python",
368 | "nbconvert_exporter": "python",
369 | "pygments_lexer": "ipython3",
370 | "version": "3.9.12"
371 | }
372 | },
373 | "nbformat": 4,
374 | "nbformat_minor": 5
375 | }
376 |
--------------------------------------------------------------------------------
/variable_L_exp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 94,
6 | "id": "3fcfaf4d",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import torch\n",
11 | "from matplotlib import pyplot as plt\n",
12 | "import sys\n",
13 | "import time\n",
14 | "import os\n",
15 | "import numpy as np\n",
16 | "import math\n",
17 | "\n",
18 | "##############################################################################################################\n",
19 | "# Trains a linear Transformer with 1,2,3,4 layers\n",
20 | "# Plots the test loss of trained Transformer against 1,2,3,4 steps of gradient descent (with and without preconditioning)\n",
21 | "##############################################################################################################\n",
22 | "\n",
23 | "#use cuda if available, else use cpu\n",
24 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
25 | "#torch.cuda.set_device(1)\n",
26 | "# import the model and some useful functions\n",
27 | "from linear_transformer import Transformer_F, attention, generate_data, in_context_loss, generate_data_inplace\n",
28 | "\n",
29 | "# set up some print options\n",
30 | "np.set_printoptions(precision = 2, suppress = True)\n",
31 | "torch.set_printoptions(precision=2)\n",
32 | "\n",
33 | "#begin logging\n",
34 | "cur_dir = 'log' \n",
35 | "os.makedirs(cur_dir, exist_ok=True)\n",
36 | "#f = open(cur_dir + '/rotation.log', \"a\", 1)\n",
37 | "#sys.stdout = f"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 95,
43 | "id": "9700bf1b",
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "# Set up problem parameters\n",
48 | "\n",
49 | "lr = 0.01\n",
50 | "clip_r = 0.01\n",
51 | "alg = 'adam'\n",
52 | "mode = 'normal'\n",
53 | "\n",
54 | "n_layer = 4 # number of layers of transformer\n",
55 | "N = 20 # context length\n",
56 | "d = 5 # dimension\n",
57 | "\n",
58 | "\n",
59 | "n_head = 1 # 1-headed attention\n",
60 | "B = 20000 # 1000 minibatch size\n",
61 | "var = 0.0001 # initializations scale of transformer parameter\n",
62 | "shape_k = 0.1 # shape_k: parameter for Gamma distributed covariates\n",
63 | "max_iters = 20000 # Number of Iterations to run\n",
64 | "hist_stride = 1 # stride for saved model paramters in `train.ipynb'\n",
65 | "stride = 100\n",
66 | "\n",
67 | "# a convenience function for taking a step and clipping\n",
68 | "def clip_and_step(allparam, optimizer, clip_r = None):\n",
69 | " norm_p=None\n",
70 | " grad_all = allparam.grad\n",
71 | " if clip_r is not None:\n",
72 | " for l in range(grad_all.shape[0]):\n",
73 | " for h in range(grad_all.shape[1]):\n",
74 | " for t in range(grad_all.shape[2]):\n",
75 | " norm_p = grad_all[l,h,t,:,:].norm().item()\n",
76 | " if norm_p > clip_r:\n",
77 | " grad_all[l,h,t,:,:].mul_(clip_r/norm_p)\n",
78 | " optimizer.step()\n",
79 | " return norm_p\n",
80 | "\n",
81 | "#format for saving run data\n",
82 | "filename_format = '/variable_L_hist_{}_{}_{}.pth'\n",
83 | "n_layers = [1,2,3,4] # number of layers of transformer\n",
84 | "seeds=[0,1,2,3,4]\n",
85 | "keys = []\n",
86 | "for s in seeds:\n",
87 | " for n_layer in n_layers:\n",
88 | " keys.append((s,n_layer,))"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "id": "e69d0ea4",
95 | "metadata": {
96 | "scrolled": false
97 | },
98 | "outputs": [],
99 | "source": [
100 | "for key in keys:\n",
101 | " sd = key[0]\n",
102 | " n_layer = key[1]\n",
103 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
104 | " print(key)\n",
105 | " \n",
106 | " prob_seed = sd\n",
107 | " opt_seed = sd\n",
108 | " \n",
109 | " hist = []\n",
110 | " \n",
111 | " #set seed and initialize model\n",
112 | " torch.manual_seed(opt_seed)\n",
113 | " model = Transformer_F(n_layer, 1, d, var)\n",
114 | " model.to(device)\n",
115 | " #initialize algorithm. Important: set beta = 0.9 for adam, 0.999 is very slow\n",
116 | " optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.99, 0.9), weight_decay=0)\n",
117 | " \n",
118 | " # set seed\n",
119 | " # sample random rotation matrix\n",
120 | " # initialize initial training batch\n",
121 | " np.random.seed(prob_seed)\n",
122 | " torch.manual_seed(prob_seed)\n",
123 | " gaus = torch.FloatTensor(5,5).uniform_(-1,1).cuda()\n",
124 | " U = torch.linalg.svd (gaus)[0].cuda()\n",
125 | " D = torch.diag(torch.FloatTensor([1,1,1/2,1/4,1])).cuda()\n",
126 | " Z, y = generate_data(mode,N,d,B,shape_k, U, D)\n",
127 | " Z = Z.to(device)\n",
128 | " y = y.to(device)\n",
129 | " for t in range(max_iters):\n",
130 | " if t%4000==0 and t>1:\n",
131 | " optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] *0.5\n",
132 | " if t%100==0:\n",
133 | " Z,y = generate_data_inplace(Z, U=U, D=D)\n",
134 | " start = time.time()\n",
135 | " # save model parameters\n",
136 | " if t%stride ==0:\n",
137 | " hist.append(model.allparam.clone().detach())\n",
138 | " loss = in_context_loss(model, Z, y)\n",
139 | " # compute gradient, take step\n",
140 | " loss.backward()\n",
141 | " norms = clip_and_step(model.allparam, optimizer, clip_r=clip_r)\n",
142 | " optimizer.zero_grad()\n",
143 | " end=time.time()\n",
144 | " if t%100 ==0 or t<5:\n",
145 | " print('iter {} | Loss: {} time: {} gradnorm: {}'.format(t,loss.item(), end-start, norms))\n",
146 | " torch.save({'hist':hist, 'U':U, 'D':D}, filename)"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 96,
152 | "id": "83154b0e",
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "########################################################\n",
157 | "# compute test loss for trained linear Transformers\n",
158 | "########################################################\n",
159 | "loss_dict = {}\n",
160 | "for sd in seeds:\n",
161 | " key = (sd,)\n",
162 | " loss_dict[key] = torch.zeros(4)\n",
163 | " for n_layer in n_layers:\n",
164 | " # load parameters for given n_layer and seed\n",
165 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
166 | " hist = torch.load(filename)['hist']\n",
167 | " U = torch.load(filename)['U']\n",
168 | " D = torch.load(filename)['D']\n",
169 | " \n",
170 | " # given short(er) training steps, may have some unstable solutions\n",
171 | " # on a validation set of (seed=999), find the solution with best validation\n",
172 | " # loss from the last 20 runs\n",
173 | " np.random.seed(999)\n",
174 | " torch.manual_seed(999)\n",
175 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
176 | " Z = Z.to(device)\n",
177 | " y = y.to(device)\n",
178 | " model = Transformer_F(n_layer, n_head, d, var).to(device)\n",
179 | " loss = 100\n",
180 | " bestmodel = None\n",
181 | " for t in range(len(hist)-20, len(hist)):\n",
182 | " with torch.no_grad():\n",
183 | " model.allparam.copy_(hist[t])\n",
184 | " newloss = in_context_loss(model, Z, y).item()\n",
185 | " if (newloss < loss):\n",
186 | " loss=newloss\n",
187 | " bestmodel = hist[t]\n",
188 | " with torch.no_grad():\n",
189 | " model.allparam.copy_(bestmodel)\n",
190 | " np.random.seed(99)\n",
191 | " torch.manual_seed(99)\n",
192 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
193 | " Z = Z.to(device)\n",
194 | " y = y.to(device) \n",
195 | " #compute loss\n",
196 | " loss_dict[key][n_layer-1] = in_context_loss(model, Z, y).log().item()"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "id": "e9ca8fff",
203 | "metadata": {},
204 | "outputs": [],
205 | "source": [
206 | "# evaluate the performance of x steps of Gradient Descent\n",
207 | "def do_gd(Z,eta,numstep):\n",
208 | " N = Z.shape[0]-1\n",
209 | " X = Z[0:N-1,0:5]\n",
210 | " Y = Z[0:N-1,5]\n",
211 | " w = torch.zeros(X.shape[1]).to(device)\n",
212 | " for k in range(numstep):\n",
213 | " XTXw = torch.einsum('ik,ij,j->k',X,X,w)\n",
214 | " XTY = torch.einsum('ik,i->k',X,Y)\n",
215 | " grad = XTXw - XTY\n",
216 | " w = w - eta * grad\n",
217 | " return w\n",
218 | "\n",
219 | "def eval_w_instance(Z, Ytest, w):\n",
220 | " N = Z.shape[0]-1\n",
221 | " Xtest = Z[N,0:5]\n",
222 | " prediction = torch.einsum('i,i->',w,Xtest)\n",
223 | " return (Ytest - prediction)**2, prediction\n",
224 | "\n",
225 | "\n",
226 | "gd_loss_matrix = torch.zeros(len(seeds),4)\n",
227 | "\n",
228 | "\n",
229 | "for n_layer in n_layers:\n",
230 | " #first find best eta\n",
231 | " #load seed 1 for U,D matrices\n",
232 | " sd = 1\n",
233 | " best_loss = 10000\n",
234 | " best_eta = 0\n",
235 | " numstep = n_layer\n",
236 | " # load UD matrices\n",
237 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
238 | " U = torch.load(filename)['U']\n",
239 | " D = torch.load(filename)['D']\n",
240 | " #generate test data using seed 999\n",
241 | " np.random.seed(999)\n",
242 | " torch.manual_seed(999)\n",
243 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
244 | " Z = Z.to(device)\n",
245 | " y = y.to(device)\n",
246 | " #done generating data \n",
247 | " \n",
248 | " for eta in [0.008, 0.01, 0.02, 0.04, 0.08, 0.16]:\n",
249 | " ### start of evaluate mean loss ###\n",
250 | " total_loss = 0\n",
251 | " for i in range(5000):\n",
252 | " Zi = Z[i,:,:]\n",
253 | " Ytesti = y[i]\n",
254 | " w = do_gd(Zi,eta,numstep)\n",
255 | " gd_loss, gd_pred = eval_w_instance(Zi, Ytesti, w)\n",
256 | " total_loss = total_loss + gd_loss\n",
257 | " mean_loss = total_loss / 5000\n",
258 | " ### end of evaluate mean loss ###\n",
259 | " print('eta: {}, loss: {}'.format(eta, mean_loss))\n",
260 | " if (mean_loss < best_loss):\n",
261 | " best_eta = eta\n",
262 | " best_loss = mean_loss\n",
263 | " print('best eta: {} for n_layer={}'.format(best_eta, n_layer))\n",
264 | " \n",
265 | " #now do actual evaluation\n",
266 | " for sd in seeds:\n",
267 | " opt_seed = sd\n",
268 | " \n",
269 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
270 | " U = torch.load(filename)['U']\n",
271 | " D = torch.load(filename)['D']\n",
272 | " #generate test data\n",
273 | " torch.manual_seed(sd)\n",
274 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
275 | " Z = Z.to(device)\n",
276 | " y = y.to(device)\n",
277 | " #done generating data \n",
278 | " eta = best_eta\n",
279 | " ### start of evaluate mean loss ###\n",
280 | " total_loss = 0\n",
281 | " for i in range(Z.shape[0]):\n",
282 | " Zi = Z[i,:,:]\n",
283 | " Ytesti = y[i]\n",
284 | " w = do_gd(Zi,eta,numstep)\n",
285 | " gd_loss, gd_pred = eval_w_instance(Zi, Ytesti, w)\n",
286 | " total_loss = total_loss + gd_loss\n",
287 | " mean_loss = total_loss / Z.shape[0]\n",
288 | " gd_loss_matrix[sd,n_layer-1] = mean_loss\n",
289 | " \n",
290 | "#compute mean and std of log test loss for plotting\n",
291 | "gd_loss_mean = gd_loss_matrix.log().mean(dim=0)\n",
292 | "gd_loss_std = gd_loss_matrix.log().var(dim=0)**0.5"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": null,
298 | "id": "98d80e86",
299 | "metadata": {},
300 | "outputs": [],
301 | "source": [
302 | "def do_preconditioned_gd(Z,eta,numstep,U,D):\n",
303 | " N = Z.shape[0]-1\n",
304 | " X = Z[0:N-1,0:5]\n",
305 | " Y = Z[0:N-1,5]\n",
306 | " w = torch.zeros(X.shape[1]).to(device)\n",
307 | " X = torch.einsum('ij, jk, Nk -> Ni', (torch.inverse(D),U.t(),X))\n",
308 | " for k in range(numstep):\n",
309 | " XTXw = torch.einsum('ik,ij,j->k',X,X,w)\n",
310 | " XTY = torch.einsum('ik,i->k',X,Y)\n",
311 | " grad = XTXw - XTY\n",
312 | " w = w - eta * grad\n",
313 | " return w\n",
314 | "\n",
315 | "def eval_w_instance_precon(Z, Ytest, w, U, D):\n",
316 | " N = Z.shape[0]-1\n",
317 | " Xtest = Z[N,0:5]\n",
318 | " Xtest = torch.einsum('ij, jk, k -> i', (torch.inverse(D),U.t(),Xtest))\n",
319 | " prediction = torch.einsum('i,i->',w,Xtest)\n",
320 | " return (Ytest - prediction)**2, prediction\n",
321 | "\n",
322 | "\n",
323 | "\n",
324 | "pgd_loss_matrix = torch.zeros(len(seeds),4)\n",
325 | "\n",
326 | "for n_layer in n_layers:\n",
327 | " #first find best eta\n",
328 | " #load seed 1 for U,D matrices\n",
329 | " sd = 1\n",
330 | " best_loss = 10000\n",
331 | " best_eta = 0\n",
332 | " numstep = n_layer\n",
333 | " # load UD matrices\n",
334 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
335 | " U = torch.load(filename)['U'].to(device)\n",
336 | " D = torch.load(filename)['D'].to(device)\n",
337 | " #generate test data using seed 999\n",
338 | " np.random.seed(999)\n",
339 | " torch.manual_seed(999)\n",
340 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
341 | " Z = Z.to(device)\n",
342 | " y = y.to(device)\n",
343 | " #done generating data \n",
344 | " \n",
345 | " for eta in [0.001, 0.002, 0.004, 0.008, 0.01, 0.02, 0.04, 0.08, 0.16]:\n",
346 | " ### start of evaluate mean loss ###\n",
347 | " total_loss = 0\n",
348 | " for i in range(5000):\n",
349 | " Zi = Z[i,:,:]\n",
350 | " Ytesti = y[i]\n",
351 | " w = do_preconditioned_gd(Zi,eta,numstep,U,D)\n",
352 | " pgd_loss, pgd_pred = eval_w_instance_precon(Zi, Ytesti, w, U, D)\n",
353 | " total_loss = total_loss + pgd_loss\n",
354 | " mean_loss = total_loss / 5000\n",
355 | " ### end of evaluate mean loss ###\n",
356 | " print('eta: {}, loss: {}'.format(eta, mean_loss))\n",
357 | " if (mean_loss < best_loss):\n",
358 | " best_eta = eta\n",
359 | " best_loss = mean_loss\n",
360 | " print('best eta: {} for n_layer={}'.format(best_eta, n_layer))\n",
361 | " \n",
362 | " #now do actual evaluation\n",
363 | " for sd in seeds:\n",
364 | " opt_seed = sd\n",
365 | " \n",
366 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
367 | " U = torch.load(filename)['U'].to(device)\n",
368 | " D = torch.load(filename)['D'].to(device)\n",
369 | " #generate test data\n",
370 | " torch.manual_seed(sd)\n",
371 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
372 | " Z = Z.to(device)\n",
373 | " y = y.to(device)\n",
374 | " #done generating data \n",
375 | " eta = best_eta\n",
376 | " ### start of evaluate mean loss ###\n",
377 | " total_loss = 0\n",
378 | " for i in range(5000):\n",
379 | " Zi = Z[i,:,:]\n",
380 | " Ytesti = y[i]\n",
381 | " w = do_preconditioned_gd(Zi,eta,numstep,U,D)\n",
382 | " pgd_loss, pgd_pred = eval_w_instance_precon(Zi, Ytesti, w, U, D)\n",
383 | " total_loss = total_loss + pgd_loss\n",
384 | " mean_loss = total_loss / 5000\n",
385 | " pgd_loss_matrix[sd,n_layer-1] = mean_loss\n",
386 | "\n",
387 | "#compute mean and std of log test loss for plotting\n",
388 | "pgd_loss_mean = pgd_loss_matrix.log().mean(dim=0)\n",
389 | "pgd_loss_std = pgd_loss_matrix.log().var(dim=0)**0.5\n",
390 | " "
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": null,
396 | "id": "5af8555c",
397 | "metadata": {},
398 | "outputs": [],
399 | "source": [
400 | "####################################\n",
401 | "# plot final test loss against N\n",
402 | "####################################\n",
403 | "\n",
404 | "fig_dir = 'figures' \n",
405 | "os.makedirs(fig_dir, exist_ok=True)\n",
406 | "\n",
407 | "fig, ax = plt.subplots(1, 1,figsize = (9, 9))\n",
408 | "\n",
409 | "losses = torch.zeros(len(seeds), len(n_layers))\n",
410 | "keys = loss_dict.keys()\n",
411 | "for idx, key in enumerate(keys):\n",
412 | " losses[idx,:] = loss_dict[key]\n",
413 | "losses_mean = torch.mean(losses, axis=0)\n",
414 | "losses_std = torch.std(losses, axis=0)\n",
415 | "\n",
416 | "plt.plot(n_layers, gd_loss_mean, color='blue', label='GD')\n",
417 | "plt.fill_between(n_layers, gd_loss_mean - gd_loss_std, gd_loss_mean + gd_loss_std, color='blue', alpha=0.2)\n",
418 | "plt.plot(n_layers, pgd_loss_mean, color='green', label='Preconditioned GD')\n",
419 | "plt.fill_between(n_layers, pgd_loss_mean - pgd_loss_std, pgd_loss_mean + pgd_loss_std, color='green', alpha=0.2)\n",
420 | "ax.plot(n_layers, losses_mean, color = 'red', lw = 3, label='Linear Transformer')\n",
421 | "ax.fill_between(n_layers, losses_mean-losses_std, losses_mean+losses_std, color = 'red', alpha = 0.2)\n",
422 | "\n",
423 | "plt.ylabel('log(Loss)',fontsize=30)\n",
424 | "plt.xlabel('Number of Layers/Steps',fontsize=30)\n",
425 | "ax.tick_params(axis='both', which='major', labelsize=30, width = 3, length = 10)\n",
426 | "ax.tick_params(axis='both', which='minor', labelsize=20, width = 3, length = 5)\n",
427 | "ax.legend(fontsize=24)\n",
428 | "#ax.set_yscale('log')\n",
429 | "\n",
430 | "\n",
431 | "plt.tight_layout()\n",
432 | "plt.savefig(fig_dir + '/variable-L-plot.pdf', dpi=600)"
433 | ]
434 | },
435 | {
436 | "cell_type": "code",
437 | "execution_count": null,
438 | "id": "78b77440",
439 | "metadata": {},
440 | "outputs": [],
441 | "source": []
442 | }
443 | ],
444 | "metadata": {
445 | "kernelspec": {
446 | "display_name": "Python 3 (ipykernel)",
447 | "language": "python",
448 | "name": "python3"
449 | },
450 | "language_info": {
451 | "codemirror_mode": {
452 | "name": "ipython",
453 | "version": 3
454 | },
455 | "file_extension": ".py",
456 | "mimetype": "text/x-python",
457 | "name": "python",
458 | "nbconvert_exporter": "python",
459 | "pygments_lexer": "ipython3",
460 | "version": "3.9.12"
461 | }
462 | },
463 | "nbformat": 4,
464 | "nbformat_minor": 5
465 | }
466 |
--------------------------------------------------------------------------------
/variable_N_exp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "id": "3fcfaf4d",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import torch\n",
11 | "from matplotlib import pyplot as plt\n",
12 | "import sys\n",
13 | "import time\n",
14 | "import os\n",
15 | "import numpy as np\n",
16 | "import math\n",
17 | "\n",
18 | "#####################################################\n",
19 | "# This is almost identical to simple demonstration \n",
20 | "# -- except covariates have a skewed covariance matrix\n",
21 | "#\n",
22 | "# In this notebook, we train a 3-layer linear transformer with\n",
23 | "# - context-length 20\n",
24 | "# - covariate dimension 5, standard Gaussian distribution\n",
25 | "# We plot\n",
26 | "# - test loss against number of iterations\n",
27 | "# - imshow of each parameter matrix at end of training\n",
28 | "# - distance-to-identity of each parameter matrix\n",
29 | "#####################################################\n",
30 | "\n",
31 | "#use cuda if available, else use cpu\n",
32 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
33 | "torch.cuda.set_device(0)\n",
34 | "# import the model and some useful functions\n",
35 | "from linear_transformer import Transformer_F, attention, generate_data, in_context_loss, generate_data_inplace\n",
36 | "\n",
37 | "# set up some print options\n",
38 | "np.set_printoptions(precision = 2, suppress = True)\n",
39 | "torch.set_printoptions(precision=2)\n",
40 | "\n",
41 | "#begin logging\n",
42 | "cur_dir = 'log' \n",
43 | "os.makedirs(cur_dir, exist_ok=True)\n",
44 | "#f = open(cur_dir + '/rotation.log', \"a\", 1)\n",
45 | "#sys.stdout = f"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 3,
51 | "id": "9700bf1b",
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "# Set up problem parameters\n",
56 | "\n",
57 | "lr = 0\n",
58 | "clip_r = 0.001\n",
59 | "alg = 'adam'\n",
60 | "mode = 'normal'\n",
61 | "\n",
62 | "n_layer = 3 # number of layers of transformer\n",
63 | "d = 5 # dimension\n",
64 | "\n",
65 | "\n",
66 | "n_head = 1 # 1-headed attention\n",
67 | "B = 40000 # 1000 minibatch size\n",
68 | "var = 0.0001 # initializations scale of transformer parameter\n",
69 | "shape_k = 0.1 # shape_k: parameter for Gamma distributed covariates\n",
70 | "max_iters = 8000 # Number of Iterations to run\n",
71 | "hist_stride = 1 # stride for saved model paramters in `train.ipynb'\n",
72 | "stride = 100\n",
73 | "\n",
74 | "# a convenience function for taking a step and clipping\n",
75 | "def clip_and_step(allparam, optimizer, clip_r = None):\n",
76 | " norm_p=None\n",
77 | " grad_all = allparam.grad\n",
78 | " if clip_r is not None:\n",
79 | " for l in range(grad_all.shape[0]):\n",
80 | " for h in range(grad_all.shape[1]):\n",
81 | " for t in range(grad_all.shape[2]):\n",
82 | " norm_p = grad_all[l,h,t,:,:].norm().item()\n",
83 | " if norm_p > clip_r:\n",
84 | " grad_all[l,h,t,:,:].mul_(clip_r/norm_p)\n",
85 | " optimizer.step()\n",
86 | " return norm_p\n",
87 | "\n",
88 | "filename_format = '/variable_N_hist_{}_{}_{}.pth'\n",
89 | "Ns = range(20,21,2) # context length\n",
90 | "seeds=[0,1,2,3,4]\n",
91 | "keys = []\n",
92 | "for s in seeds:\n",
93 | " for N in Ns:\n",
94 | " keys.append((s,N,))"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "id": "e69d0ea4",
101 | "metadata": {
102 | "scrolled": false
103 | },
104 | "outputs": [],
105 | "source": [
106 | "for key in keys:\n",
107 | " sd = key[0]\n",
108 | " N = key[1]\n",
109 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
110 | " print(key)\n",
111 | " \n",
112 | " prob_seed = sd\n",
113 | " opt_seed = sd\n",
114 | " \n",
115 | " hist = []\n",
116 | " \n",
117 | " #set seed and initialize model\n",
118 | " torch.manual_seed(opt_seed)\n",
119 | " \n",
120 | " model = Transformer_F(n_layer, 1, d, var)\n",
121 | " model.to(device)\n",
122 | " #initialize algorithm. Important: set beta = 0.9 for adam, 0.999 is very slow\n",
123 | " \n",
124 | " \n",
125 | " if N < 5:\n",
126 | " lr = 0.001\n",
127 | " elif N < 15:\n",
128 | " lr = 0.01\n",
129 | " else:\n",
130 | " lr = 0.01\n",
131 | " \n",
132 | " optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.99, 0.99), weight_decay=0)\n",
133 | " \n",
134 | " # set seed\n",
135 | " # sample random rotation matrix\n",
136 | " # initialize initial training batch\n",
137 | " np.random.seed(prob_seed)\n",
138 | " torch.manual_seed(prob_seed)\n",
139 | " gaus = torch.FloatTensor(5,5).uniform_(-1,1).cuda()\n",
140 | " U = torch.linalg.svd (gaus)[0].cuda()\n",
141 | " D = torch.diag(torch.FloatTensor([1,1,1/2,1/4,1])).cuda()\n",
142 | " \n",
143 | " # generate a SINGLE BATCH of training set USED FOREVER\n",
144 | " Z, y = generate_data(mode,N,d,B,shape_k, U, D)\n",
145 | " Z = Z.to(device)\n",
146 | " y = y.to(device)\n",
147 | " for t in range(max_iters):\n",
148 | " start = time.time()\n",
149 | " if t%2000==0 and t>1:\n",
150 | " optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] *0.5\n",
151 | " Z,y = generate_data_inplace(Z, U=U, D=D)\n",
152 | " start = time.time()\n",
153 | " # save model parameters\n",
154 | " if t%stride ==0:\n",
155 | " hist.append(model.allparam.clone().detach())\n",
156 | " loss = in_context_loss(model, Z, y)\n",
157 | " # compute gradient, take step\n",
158 | " loss.backward()\n",
159 | " norms = clip_and_step(model.allparam, optimizer, clip_r=clip_r)\n",
160 | " optimizer.zero_grad()\n",
161 | " end=time.time()\n",
162 | " if t%500 ==0 or t<5:\n",
163 | " print('iter {} | Loss: {} time: {} gradnorm: {}'.format(t,loss.item(), end-start, norms))\n",
164 | " torch.save({'hist':hist, 'U':U, 'D':D}, filename)"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": 8,
170 | "id": "83154b0e",
171 | "metadata": {},
172 | "outputs": [],
173 | "source": [
174 | "####################################\n",
175 | "# compute test loss\n",
176 | "####################################\n",
177 | "#hist_dict = torch.load(filename)['hist_dict']\n",
178 | "keys = []\n",
179 | "seeds = [0,1,2,3,4]\n",
180 | "Ns = range(2,21,2)\n",
181 | "for s in seeds:\n",
182 | " for N in Ns:\n",
183 | " keys.append((s,N,))\n",
184 | "loss_dict = {}\n",
185 | "for sd in seeds:\n",
186 | " key = (sd,)\n",
187 | " loss_dict[key] = torch.zeros(len(Ns))\n",
188 | " for N in Ns:\n",
189 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
190 | " hist = torch.load(filename)['hist']\n",
191 | " U = torch.load(filename)['U']\n",
192 | " D = torch.load(filename)['D']\n",
193 | " np.random.seed(999)\n",
194 | " torch.manual_seed(999)\n",
195 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
196 | " Z = Z.to(device)\n",
197 | " y = y.to(device)\n",
198 | " model = Transformer_F(n_layer, n_head, d, var).to(device)\n",
199 | " loss = 100\n",
200 | " bestmodel = None\n",
201 | " for t in range(len(hist)-10, len(hist)):\n",
202 | " with torch.no_grad():\n",
203 | " model.allparam.copy_(hist[t])\n",
204 | " newloss = in_context_loss(model, Z, y).item()\n",
205 | " if (newloss < loss):\n",
206 | " loss=newloss\n",
207 | " bestmodel = hist[t]\n",
208 | " with torch.no_grad():\n",
209 | " model.allparam.copy_(bestmodel)\n",
210 | " np.random.seed(99)\n",
211 | " torch.manual_seed(99)\n",
212 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
213 | " Z = Z.to(device)\n",
214 | " y = y.to(device) \n",
215 | " loss_dict[key][N//2-1] = in_context_loss(model, Z, y).item()"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": null,
221 | "id": "ff4227a3",
222 | "metadata": {},
223 | "outputs": [],
224 | "source": [
225 | "########################################################\n",
226 | "# plot log final test loss against N, for sanity check\n",
227 | "########################################################\n",
228 | "\n",
229 | "fig_dir = 'figures' \n",
230 | "os.makedirs(fig_dir, exist_ok=True)\n",
231 | "\n",
232 | "fig, ax = plt.subplots(1, 1,figsize = (9, 9))\n",
233 | "\n",
234 | "losses = torch.zeros(len(seeds), len(Ns))\n",
235 | "keys = loss_dict.keys()\n",
236 | "for idx, key in enumerate(keys):\n",
237 | " losses[idx,:] = np.log(loss_dict[key])\n",
238 | "losses_mean = torch.mean(losses, axis=0)\n",
239 | "losses_std = torch.std(losses, axis=0)\n",
240 | "ax.plot(Ns, losses_mean, color = 'red', lw = 3, label='3-Layer Linear Transformer')\n",
241 | "ax.fill_between(Ns, losses_mean-losses_std, losses_mean+losses_std, color = 'red', alpha = 0.2)"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": null,
247 | "id": "e9ca8fff",
248 | "metadata": {},
249 | "outputs": [],
250 | "source": [
251 | "def do_gd(Z,eta,numstep):\n",
252 | " N = Z.shape[0]-1\n",
253 | " X = Z[0:N-1,0:5]\n",
254 | " Y = Z[0:N-1,5]\n",
255 | " w = torch.zeros(X.shape[1]).to(device)\n",
256 | " for k in range(numstep):\n",
257 | " XTXw = torch.einsum('ik,ij,j->k',X,X,w)\n",
258 | " XTY = torch.einsum('ik,i->k',X,Y)\n",
259 | " grad = XTXw - XTY\n",
260 | " w = w - eta * grad\n",
261 | " return w\n",
262 | "\n",
263 | "def eval_w_instance(Z, Ytest, w):\n",
264 | " N = Z.shape[0]-1\n",
265 | " Xtest = Z[N,0:5]\n",
266 | " prediction = torch.einsum('i,i->',w,Xtest)\n",
267 | " return (Ytest - prediction)**2, prediction\n",
268 | "\n",
269 | "\n",
270 | "## code for running 3-step GD loss\n",
271 | "gd_loss_matrix = torch.zeros(len(seeds),10)\n",
272 | "#for seed in seeds:\n",
273 | "# gd_loss_matrix.append([None]*10)\n",
274 | " \n",
275 | "for N in Ns:\n",
276 | " #find best eta\n",
277 | " #load seed 1 just to find eta\n",
278 | " sd = 1\n",
279 | " best_loss = 10000\n",
280 | " best_eta = 0\n",
281 | " numstep = 3\n",
282 | " \n",
283 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
284 | " U = torch.load(filename)['U']\n",
285 | " D = torch.load(filename)['D']\n",
286 | " #generate test data\n",
287 | " np.random.seed(999)\n",
288 | " torch.manual_seed(999)\n",
289 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
290 | " Z = Z.to(device)\n",
291 | " y = y.to(device)\n",
292 | " #done generating data \n",
293 | " \n",
294 | " for eta in [0.001, 0.002, 0.004, 0.008, 0.01, 0.02, 0.04, 0.08, 0.16]:\n",
295 | " ### start of evaluate mean loss ###\n",
296 | " total_loss = 0\n",
297 | " for i in range(5000):\n",
298 | " Zi = Z[i,:,:]\n",
299 | " Ytesti = y[i]\n",
300 | " w = do_gd(Zi,eta,numstep)\n",
301 | " gd_loss, gd_pred = eval_w_instance(Zi, Ytesti, w)\n",
302 | " total_loss = total_loss + gd_loss\n",
303 | " mean_loss = total_loss / 5000\n",
304 | " ### end of evaluate mean loss ###\n",
305 | " print('eta: {}, loss: {}'.format(eta, mean_loss))\n",
306 | " if (mean_loss < best_loss):\n",
307 | " best_eta = eta\n",
308 | " best_loss = mean_loss\n",
309 | " print('best eta: {} for N={}'.format(best_eta, N))\n",
310 | " \n",
311 | " #now do actual evaluation\n",
312 | " for sd in seeds:\n",
313 | " opt_seed = sd\n",
314 | " \n",
315 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
316 | " U = torch.load(filename)['U']\n",
317 | " D = torch.load(filename)['D']\n",
318 | " #generate test data\n",
319 | " torch.manual_seed(sd)\n",
320 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
321 | " Z = Z.to(device)\n",
322 | " y = y.to(device)\n",
323 | " #done generating data \n",
324 | " eta = best_eta\n",
325 | " ### start of evaluate mean loss ###\n",
326 | " total_loss = 0\n",
327 | " for i in range(5000):\n",
328 | " Zi = Z[i,:,:]\n",
329 | " Ytesti = y[i]\n",
330 | " w = do_gd(Zi,eta,numstep)\n",
331 | " gd_loss, gd_pred = eval_w_instance(Zi, Ytesti, w)\n",
332 | " total_loss = total_loss + gd_loss\n",
333 | " mean_loss = total_loss / 5000\n",
334 | " gd_loss_matrix[sd,int(N/2-1)] = mean_loss\n",
335 | " \n",
336 | "gd_loss_mean = gd_loss_matrix.mean(dim=0)\n",
337 | "gd_loss_std = gd_loss_matrix.var(dim=0)**0.5 "
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": null,
343 | "id": "4a2b85e8",
344 | "metadata": {},
345 | "outputs": [],
346 | "source": [
347 | "def do_preconditioned_gd(Z,eta,numstep,U,D):\n",
348 | " N = Z.shape[0]-1\n",
349 | " X = Z[0:N-1,0:5]\n",
350 | " Y = Z[0:N-1,5]\n",
351 | " w = torch.zeros(X.shape[1]).to(device)\n",
352 | " X = torch.einsum('ij, jk, Nk -> Ni', (torch.inverse(D),U.t(),X))\n",
353 | " for k in range(numstep):\n",
354 | " XTXw = torch.einsum('ik,ij,j->k',X,X,w)\n",
355 | " XTY = torch.einsum('ik,i->k',X,Y)\n",
356 | " grad = XTXw - XTY\n",
357 | " w = w - eta * grad\n",
358 | " return w\n",
359 | "\n",
360 | "def eval_w_instance_precon(Z, Ytest, w, U, D):\n",
361 | " N = Z.shape[0]-1\n",
362 | " Xtest = Z[N,0:5]\n",
363 | " Xtest = torch.einsum('ij, jk, k -> i', (torch.inverse(D),U.t(),Xtest))\n",
364 | " prediction = torch.einsum('i,i->',w,Xtest)\n",
365 | " return (Ytest - prediction)**2, prediction\n",
366 | "\n",
367 | "\n",
368 | "## code for running 3-step GD loss\n",
369 | "pgd_loss_matrix = torch.zeros(len(seeds),10)\n",
370 | "#for seed in seeds:\n",
371 | "# gd_loss_matrix.append([None]*10)\n",
372 | " \n",
373 | "for N in Ns:\n",
374 | " #find best eta\n",
375 | " #load seed 1 just to find eta\n",
376 | " best_loss = 10000\n",
377 | " best_eta = 0\n",
378 | " numstep = 3\n",
379 | " \n",
380 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
381 | " U = torch.load(filename)['U'].to(device)\n",
382 | " D = torch.load(filename)['D'].to(device)\n",
383 | " #generate test data\n",
384 | " np.random.seed(999)\n",
385 | " torch.manual_seed(999)\n",
386 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
387 | " Z = Z.to(device)\n",
388 | " y = y.to(device)\n",
389 | " #done generating data \n",
390 | " \n",
391 | " for eta in [0.008, 0.01, 0.02, 0.04, 0.08, 0.16]:\n",
392 | " ### start of evaluate mean loss ###\n",
393 | " total_loss = 0\n",
394 | " for i in range(5000):\n",
395 | " Zi = Z[i,:,:]\n",
396 | " Ytesti = y[i]\n",
397 | " w = do_preconditioned_gd(Zi,eta,numstep,U,D)\n",
398 | " pgd_loss, pgd_pred = eval_w_instance_precon(Zi, Ytesti, w, U, D)\n",
399 | " total_loss = total_loss + pgd_loss\n",
400 | " mean_loss = total_loss / 5000\n",
401 | " ### end of evaluate mean loss ###\n",
402 | " print('eta: {}, loss: {}'.format(eta, mean_loss))\n",
403 | " if (mean_loss < best_loss):\n",
404 | " best_eta = eta\n",
405 | " best_loss = mean_loss\n",
406 | " print('best eta: {} for N={}'.format(best_eta, N))\n",
407 | " \n",
408 | " #now do actual evaluation\n",
409 | " for sd in seeds:\n",
410 | " opt_seed = sd\n",
411 | " \n",
412 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
413 | " U = torch.load(filename)['U'].to(device)\n",
414 | " D = torch.load(filename)['D'].to(device)\n",
415 | " #generate test data\n",
416 | " torch.manual_seed(sd)\n",
417 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
418 | " Z = Z.to(device)\n",
419 | " y = y.to(device)\n",
420 | " #done generating data \n",
421 | " eta = best_eta\n",
422 | " ### start of evaluate mean loss ###\n",
423 | " total_loss = 0\n",
424 | " for i in range(5000):\n",
425 | " Zi = Z[i,:,:]\n",
426 | " Ytesti = y[i]\n",
427 | " w = do_preconditioned_gd(Zi,eta,numstep,U,D)\n",
428 | " pgd_loss, pgd_pred = eval_w_instance_precon(Zi, Ytesti, w, U, D)\n",
429 | " total_loss = total_loss + pgd_loss\n",
430 | " mean_loss = total_loss / 5000\n",
431 | " pgd_loss_matrix[sd,int(N/2-1)] = mean_loss\n",
432 | " \n",
433 | "pgd_loss_mean = pgd_loss_matrix.mean(dim=0)\n",
434 | "pgd_loss_std = pgd_loss_matrix.var(dim=0)**0.5"
435 | ]
436 | },
437 | {
438 | "cell_type": "code",
439 | "execution_count": null,
440 | "id": "e2644ed4",
441 | "metadata": {},
442 | "outputs": [],
443 | "source": [
444 | "def do_OLS(Z,eta,numstep,U,D):\n",
445 | " N = Z.shape[0]-1\n",
446 | " X = Z[0:N-1,0:5]\n",
447 | " Y = Z[0:N-1,5]\n",
448 | " w = torch.zeros(X.shape[1])\n",
449 | " X = torch.einsum('ij, jk, Nk -> Ni', (torch.inverse(D),U.t(),X))\n",
450 | " for k in range(numstep):\n",
451 | " XTX = torch.einsum('ik,ij->kj',X,X)\n",
452 | " XTY = torch.einsum('ik,i->k',X,Y)\n",
453 | " w = torch.einsum('ik,k->i', torch.linalg.pinv(XTX) , XTY)\n",
454 | " return w\n",
455 | "\n",
456 | "def eval_w_instance_OLS(Z, Ytest, w,U,D):\n",
457 | " N = Z.shape[0]-1\n",
458 | " Xtest = Z[N,0:5]\n",
459 | " Xtest = torch.einsum('ij, jk, k -> i', (torch.inverse(D),U.t(),Xtest))\n",
460 | " prediction = torch.einsum('i,i->',w,Xtest)\n",
461 | " return (Ytest - prediction)**2, prediction\n",
462 | "\n",
463 | "\n",
464 | "## code for running 3-step GD loss\n",
465 | "OLS_loss_matrix = torch.zeros(len(seeds),10)\n",
466 | "#for seed in seeds:\n",
467 | "# gd_loss_matrix.append([None]*10)\n",
468 | " \n",
469 | "for N in Ns:\n",
470 | " #now do actual evaluation\n",
471 | " for sd in seeds:\n",
472 | " opt_seed = sd\n",
473 | " filename = cur_dir + filename_format.format(n_layer, N, sd)\n",
474 | " U = torch.load(filename)['U'].to(device)\n",
475 | " D = torch.load(filename)['D'].to(device)\n",
476 | " #generate test data\n",
477 | " torch.manual_seed(sd)\n",
478 | " Z, y = generate_data(mode,N,d,B,shape_k,U,D)\n",
479 | " Z = Z.to(device)\n",
480 | " y = y.to(device)\n",
481 | " #done generating data \n",
482 | " eta = best_eta\n",
483 | " ### start of evaluate mean loss ###\n",
484 | " total_loss = 0\n",
485 | " for i in range(5000):\n",
486 | " Zi = Z[i,:,:]\n",
487 | " Ytesti = y[i]\n",
488 | " w = do_OLS(Zi,eta,numstep,U,D)\n",
489 | " OLS_loss, OLS_pred = eval_w_instance_OLS(Zi, Ytesti, w, U, D)\n",
490 | " total_loss = total_loss + OLS_loss\n",
491 | " mean_loss = total_loss / 5000\n",
492 | " OLS_loss_matrix[sd,int(N/2-1)] = mean_loss\n",
493 | " print('N={}, loss={}'.format(N,mean_loss))\n",
494 | " \n",
495 | "ols_loss_mean = OLS_loss_matrix.mean(dim=0)\n",
496 | "ols_loss_std = OLS_loss_matrix.var(dim=0)**0.5"
497 | ]
498 | },
499 | {
500 | "cell_type": "code",
501 | "execution_count": null,
502 | "id": "5af8555c",
503 | "metadata": {},
504 | "outputs": [],
505 | "source": [
506 | "####################################\n",
507 | "# plot final test loss against N\n",
508 | "####################################\n",
509 | "\n",
510 | "fig_dir = 'figures' \n",
511 | "os.makedirs(fig_dir, exist_ok=True)\n",
512 | "\n",
513 | "fig, ax = plt.subplots(1, 1,figsize = (9, 9))\n",
514 | "\n",
515 | "losses = torch.zeros(len(seeds), len(Ns))\n",
516 | "keys = loss_dict.keys()\n",
517 | "for idx, key in enumerate(keys):\n",
518 | " losses[idx,:] = loss_dict[key]\n",
519 | "losses_mean = torch.mean(losses, axis=0)\n",
520 | "losses_std = torch.std(losses, axis=0)\n",
521 | "\n",
522 | "plt.plot(Ns, gd_loss_mean, color='blue', label='3-Step GD')\n",
523 | "plt.fill_between(Ns, gd_loss_mean - gd_loss_std, gd_loss_mean + gd_loss_std, color='blue', alpha=0.2)\n",
524 | "plt.plot(Ns, pgd_loss_mean, color='green', label='3-Step Preconditioned GD')\n",
525 | "plt.fill_between(Ns, pgd_loss_mean - pgd_loss_std, pgd_loss_mean + pgd_loss_std, color='green', alpha=0.2)\n",
526 | "ax.plot(Ns, losses_mean, color = 'red', lw = 3, label='3-Layer Linear Transformer')\n",
527 | "ax.fill_between(Ns, losses_mean-losses_std, losses_mean+losses_std, color = 'red', alpha = 0.2)\n",
528 | "plt.plot(Ns, ols_loss_mean, color='purple', label='OLS')\n",
529 | "plt.fill_between(Ns, ols_loss_mean - ols_loss_std, ols_loss_mean + ols_loss_std, color='purple', alpha=0.2)\n",
530 | "\n",
531 | "plt.ylabel('Loss',fontsize=30)\n",
532 | "plt.xlabel('Number of ICL Examples',fontsize=30)\n",
533 | "ax.tick_params(axis='both', which='major', labelsize=30, width = 3, length = 10)\n",
534 | "ax.tick_params(axis='both', which='minor', labelsize=20, width = 3, length = 5)\n",
535 | "plt.xticks(np.arange(2,24, 4))\n",
536 | "ax.legend(fontsize=24)\n",
537 | "#ax.set_yscale('log')\n",
538 | "\n",
539 | "\n",
540 | "plt.tight_layout()\n",
541 | "plt.savefig(fig_dir + '/3-step-variable-N-plot.pdf', dpi=600)"
542 | ]
543 | }
544 | ],
545 | "metadata": {
546 | "kernelspec": {
547 | "display_name": "Python 3 (ipykernel)",
548 | "language": "python",
549 | "name": "python3"
550 | },
551 | "language_info": {
552 | "codemirror_mode": {
553 | "name": "ipython",
554 | "version": 3
555 | },
556 | "file_extension": ".py",
557 | "mimetype": "text/x-python",
558 | "name": "python",
559 | "nbconvert_exporter": "python",
560 | "pygments_lexer": "ipython3",
561 | "version": "3.9.12"
562 | }
563 | },
564 | "nbformat": 4,
565 | "nbformat_minor": 5
566 | }
567 |
--------------------------------------------------------------------------------
/plot_stochastic_noise.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "%matplotlib inline\n",
11 | "from matplotlib import pyplot as plt\n",
12 | "import math\n",
13 | "import torch.nn.functional as F\n",
14 | "from torch.nn.functional import relu\n",
15 | "from torch import nn\n",
16 | "import torch.optim as optim\n",
17 | "import torch.optim.lr_scheduler as lr_scheduler\n",
18 | "import random\n",
19 | "import numpy as np\n",
20 | "import gc\n",
21 | "from pylab import *\n",
22 | "import os\n",
23 | "import random\n",
24 | "import json\n",
25 | "import pandas as pd\n",
26 | "from scipy.stats import norm\n",
27 | "pd.set_option('display.float_format', lambda x: '%.5f' % x)\n",
28 | "import sys\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import time\n",
31 | "\n",
32 | "from linear_transformer import Transformer_F, attention, generate_data, in_context_loss, generate_data_inplace\n",
33 | "\n",
34 | "np.set_printoptions(precision = 4, suppress = True)\n",
35 | "torch.set_printoptions(precision=2)\n",
36 | "device = torch.device(\"cuda\")\n",
37 | "torch.cuda.set_device(0)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 2,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "# Set Hyperparameters\n",
47 | "\n",
48 | "# Fixed\n",
49 | "n_head = 1\n",
50 | "d = 5\n",
51 | "B = 1000\n",
52 | "ma = 1\n",
53 | "var = 0.05\n",
54 | "shape_k = 0.1\n",
55 | "\n",
56 | "# We vary the following parameters\n",
57 | "n_layer = 3\n",
58 | "mode = 'normal'\n",
59 | "N = 20\n",
60 | "n_sample = 10000 # number of stochastic gradients to sample\n",
61 | "seed = 1\n",
62 | "\n",
63 | "log_dir = 'log' \n",
64 | "os.makedirs(log_dir, exist_ok=True)"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "np.random.seed(seed)\n",
74 | "torch.manual_seed(seed)\n",
75 | "\n",
76 | "model = Transformer_F(n_layer, n_head, d, var)\n",
77 | "model.to(device)\n",
78 | "\n",
79 | "# compute estimated true gradient using large batch\n",
80 | "B_large = B*100\n",
81 | "Z, y = generate_data(mode,N,d,B_large,shape_k)\n",
82 | "Z = Z.cuda()\n",
83 | "y = y.cuda()\n",
84 | "\n",
85 | "# redefine loss using newly sampled Z and y\n",
86 | "def eval_loss():\n",
87 | " output = model(Z)\n",
88 | " N= Z.shape[1]-1\n",
89 | " diff = output[:,N,d]+y\n",
90 | " loss = ((diff)**2).mean() \n",
91 | " loss = loss \n",
92 | " return loss\n",
93 | "\n",
94 | "loss = eval_loss()\n",
95 | "loss.backward()\n",
96 | "gradient = model.allparam.grad.data.clone().detach()\n",
97 | "model.allparam.grad.zero_()\n",
98 | "\n",
99 | "noiseList = []\n",
100 | "for _ in range(n_sample):\n",
101 | " # compute stochastic gradient\n",
102 | " Z, y = generate_data(mode,N,d,B,shape_k)\n",
103 | " Z = Z.cuda()\n",
104 | " y = y.cuda()\n",
105 | " \n",
106 | " loss = in_context_loss(model, Z, y)\n",
107 | " loss.backward()\n",
108 | " stochastic_gradient = model.allparam.grad.data.clone().detach()\n",
109 | " model.allparam.grad.zero_()\n",
110 | "\n",
111 | " noise = torch.norm(stochastic_gradient - gradient)\n",
112 | " noiseList.append(noise.item())\n",
113 | "\n",
114 | "filename = log_dir + '/stochastic_gradient_noise_layer{}_N{}_{}_sd{}.pth'.format(n_layer,N,mode,seed)\n",
115 | "torch.save({'noiseList':noiseList}, filename)"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 6,
121 | "metadata": {},
122 | "outputs": [
123 | {
124 | "data": {
125 | "image/png": "",
126 | "text/plain": [
127 | ""
128 | ]
129 | },
130 | "metadata": {},
131 | "output_type": "display_data"
132 | }
133 | ],
134 | "source": [
135 | "# Plot stochastic gradient noise\n",
136 | "\n",
137 | "import torch\n",
138 | "%matplotlib inline\n",
139 | "from matplotlib import pyplot as plt\n",
140 | "import numpy as np\n",
141 | "\n",
142 | "fig_dir = 'figures' \n",
143 | "os.makedirs(fig_dir, exist_ok=True)\n",
144 | "\n",
145 | "# hyperparamters\n",
146 | "mode = 'normal'\n",
147 | "N = 20\n",
148 | "seed = 1\n",
149 | "n_layer = 3\n",
150 | "\n",
151 | "filename = log_dir + '/stochastic_gradient_noise_layer{}_N{}_{}_sd{}.pth'.format(n_layer,N,mode,seed)\n",
152 | "loaded_dict = torch.load(filename)\n",
153 | "noiseList = loaded_dict['noiseList']\n",
154 | "noiseArray = np.array(noiseList)\n",
155 | "\n",
156 | "fig, ax = plt.subplots(1, 1,figsize = (6, 6))\n",
157 | "ax.hist(noiseList, bins=100, density=True, alpha=1.0, edgecolor = 'black', linewidth = 0.001)\n",
158 | "ax.tick_params(axis='both', which='major', labelsize=30, width = 3, length = 10)\n",
159 | "ax.tick_params(axis='both', which='minor', labelsize=30, width = 3, length = 5)\n",
160 | "\n",
161 | "ax.spines[['right', 'top']].set_visible(False)\n",
162 | "ax.spines['left'].set_linewidth(3)\n",
163 | "ax.spines['bottom'].set_linewidth(3)\n",
164 | "ax.set_ylabel('Density',fontsize=40)\n",
165 | "ax.set_xlabel('Gradient error',fontsize=40)\n",
166 | "\n",
167 | "plt.tight_layout()\n",
168 | "plt.savefig(fig_dir + '/heavy_tail_noise_layer{}_N{}_{}.pdf'.format(n_layer, N, mode), dpi=600)"
169 | ]
170 | }
171 | ],
172 | "metadata": {
173 | "kernelspec": {
174 | "display_name": "pytorch",
175 | "language": "python",
176 | "name": "python3"
177 | },
178 | "language_info": {
179 | "codemirror_mode": {
180 | "name": "ipython",
181 | "version": 3
182 | },
183 | "file_extension": ".py",
184 | "mimetype": "text/x-python",
185 | "name": "python",
186 | "nbconvert_exporter": "python",
187 | "pygments_lexer": "ipython3",
188 | "version": "3.11.3"
189 | },
190 | "orig_nbformat": 4
191 | },
192 | "nbformat": 4,
193 | "nbformat_minor": 2
194 | }
195 |
--------------------------------------------------------------------------------
/plot_loss.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "%matplotlib inline\n",
11 | "from matplotlib import pyplot as plt\n",
12 | "import math\n",
13 | "import torch.nn.functional as F\n",
14 | "from torch.nn.functional import relu\n",
15 | "from torch import nn\n",
16 | "import torch.optim as optim\n",
17 | "import torch.optim.lr_scheduler as lr_scheduler\n",
18 | "import random\n",
19 | "import numpy as np\n",
20 | "import gc\n",
21 | "from pylab import *\n",
22 | "import os\n",
23 | "import random\n",
24 | "import json\n",
25 | "import pandas as pd\n",
26 | "from scipy.stats import norm\n",
27 | "pd.set_option('display.float_format', lambda x: '%.5f' % x)\n",
28 | "import sys\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import time\n",
31 | "\n",
32 | "from linear_transformer import Transformer_F, attention, generate_data, in_context_loss, generate_data_inplace\n",
33 | "\n",
34 | "np.set_printoptions(precision = 4, suppress = True)\n",
35 | "torch.set_printoptions(precision=2)\n",
36 | "device = torch.device(\"cuda\")\n",
37 | "torch.cuda.set_device(0)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 2,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "# Set Hyperparameters\n",
47 | "\n",
48 | "# Fixed\n",
49 | "n_head = 1\n",
50 | "d = 5\n",
51 | "B = 1000\n",
52 | "ma = 1\n",
53 | "var = 0.05\n",
54 | "shape_k = 0.1\n",
55 | "\n",
56 | "# Number of Iterations to run\n",
57 | "max_iters = 10000\n",
58 | "hist_stride = 1 # stride for saved model paramters in `train.ipynb'\n",
59 | "stride = 50 # stride for computing loss\n",
60 | "\n",
61 | "# We vary the following parameters\n",
62 | "n_layer = 3\n",
63 | "mode = 'normal'\n",
64 | "N = 20\n",
65 | "seeds = [0,1,2,3,4,5]\n"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 3,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "log_dir = 'log'\n",
75 | "loss_plots = {}\n",
76 | "\n",
77 | "for sd in seeds:\n",
78 | " for (alg, toclip, lr) in [('sgd', True, 0.02),('adam', True, 0.02)]: \n",
79 | " filename = log_dir + '/train_layer{}_N{}_{}_{}_{}_lr{}_sd{}.pth'.format(n_layer, N, mode, alg, toclip, lr, sd)\n",
80 | " loaded_dict = torch.load(filename)\n",
81 | " hist_list = loaded_dict['hist_list']\n",
82 | "\n",
83 | " np.random.seed(99)\n",
84 | " torch.manual_seed(99)\n",
85 | " Z, y = generate_data(mode,N,d,B,shape_k)\n",
86 | " Z = Z.to(device)\n",
87 | " y = y.to(device)\n",
88 | "\n",
89 | " model = Transformer_F(n_layer, n_head, d, var)\n",
90 | " model = model.to(device)\n",
91 | " \n",
92 | " test_losses = torch.zeros(max_iters//stride)\n",
93 | "\n",
94 | " for t in range(0,max_iters,stride):\n",
95 | " allparam_loaded = hist_list[t]\n",
96 | " with torch.no_grad():\n",
97 | " model.allparam.copy_(allparam_loaded)\n",
98 | " test_loss = in_context_loss(model, Z, y)\n",
99 | " test_losses[t//stride] = test_loss.item()\n",
100 | "\n",
101 | " loss_plots[(alg, toclip, lr, sd)] = test_losses"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 4,
107 | "metadata": {},
108 | "outputs": [
109 | {
110 | "data": {
111 | "image/png": "",
112 | "text/plain": [
113 | ""
114 | ]
115 | },
116 | "metadata": {},
117 | "output_type": "display_data"
118 | }
119 | ],
120 | "source": [
121 | "fig_dir = 'figures' \n",
122 | "os.makedirs(fig_dir, exist_ok=True)\n",
123 | "\n",
124 | "fig, ax = plt.subplots(1, 1,figsize = (6, 6))\n",
125 | "\n",
126 | "for (alg, toclip, lr) in [('sgd', True, 0.02),('adam', True, 0.02)]:\n",
127 | " losses = torch.zeros(len(seeds), int(max_iters/stride))\n",
128 | " for idx, sd in enumerate(seeds):\n",
129 | " losses[idx] = loss_plots[(alg, toclip, lr, sd)]\n",
130 | " losses_mean = torch.mean(losses, axis=0)\n",
131 | " losses_std = torch.std(losses, axis=0)\n",
132 | " if alg == 'sgd':\n",
133 | " ax.plot(range(0,max_iters,stride), losses_mean, color = 'black', lw = 3,label='SGDM')\n",
134 | " ax.fill_between(range(0,max_iters,stride), losses_mean-losses_std/4, losses_mean+losses_std/4, color = 'black', alpha = 0.1)\n",
135 | " elif alg == 'adam':\n",
136 | " ax.plot(range(0,max_iters,stride), losses_mean, color = 'red', lw = 3, label='Adam')\n",
137 | " ax.fill_between(range(0,max_iters,stride), losses_mean-losses_std/4, losses_mean+losses_std/4, color = 'red', alpha = 0.1)\n",
138 | "\n",
139 | " ax.set_xlabel('Iteration',fontsize=40)\n",
140 | " ax.tick_params(axis='both', which='major', labelsize=30, width = 3, length = 10)\n",
141 | " ax.tick_params(axis='both', which='minor', labelsize=20, width = 3, length = 5)\n",
142 | " ax.legend(fontsize=30)\n",
143 | " ax.spines[['right', 'top']].set_visible(False)\n",
144 | " ax.spines['left'].set_linewidth(3)\n",
145 | " ax.spines['bottom'].set_linewidth(3)\n",
146 | " ax.set_yscale('log')\n",
147 | " \n",
148 | " plt.tight_layout()\n",
149 | " plt.savefig(fig_dir + '/loss_layer{}_N{}_{}.pdf'.format(n_layer, N, mode), dpi=600)"
150 | ]
151 | }
152 | ],
153 | "metadata": {
154 | "kernelspec": {
155 | "display_name": "pytorch",
156 | "language": "python",
157 | "name": "python3"
158 | },
159 | "language_info": {
160 | "codemirror_mode": {
161 | "name": "ipython",
162 | "version": 3
163 | },
164 | "file_extension": ".py",
165 | "mimetype": "text/x-python",
166 | "name": "python",
167 | "nbconvert_exporter": "python",
168 | "pygments_lexer": "ipython3",
169 | "version": "3.11.3"
170 | },
171 | "orig_nbformat": 4
172 | },
173 | "nbformat": 4,
174 | "nbformat_minor": 2
175 | }
176 |
--------------------------------------------------------------------------------