├── cnn ├── script.sh ├── lazy_utils.py ├── models.py ├── extract_kernel.py └── train.py ├── README.md └── shallow-nn ├── populationSGD.jl ├── test_vs_scale.jl ├── illustration.jl ├── test_vs_m.jl └── lazy.ipynb /cnn/script.sh: -------------------------------------------------------------------------------- 1 | # Reproduce experiments to demonstrate an effective linearization as alpha grows 2 | 3 | for LR in 1.0 0.1 0.01 0.001 4 | do 5 | for ALPHA in 10000000.0 1000000.0 100000.0 10000.0 1000.0 100.0 10.0 5.0 1.0 0.5 0.1 0.01 6 | do 7 | python train.py --scaling_factor $ALPHA --lr $lr --gain 1.0 --schedule 'b' --loss 'mse' --length 100 --precision 'double' 8 | done 9 | done 10 | 11 | # Obtain the SVD of the tangent kernel for cifar and random features 12 | 13 | python extract_kernel.py --bs 9 --data 'random' --subset 495 14 | python extract_kernel.py --bs 9 --subset 495 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # lazy-training-code 2 | 3 | This code was based on https://github.com/kuangliu/pytorch-cifar . 4 | 5 | ## Reproducing CNNs experiments 6 | 7 | If you want to obtain CNN experiments accuracies and loss from the paper, simply run: 8 | 9 | ``` 10 | cd cnn 11 | sh script.sh 12 | ``` 13 | 14 | The __double__ precision experiments require a Tesla or Volta GPUs for handling this numerical precision at a reasonable speed... 15 | 16 | ## Reproducing shallow experiments 17 | 18 | All the codes necessary to reproduce the results from the paper as located in `shallow-nn` 19 | 20 | ## Contributions 21 | 22 | All contributions are welcome. 23 | -------------------------------------------------------------------------------- /shallow-nn/populationSGD.jl: -------------------------------------------------------------------------------- 1 | d = 100 # dimension of the supervised learning problem (our d-1) 2 | m0 = 3 # number of neurons of generating data 3 | niter = 20000 # put 20000 4 | @assert niter>1999 5 | m = 50 6 | 7 | #scales = 10 .^ (-2.5:0.1:1) 8 | scales = cat([0.01,0.02,0.04],10 .^ (-1.0:0.1:0),[2,4,8],dims=1) 9 | nscales = length(scales) 10 | ntrials = 10 11 | 12 | Fs = zeros(ntrials,nscales) 13 | 14 | batchsize = 200 15 | stepsize = 10 16 | 17 | p = Progress(ntrials*nscales) 18 | Random.seed!(1) 19 | for i = 1:ntrials 20 | θ0 = randn(m0,d) # random ground truth 21 | θ0 = θ0 ./ sqrt.(sum(θ0.^2,dims=2)) 22 | w0 = sign.(randn(m0)) 23 | 24 | for j=1:nscales 25 | scale = scales[j] 26 | stepsize = min(0.25/scale^2,25) 27 | ws,θs,val = populationSGDfor2NN(m,w0,θ0,stepsize,batchsize,scale,niter) 28 | Fs[i,j]= sum(val[end-1999:end]) 29 | ProgressMeter.next!(p) 30 | end 31 | end 32 | 33 | figure(figsize=[4,4]) 34 | mea = sum(Fs,dims=1)'/ntrials 35 | stdr = sqrt.(sum((Fs' .- mea).^2, dims=2)/(ntrials-1)) 36 | ss = 1/maximum(mea) 37 | semilogx(scales,ss*mea,"k",linewidth=2) 38 | fill_between(scales,ss*(mea+stdr)[:],ss*(mea-stdr)[:],color=[0.85,0.85,0.85]) 39 | ylabel("Population loss at convergence") 40 | xlabel(L"\tau") 41 | 42 | #vlines([0.15; 0.5],[0 ;0],[4; 4],linestyle=":") 43 | #fill_betweenx([0; 4],[0.15 ;0.15],[0.5; 0.5],hatch="//",facecolor="None",edgecolor="k",linestyle=":",label="not yet converged") 44 | #legend(loc="upper left") 45 | #savefig("lazySGD_tau_sans.pdf",bbox_inches="tight") -------------------------------------------------------------------------------- /cnn/lazy_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | _, term_width = os.popen('stty size', 'r').read().split() 6 | term_width = int(term_width) 7 | 8 | TOTAL_BAR_LENGTH = 65. 9 | last_time = time.time() 10 | begin_time = last_time 11 | def progress_bar(current, total, msg=None): 12 | global last_time, begin_time 13 | if current == 0: 14 | begin_time = time.time() # Reset for new bar. 15 | 16 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 17 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 18 | 19 | sys.stdout.write(' [') 20 | for i in range(cur_len): 21 | sys.stdout.write('=') 22 | sys.stdout.write('>') 23 | for i in range(rest_len): 24 | sys.stdout.write('.') 25 | sys.stdout.write(']') 26 | 27 | cur_time = time.time() 28 | step_time = cur_time - last_time 29 | last_time = cur_time 30 | tot_time = cur_time - begin_time 31 | 32 | L = [] 33 | L.append(' Step: %s' % format_time(step_time)) 34 | L.append(' | Tot: %s' % format_time(tot_time)) 35 | if msg: 36 | L.append(' | ' + msg) 37 | 38 | msg = ''.join(L) 39 | sys.stdout.write(msg) 40 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 41 | sys.stdout.write(' ') 42 | 43 | # Go back to the center of the bar. 44 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 45 | sys.stdout.write('\b') 46 | sys.stdout.write(' %d/%d ' % (current+1, total)) 47 | 48 | if current < total-1: 49 | sys.stdout.write('\r') 50 | else: 51 | sys.stdout.write('\n') 52 | sys.stdout.flush() 53 | 54 | def format_time(seconds): 55 | days = int(seconds / 3600/24) 56 | seconds = seconds - days*3600*24 57 | hours = int(seconds / 3600) 58 | seconds = seconds - hours*3600 59 | minutes = int(seconds / 60) 60 | seconds = seconds - minutes*60 61 | secondsf = int(seconds) 62 | seconds = seconds - secondsf 63 | millis = int(seconds*1000) 64 | 65 | f = '' 66 | i = 1 67 | if days > 0: 68 | f += str(days) + 'D' 69 | i += 1 70 | if hours > 0 and i <= 2: 71 | f += str(hours) + 'h' 72 | i += 1 73 | if minutes > 0 and i <= 2: 74 | f += str(minutes) + 'm' 75 | i += 1 76 | if secondsf > 0 and i <= 2: 77 | f += str(secondsf) + 's' 78 | i += 1 79 | if millis > 0 and i <= 2: 80 | f += str(millis) + 'ms' 81 | i += 1 82 | if f == '': 83 | f = '0ms' 84 | return f -------------------------------------------------------------------------------- /shallow-nn/test_vs_scale.jl: -------------------------------------------------------------------------------- 1 | 2 | d = 100 # dimension of the supervised learning problem 3 | n_train = 1000 # size of train set 4 | n_test = 1000 # size of test set 5 | m0 = 3 # nb of neurons teacher 6 | m = 50 # nb of neurons student 7 | 8 | 9 | scaling = 1 # we change the initialization instead of the scaling (it is equivalent, up to a square) 10 | niter = 10000 11 | scales = 10 .^ (-2.2:0.1:1) # scales of init 12 | ntrials = 10 # repetition with different random data/teacher/init/ 13 | ltrains = zeros(niter,length(scales),ntrials) 14 | ltests = zeros(niter,length(scales),ntrials) 15 | test_err_tangent = zeros(niter,ntrials) 16 | 17 | p = Progress(length(scales)*ntrials) # progress bar 18 | for k = 1:ntrials 19 | # random teacher 20 | w1 = randn(m0,d) 21 | w1 = w1 ./ sqrt.(sum(w1.^2, dims=2)) 22 | w2 = sign.(randn(m0)) 23 | f(X) = sum( w2 .* max.( w1 * X', 0.0), dims=1) 24 | 25 | # data sets 26 | X_train = randn(n_train, d) 27 | X_train = X_train ./ sqrt.(sum(X_train.^2, dims=2)) 28 | Y_train = f(X_train) #randn(1,n_train) 29 | X_test = randn(n_test, d) 30 | X_test = X_test ./ sqrt.(sum(X_test.^2, dims=2)) 31 | Y_test = f(X_test); 32 | 33 | # initialization 34 | W_init = randn(m, d+1) 35 | # symmetrization 36 | W_init[1:div(m,2),end] = abs.(W_init[1:div(m,2),end]) 37 | W_init[(div(m,2)+1):end,end] = - W_init[1:div(m,2),end] 38 | W_init[(div(m,2)+1):end,1:end-1] = W_init[1:div(m,2),1:end-1] 39 | W_init0 = W_init; 40 | 41 | for i=1:length(scales) 42 | W_init = scales[i]*W_init0 # both layers are multiplied so scale ~ alpha^2 43 | # the linear scaling of the step-size works for large scales only 44 | stepsize = min(10,0.1/scales[i].^2) 45 | Ws, loss_train, loss_test = GDfor2NN(X_train, X_test, Y_train, Y_test, W_init, scaling, stepsize, niter) 46 | ltrains[:,i,k] = loss_train 47 | ltests[:,i,k] = loss_test 48 | ProgressMeter.next!(p) 49 | end 50 | end 51 | 52 | # Compute mean and std 53 | meana = sum(ltests[end,:,:],dims=2)/ntrials 54 | meanb = sum(minimum(ltests,dims=1),dims=3)[:]/ntrials 55 | stda = sqrt.(sum((ltests[end,:,:] .- meana).^2,dims=2)/(ntrials-1)) 56 | stdb = sqrt.(sum((minimum(ltests,dims=1) .- meanb').^2,dims=3)[:]/(ntrials-1)) 57 | 58 | # Plot 59 | figure(figsize=[4,4]) 60 | ss = 1000 # for nicer yticks 61 | fill_between(scales, ss*(meana+stda)[:],ss*(meana-stda)'[:],color=[0.85,0.85,0.85]) 62 | fill_between(scales, ss*(meanb+stdb)[:],ss*(meanb-stdb)'[:],color=[0.85,0.85,0.85]) 63 | semilogx(scales, ss*meana,"k",alpha=1,linewidth=3,label="end of training") 64 | semilogx(scales, ss*sum(minimum(ltests,dims=1),dims=3)[:]/ntrials,":k",alpha=1,linewidth=3,label="best throughout training") 65 | ylabel("Test loss") 66 | xlabel(L"\tau") 67 | legend() 68 | #savefig("test_loss_tau.pdf",bbox_inches="tight") -------------------------------------------------------------------------------- /shallow-nn/illustration.jl: -------------------------------------------------------------------------------- 1 | # generate the data 2 | d = 2 # dimension of input 3 | 4 | # random teacher 2-NN 5 | m0 = 3 # nb of neurons teacher 6 | w1 = randn(m0,d) 7 | w1 = w1 ./ sqrt.(sum(w1.^2, dims=2)) 8 | w2 = sign.(randn(m0)) 9 | f(X) = sum( w2 .* max.( w1 * X', 0.0), dims=1) 10 | 11 | # data sets 12 | n_train = 15 # size train set (15) 13 | n_test = 20 # size test set 14 | X_train = randn(n_train, d) 15 | X_train = X_train ./ sqrt.(sum(X_train.^2, dims=2)) 16 | Y_train = f(X_train) 17 | X_test = randn(n_test, d) 18 | X_test = X_test ./ sqrt.(sum(X_test.^2, dims=2)) 19 | Y_test = f(X_test); 20 | 21 | # initialize and train 22 | m = 16 # nb of neurons student 23 | scaling = 1 24 | niter = 10^5 25 | stepsize = 0.005 26 | 27 | # initialization 28 | W_init = randn(m, 2) 29 | W_init = W_init ./ sqrt.(sum(W_init.^2, dims=2)) 30 | W_init = cat(W_init, rand(m),dims=2) 31 | 32 | # symmetrization to set initial output to zero (optional) 33 | W_init[(div(m,2)+1):end,end] = - W_init[1:div(m,2),end] 34 | W_init[(div(m,2)+1):end,1:end-1] = W_init[1:div(m,2),1:end-1] 35 | 36 | # choose scale of init (0.1 not lazy / 2 lazy) 37 | W_init = 0.2*W_init 38 | 39 | @time Ws, loss_train, loss_test = GDfor2NN(X_train, X_test, Y_train, Y_test, W_init, scaling, stepsize, niter); 40 | 41 | 42 | figure(figsize=[8,4]) 43 | 44 | subplot(121) 45 | semilogy(loss_train,label="train loss") 46 | semilogy(loss_test,label="test loss") 47 | legend();title("Convergence"); 48 | 49 | 50 | # things to plot 51 | iters = Int.(floor.(exp.(range(0, stop = log(niter), length = 100))))#cat(1:20,21:4:100,110:15:500,500:100:10000,20000:1000:niter,dims=1) 52 | mid=div(m,2) 53 | finalsign = sign.(Ws[:,3,end]) 54 | pxs = Ws[finalsign.>0,1,iters] .* Ws[finalsign.>0,3,iters] 55 | pys = Ws[finalsign.>0,2,iters] .* Ws[finalsign.>0,3,iters] 56 | pxsm = Ws[finalsign.<0,1,iters] .* abs.(Ws[finalsign.<0,3,iters]) 57 | pysm = Ws[finalsign.<0,2,iters] .* abs.(Ws[finalsign.<0,3,iters]) 58 | px0 = w1[:,1] #.* w2 59 | py0 = w1[:,2] #.* w2 60 | 61 | subplot(122) 62 | r = 1 63 | plot(r*cos.(0.0:0.01:2π),r*sin.(0.0:0.01:2π),":",color="k",label="circle of radius $(r)") 64 | 65 | arrow(0,0,px0[1],py0[1],head_width=0.06,length_includes_head=true,facecolor="C3") 66 | arrow(0,0,px0[2],py0[2],head_width=0.06,length_includes_head=true,facecolor="C0") 67 | arrow(0,0,px0[3],py0[3],head_width=0.06,length_includes_head=true,facecolor="C3",label="teacher") 68 | 69 | plot(pxs',pys',linewidth=1.0,"C3"); 70 | plot(pxs[1,:],pys[1,:],linewidth=0.5,"C3",label="gradient flow (+)") 71 | scatter(pxs[:,end],pys[:,end],30,color="C3") 72 | plot(pxsm',pysm',linewidth=1.0,"C0"); 73 | plot(pxsm[1,:],pysm[1,:],linewidth=0.5,"C0",label="gradient flow (-)") 74 | scatter(pxsm[:,end],pysm[:,end],30,color="C0") 75 | 76 | bx= max(max(maximum(abs.(pxs)), maximum(abs.(pys)))*1.1,1.1) 77 | axis([-bx,bx,-bx,bx]); 78 | axis("off") 79 | 80 | #legend(loc=3) 81 | #savefig("cover_lazy_leg.pdf",bbox_inches="tight") 82 | #savefig("gf_doubling_1.png") -------------------------------------------------------------------------------- /shallow-nn/test_vs_m.jl: -------------------------------------------------------------------------------- 1 | d = 100 # dimension of the supervised learning problem 2 | n_train = 1000 # nb of data points 3 | n_test = 1000 4 | m0 = 3 # nb of neurons of ground truth 5 | 6 | scaling = 1 7 | niter = 25000 8 | ms = [2,3,4,6,8,12,16,24,32,64,128,256,512] 9 | ntrials = 10 10 | 11 | # compute with alpha = 1/sqrt(m) 12 | m_ltrains = zeros(niter,length(ms),ntrials) 13 | m_ltests = zeros(niter,length(ms),ntrials) 14 | 15 | # compute with alpha = 1/m 16 | m_ltrains2 = zeros(niter,length(ms),ntrials) 17 | m_ltests2 = zeros(niter,length(ms),ntrials) 18 | 19 | p = Progress(length(ms)*ntrials*2) 20 | for k = 1:ntrials 21 | # ground thruth 22 | w1 = randn(m0,d) 23 | w1 = w1 ./ sqrt.(sum(w1.^2, dims=2)) 24 | w2 = sign.(randn(m0)) 25 | f(X) = sum( w2 .* max.( w1 * X', 0.0), dims=1)*100 # neurons 26 | 27 | # data sets 28 | X_train = randn(n_train, d) 29 | X_train = X_train ./ sqrt.(sum(X_train.^2, dims=2)) 30 | Y_train = f(X_train) #randn(1,n_train) 31 | X_test = randn(n_test, d) 32 | X_test = X_test ./ sqrt.(sum(X_test.^2, dims=2)) 33 | Y_test = f(X_test) 34 | 35 | # compute with alpha = 1/sqrt(m) 36 | for i=1:length(ms) 37 | m = ms[i] 38 | W_init = randn(m, d+1) 39 | scaling = 1/sqrt(m) 40 | stepsize = 1/m 41 | Ws, loss_train, loss_test = GDfor2NN(X_train, X_test, Y_train, Y_test, W_init, scaling, stepsize, niter); 42 | m_ltrains[:,i,k] = loss_train 43 | m_ltests[:,i,k] = loss_test 44 | ProgressMeter.next!(p) 45 | end 46 | 47 | # compute with alpha = 1/m 48 | for i=1:length(ms) 49 | m = ms[i] 50 | W_init = randn(m, d+1) 51 | scaling = 1/m 52 | stepsize = 0.05/m 53 | Ws, loss_train, loss_test = GDfor2NN(X_train, X_test, Y_train, Y_test, W_init, scaling, stepsize, niter) 54 | m_ltrains2[:,i,k] = loss_train 55 | m_ltests2[:,i,k] = loss_test 56 | ProgressMeter.next!(p) 57 | end 58 | end 59 | 60 | # Prepare the plots 61 | sa=1 62 | sb=length(ms) 63 | ss = .1 64 | #endtraining = permutedims(minimum(m_ltests[:,sa:sb,:],dims=1),[2 3 1])[:,:,1] 65 | meana = sum(m_ltests[end,sa:sb,:],dims=2)/ntrials 66 | #meanb = sum(minimum(m_ltests[:,sa:sb,:],dims=1),dims=3)[:]/ntrials 67 | meana2 = sum(m_ltests2[end,sa:sb,:],dims=2)/ntrials 68 | #meanb2 = sum(minimum(m_ltests[:,sa:sb,:],dims=1),dims=3)[:]/ntrials 69 | stda = sqrt.(sum((m_ltests[end,sa:sb,:] .- meana).^2,dims=2)/(ntrials-1)) 70 | stdb = sqrt.(sum((minimum(m_ltests[:,sa:sb,:],dims=1) .- meanb').^2,dims=3)[:]/(ntrials-1)) 71 | confint_low = sort(endtraining,dims=dims=2)[:,1] 72 | confint_up = sort(endtraining,dims=dims=2)[:,end] 73 | 74 | 75 | figure(figsize=[4,4]) 76 | 77 | #fill_between(ms[sa:sb],ss*(meana+stda)[:],ss*(meana-stda)'[:],color=[0.85,0.85,0.85]) 78 | #fill_between(ms[sa:sb],ss*(meana2+stdb)[:],ss*(meanb-stdb)'[:],color=[0.85,0.85,0.85]) 79 | #fill_between(ms[sa:sb],ss*confint_low[:],ss*confint_up[:],color=[0.85,0.85,0.85]) 80 | semilogx(ms[sa:sb],ss*m_ltests[end,:,:],"ok",markersize=1); 81 | semilogx(ms[sa:sb],ss*m_ltests2[end,:,:],"o",color=[0.5,0.5,0.5],markersize=1); 82 | 83 | semilogx(ms[sa:sb],ss*meana,"k",alpha=1,linewidth=3,label=L"scaling $1/\sqrt{m}$") 84 | semilogx(ms[sa:sb],ss*meana2,color=[0.5,0.5,0.5],linewidth=3,label=L"scaling $1/m$") 85 | 86 | ylabel("Test loss") 87 | xlabel(L"m") 88 | xticks([1, 10, 100, 1000]) 89 | yticks([0,1]) 90 | legend() 91 | #savefig("test_mcomp_dots.pdf",bbox_inches="tight") -------------------------------------------------------------------------------- /cnn/models.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M','M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1) 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion*planes: 23 | self.shortcut =nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride))#, bias=False), 24 | # nn.BatchNorm2d(self.expansion*planes) 25 | 26 | def forward(self, x): 27 | out = F.relu(self.conv1(x)) 28 | out = self.conv2(out) 29 | out += self.shortcut(x) 30 | out = F.relu(out) 31 | return out 32 | 33 | 34 | class ResNet(nn.Module): 35 | def __init__(self, block, num_blocks, k, num_classes=10): 36 | super(ResNet, self).__init__() 37 | self.in_planes = 64*k 38 | layers=[] 39 | layers+= [nn.Conv2d(3, 64*k, kernel_size=3, stride=1, padding=1),nn.ReLU()] 40 | #self.bn1 = nn.BatchNorm2d(64) 41 | a = self._make_layer(block, 64 * k, num_blocks[0], stride=1) 42 | layers+= [*a] 43 | a = self._make_layer(block, 128 * k, num_blocks[1], stride=2) 44 | layers+= [*a] 45 | a =self._make_layer(block, 256 * k, num_blocks[2], stride=2) 46 | layers+= [*a] 47 | a = self._make_layer(block, 512 * k, num_blocks[3], stride=2) 48 | layers+= [*a] 49 | layers += [nn.AvgPool2d(kernel_size=4)] 50 | self.features = nn.Sequential(*layers) 51 | #self.linear = nn.Linear(512*block.expansion, num_classes) 52 | 53 | def _make_layer(self, block, planes, num_blocks, stride): 54 | strides = [stride] + [1]*(num_blocks-1) 55 | layers = [] 56 | for stride in strides: 57 | layers += [block(self.in_planes, planes, stride)] 58 | self.in_planes = planes * block.expansion 59 | return layers 60 | 61 | def forward(self, x): 62 | out = self.features(x)#F.relu(self.bn1(self.conv1(x))) 63 | #out = self.layer1(out) 64 | #out = self.layer2(out) 65 | #out = self.layer3(out) 66 | #out = self.layer4(out) 67 | #out = F.avg_pool2d(out, 4) 68 | out = out.view(out.size(0), -1) 69 | #out = self.linear(out) 70 | return out 71 | 72 | 73 | def ResNet18(k): 74 | return ResNet(BasicBlock, [2,2,2,2],k) 75 | 76 | 77 | class VGG(nn.Module): 78 | def __init__(self, vgg_name,k): 79 | super(VGG, self).__init__() 80 | self.features = self._make_layers(cfg[vgg_name],k) 81 | #self.classifier = nn.Linear(512, 10) 82 | 83 | def forward(self, x): 84 | out = self.features(x) 85 | out = out.view(out.size(0), -1) 86 | #out = self.classifier(out) 87 | return out 88 | 89 | def _make_layers(self, cfg,k): 90 | layers = [] 91 | in_channels = 3 92 | for x in cfg: 93 | if x == 'M': 94 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 95 | else: 96 | layers += [nn.Conv2d(in_channels, x*k, kernel_size=3, padding=1), 97 | # nn.BatchNorm2d(x), 98 | nn.ReLU(inplace=False)] 99 | in_channels = x*k 100 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 101 | return nn.Sequential(*layers) 102 | -------------------------------------------------------------------------------- /cnn/extract_kernel.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import torch 5 | torch.manual_seed(58) 6 | import numpy as np 7 | np.random.seed(58) 8 | torch.backends.cudnn.deterministic = True 9 | torch.backends.cudnn.benchmark = False 10 | 11 | import copy 12 | import torch.nn as nn 13 | import torchvision 14 | import torchvision.transforms as transforms 15 | import argparse 16 | from models import * 17 | from lazy_utils import progress_bar 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 20 | parser.add_argument('--model', default='vgg', type=str, help='model type') 21 | parser.add_argument('--widening_factor', default=1, type=int, help='widening factor') 22 | parser.add_argument('--bs', default=10, type=int, help='batch size') 23 | parser.add_argument('--gain', default=2.0, type=float, help='gain at init') 24 | parser.add_argument('--subset', default=500, type=int, help='subset of data') 25 | parser.add_argument('--precision', default='float', type=str, help='precision...') 26 | parser.add_argument('--data', default='cifar10', type=str, help='which dataset?...') 27 | 28 | 29 | args = parser.parse_args() 30 | 31 | if args.precision=='float': 32 | torch.set_default_dtype(torch.float32) 33 | elif args.precision=='double': 34 | torch.set_default_dtype(torch.float64) 35 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 36 | best_acc = 0 # best test accuracy 37 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 38 | 39 | # Data 40 | print('==> Preparing data..') 41 | transform_train = transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 44 | ]) 45 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 46 | if args.data == 'random': 47 | trainset.train_data=np.random.randint(256,size=(500000, 32,32,3),dtype=np.uint8) # be careful, there was a recent 48 | # modification of torch, you might have to switch 'train_data' to 'data' 49 | print('randomized') 50 | trainset = torch.utils.data.Subset(trainset,range(args.subset)) 51 | 52 | trainloader = torch.utils.data.DataLoader(trainset, shuffle=False, batch_size=args.bs, num_workers=2) 53 | trainloader2 = torch.utils.data.DataLoader(trainset,shuffle=False, batch_size=args.bs, num_workers=2) 54 | 55 | 56 | 57 | k=args.widening_factor 58 | # Model 59 | print('==> Building model..') 60 | net = None 61 | if args.model=='vgg': 62 | net = VGG('VGG11',k) 63 | elif args.model=='resnet': 64 | net = ResNet18(k) 65 | net = net.to(device) 66 | net = nn.DataParallel(net.to(device)) 67 | 68 | 69 | from torch.nn.init import xavier_normal_ as xavier 70 | def weights_init(m): 71 | if isinstance(m, nn.Conv2d): 72 | xavier(m.weight.data,gain=args.gain) 73 | m.bias.data.zero_() 74 | 75 | net.apply(weights_init) 76 | 77 | net2=copy.deepcopy(net) 78 | 79 | FC1 = nn.Linear(512*k, 10).cuda() 80 | FC2 = nn.Linear(512*k, 10).cuda() 81 | 82 | xavier(FC1.weight.data, gain=args.gain) 83 | 84 | 85 | FC1.bias.data.zero_() 86 | FC2.weight.data.copy_(FC1.weight.data) 87 | FC2.bias.data.copy_(FC1.bias.data) 88 | 89 | 90 | 91 | def linearized_outputs(inputs): 92 | net_parameters = list(net.parameters())+list(FC1.parameters()) 93 | params = sum([torch.numel(p) for p in net_parameters]) 94 | 95 | output_linearized=torch.zeros(inputs.size(0),10,params).cuda() 96 | output1 = net(inputs) 97 | output2 = net2(inputs) 98 | output = FC1(output1)-FC2(output2) 99 | for n in range(inputs.size(0)): 100 | for i in range(10): 101 | output[n, i].backward(retain_graph=True) 102 | p_idx=0 103 | for p in range(len(net_parameters)): 104 | output_linearized[n, i,p_idx:p_idx+net_parameters[p].numel()] = net_parameters[p].grad.data.view(-1) 105 | p_idx = p_idx + net_parameters[p].numel() 106 | for p in range(len(net_parameters)): 107 | net_parameters[p].grad.data.zero_() 108 | 109 | output_linearized = output_linearized.view(inputs.size(0)*10,params) 110 | return output_linearized 111 | 112 | 113 | 114 | def extract_features(epoch): 115 | print('\nEpoch: %d' % epoch) 116 | K = torch.zeros([args.subset*10,args.subset*10],dtype=torch.float64) 117 | idx = 0 118 | idx2 = 0 119 | for batch_idx, (inputs, targets) in enumerate(trainloader): 120 | inputs, targets = inputs.to(device), targets.to(device) 121 | if args.precision == 'double': 122 | inputs = inputs.double() 123 | out = linearized_outputs(inputs) 124 | 125 | progress_bar(batch_idx, len(trainloader), 'bar 1') 126 | for batch_idx2, (inputs2, targets2) in enumerate(trainloader2): 127 | if(batch_idx2