├── .gitignore ├── README.md ├── diagnose.py ├── neurips_poster.pdf ├── src ├── data.py ├── linalg.py ├── models │ ├── mnist.py │ ├── resnet.py │ └── vgg.py ├── trainer.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | checkpoints/* 6 | *.pkl 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | 102 | 103 | ## Core latex/pdflatex auxiliary files: 104 | *.aux 105 | *.lof 106 | *.log 107 | *.lot 108 | *.fls 109 | *.out 110 | *.toc 111 | *.fmt 112 | *.fot 113 | *.cb 114 | *.cb2 115 | 116 | ## Intermediate documents: 117 | *.dvi 118 | *.xdv 119 | *-converted-to.* 120 | # these rules might exclude image files for figures etc. 121 | # *.ps 122 | # *.eps 123 | # *.pdf 124 | 125 | ## Generated if empty string is given at "Please type another file name for output:" 126 | .pdf 127 | 128 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 129 | *.bbl 130 | *.bcf 131 | *.blg 132 | *-blx.aux 133 | *-blx.bib 134 | *.run.xml 135 | 136 | ## Build tool auxiliary files: 137 | *.fdb_latexmk 138 | *.synctex 139 | *.synctex(busy) 140 | *.synctex.gz 141 | *.synctex.gz(busy) 142 | *.pdfsync 143 | 144 | ## Auxiliary and intermediate files from other packages: 145 | # algorithms 146 | *.alg 147 | *.loa 148 | 149 | # achemso 150 | acs-*.bib 151 | 152 | # amsthm 153 | *.thm 154 | 155 | # beamer 156 | *.nav 157 | *.pre 158 | *.snm 159 | *.vrb 160 | 161 | # changes 162 | *.soc 163 | 164 | # cprotect 165 | *.cpt 166 | 167 | # elsarticle (documentclass of Elsevier journals) 168 | *.spl 169 | 170 | # endnotes 171 | *.ent 172 | 173 | # fixme 174 | *.lox 175 | 176 | # feynmf/feynmp 177 | *.mf 178 | *.mp 179 | *.t[1-9] 180 | *.t[1-9][0-9] 181 | *.tfm 182 | 183 | #(r)(e)ledmac/(r)(e)ledpar 184 | *.end 185 | *.?end 186 | *.[1-9] 187 | *.[1-9][0-9] 188 | *.[1-9][0-9][0-9] 189 | *.[1-9]R 190 | *.[1-9][0-9]R 191 | *.[1-9][0-9][0-9]R 192 | *.eledsec[1-9] 193 | *.eledsec[1-9]R 194 | *.eledsec[1-9][0-9] 195 | *.eledsec[1-9][0-9]R 196 | *.eledsec[1-9][0-9][0-9] 197 | *.eledsec[1-9][0-9][0-9]R 198 | 199 | # glossaries 200 | *.acn 201 | *.acr 202 | *.glg 203 | *.glo 204 | *.gls 205 | *.glsdefs 206 | 207 | # gnuplottex 208 | *-gnuplottex-* 209 | 210 | # gregoriotex 211 | *.gaux 212 | *.gtex 213 | 214 | # hyperref 215 | *.brf 216 | 217 | # knitr 218 | *-concordance.tex 219 | # TODO Comment the next line if you want to keep your tikz graphics files 220 | *.tikz 221 | *-tikzDictionary 222 | 223 | # listings 224 | *.lol 225 | 226 | # makeidx 227 | *.idx 228 | *.ilg 229 | *.ind 230 | *.ist 231 | 232 | # minitoc 233 | *.maf 234 | *.mlf 235 | *.mlt 236 | *.mtc[0-9]* 237 | *.slf[0-9]* 238 | *.slt[0-9]* 239 | *.stc[0-9]* 240 | 241 | # minted 242 | _minted* 243 | *.pyg 244 | 245 | # morewrites 246 | *.mw 247 | 248 | # nomencl 249 | *.nlo 250 | 251 | # pax 252 | *.pax 253 | 254 | # pdfpcnotes 255 | *.pdfpc 256 | 257 | # sagetex 258 | *.sagetex.sage 259 | *.sagetex.py 260 | *.sagetex.scmd 261 | 262 | # scrwfile 263 | *.wrt 264 | 265 | # sympy 266 | *.sout 267 | *.sympy 268 | sympy-plots-for-*.tex/ 269 | 270 | # pdfcomment 271 | *.upa 272 | *.upb 273 | 274 | # pythontex 275 | *.pytxcode 276 | pythontex-files-*/ 277 | 278 | # thmtools 279 | *.loe 280 | 281 | # TikZ & PGF 282 | *.dpth 283 | *.md5 284 | *.auxlock 285 | 286 | # todonotes 287 | *.tdo 288 | 289 | # easy-todo 290 | *.lod 291 | 292 | # xindy 293 | *.xdy 294 | 295 | # xypic precompiled matrices 296 | *.xyc 297 | 298 | # endfloat 299 | *.ttt 300 | *.fff 301 | 302 | # Latexian 303 | TSWLatexianTemp* 304 | 305 | ## Editors: 306 | # WinEdt 307 | *.bak 308 | *.sav 309 | 310 | # Texpad 311 | .texpadtmp 312 | 313 | # Kile 314 | *.backup 315 | 316 | # KBibTeX 317 | *~[0-9]* 318 | 319 | # auto folder when using emacs and auctex 320 | /auto/* 321 | 322 | # expex forward references with \gathertags 323 | *-tags.tex 324 | 325 | experiments/ 326 | slurm/ 327 | data/* 328 | nohup.out 329 | *.PNG 330 | *.DS_Store 331 | *._* 332 | 333 | 334 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## How SGD Selects the Global Minima in Over-parameterized Learning: A Dynamical Stability Perspective 2 | by Lei Wu, Chao Ma, Weinan E 3 | 4 | 5 | ### Training 6 | ``` 7 | python train.py --dataset fashionmnist --training_size 1000 --model_file net.pkl 8 | ``` 9 | 10 | ### Computing sharpness and non-uniformity 11 | ``` 12 | python diagnose.py --dataset fashionmnist --training_size 1000 --model_file net.pkl 13 | ``` 14 | 15 | 16 | ### Dependencies 17 | - pytorch >= 0.4 18 | 19 | ### Citation 20 | 21 | @inproceedings{leiwu2018, 22 | title={How SGD Selects Global Minima in Over-parameterized Learning: A Dynamical Stability Perspetive}, 23 | author={Wu, Lei and Ma, Chao and E, Weinan}, 24 | booktitle={Advances in Neural Information Processing Systems}, 25 | year={2018} 26 | } 27 | -------------------------------------------------------------------------------- /diagnose.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import json 5 | import torch 6 | 7 | from src.utils import load_net, load_data, \ 8 | get_sharpness, get_nonuniformity, \ 9 | eval_accuracy 10 | 11 | def get_args(): 12 | argparser = argparse.ArgumentParser(description=__doc__) 13 | argparser.add_argument('--gpuid',default='0,') 14 | argparser.add_argument('--dataset',default='fashionmnist', 15 | help='dataset choosed, [fashionmnist] | cifar10') 16 | argparser.add_argument('--n_samples',type=int, 17 | default=1000, help='training set size, [1000]') 18 | argparser.add_argument('--batch_size', type=int, 19 | default=1000, help='batch size') 20 | argparser.add_argument('--model_file', default='fnn.pkl', 21 | help='file name of the pretrained model') 22 | args = argparser.parse_args() 23 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpuid 24 | 25 | print('===> Config:') 26 | print(json.dumps(vars(args),indent=2)) 27 | return args 28 | 29 | 30 | def main(): 31 | args = get_args() 32 | 33 | # load model 34 | criterion = torch.nn.MSELoss().cuda() 35 | train_loader,test_loader = load_data(args.dataset, 36 | training_size=args.n_samples, 37 | batch_size=args.batch_size) 38 | net = load_net(args.dataset) 39 | net.load_state_dict(torch.load(args.model_file)) 40 | 41 | # Evaluate models 42 | train_loss, train_accuracy = eval_accuracy(net, criterion, train_loader) 43 | test_loss, test_accuracy = eval_accuracy(net, criterion, test_loader) 44 | 45 | print('===> Basic information of the given model: ') 46 | print('\t train loss: %.2e, acc: %.2f'%(train_loss, train_accuracy)) 47 | print('\t test loss: %.2e, acc: %.2f'%(test_loss, test_accuracy)) 48 | 49 | print('===> Compute sharpness:') 50 | sharpness = get_sharpness(net, criterion, train_loader, \ 51 | n_iters=10, verbose=True, tol=1e-4) 52 | print('Sharpness is %.2e\n'%(sharpness)) 53 | 54 | print('===> Compute non-uniformity:') 55 | non_uniformity = get_nonuniformity(net, criterion, train_loader, \ 56 | n_iters=10, verbose=True, tol=1e-4) 57 | print('Non-uniformity is %.2e\n'%(non_uniformity)) 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /neurips_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiwu0/sgd.stability/f7d650b81c999462200357d62474697da03beace/neurips_poster.pdf -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as dsets 4 | 5 | 6 | class DataLoader: 7 | 8 | def __init__(self,X,y,batch_size): 9 | self.X, self.y = X, y 10 | self.batch_size = batch_size 11 | self.n_samples = len(y) 12 | self.idx = 0 13 | 14 | def __len__(self): 15 | length = self.n_samples // self.batch_size 16 | if self.n_samples > length * self.batch_size: 17 | length += 1 18 | return length 19 | 20 | def __iter__(self): 21 | return self 22 | 23 | def __next__(self): 24 | if self.idx >= self.n_samples: 25 | self.idx = 0 26 | rnd_idx = torch.randperm(self.n_samples) 27 | self.X = self.X[rnd_idx] 28 | self.y = self.y[rnd_idx] 29 | 30 | idx_end = min(self.idx+self.batch_size, self.n_samples) 31 | batch_X = self.X[self.idx:idx_end] 32 | batch_y = self.y[self.idx:idx_end] 33 | self.idx = idx_end 34 | 35 | return batch_X,batch_y 36 | 37 | 38 | def load_fmnist(training_size, batch_size=100): 39 | train_set = dsets.FashionMNIST('data/fashionmnist', train=True, download=True) 40 | train_X, train_y = train_set.data[0:training_size].float()/255, \ 41 | to_one_hot(train_set.targets[0:training_size]) 42 | train_loader = DataLoader(train_X, train_y, batch_size) 43 | 44 | test_set = dsets.FashionMNIST('data/fashionmnist', train=False,download=True) 45 | test_X, test_y = test_set.data.float()/255, \ 46 | to_one_hot(test_set.targets) 47 | test_loader = DataLoader(test_X, test_y, batch_size) 48 | 49 | return train_loader, test_loader 50 | 51 | 52 | def load_cifar10(training_size, batch_size=100): 53 | """ 54 | load cifar10 dataset. Notice that here we only use examples 55 | corresponding to label 0 and 1. Thus the training_size is at 56 | most 10000. 57 | """ 58 | train_set = dsets.CIFAR10('data/cifar10', train=True, download=True) 59 | train_X,train_y = modify_cifar_data(train_set.data, train_set.targets, training_size) 60 | train_loader = DataLoader(train_X, train_y, batch_size) 61 | 62 | test_set = dsets.CIFAR10('data/cifar10', train=False, download=True) 63 | test_X,test_y = modify_cifar_data(test_set.data, test_set.targets) 64 | test_loader = DataLoader(test_X, test_y, batch_size) 65 | 66 | return train_loader, test_loader 67 | 68 | 69 | def modify_cifar_data(X, y, n_samples=-1): 70 | X = torch.from_numpy(X.transpose([0,3,1,2])) 71 | y = torch.LongTensor(y) 72 | 73 | X_t = torch.Tensor(50000,3,32,32) 74 | y_t = torch.LongTensor(50000) 75 | idx = 0 76 | for i in range(len(y)): 77 | if y[i] == 0 or y[i] == 1: 78 | y_t[idx] = y[i] 79 | X_t[idx,:,:,:] = X[i,:,:,:] 80 | idx += 1 81 | X = X_t[0:idx] 82 | y = y_t[0:idx] 83 | 84 | if n_samples > 1: 85 | X = X[0:n_samples] 86 | y = y[0:n_samples] 87 | 88 | # preprocess the data 89 | X = X.float()/255.0 90 | y = to_one_hot(y) 91 | 92 | return X, y 93 | 94 | 95 | def to_one_hot(labels): 96 | if labels.ndimension()==1: 97 | labels.unsqueeze_(1) 98 | n_samples = labels.shape[0] 99 | n_classes = labels.max()+1 100 | 101 | one_hot_labels = torch.FloatTensor(n_samples,n_classes) 102 | one_hot_labels.zero_() 103 | one_hot_labels.scatter_(1, labels, 1) 104 | 105 | return one_hot_labels 106 | 107 | 108 | if __name__ == '__main__': 109 | train_loader, test_loader = load_cifar10(training_size=10000,batch_size=500) 110 | for i in range(30): 111 | batch_x, batch_y = next(train_loader) 112 | print(i, batch_x.shape, batch_y.shape) 113 | 114 | for i in range(4): 115 | batch_x, batch_y = next(test_loader) 116 | print(i, batch_x.shape, batch_y.shape) 117 | 118 | 119 | -------------------------------------------------------------------------------- /src/linalg.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.autograd as autograd 4 | 5 | def eigen_variance(net, criterion, dataloader, n_iters=10, tol=1e-2, verbose=False): 6 | n_parameters = num_parameters(net) 7 | v0 = torch.randn(n_parameters) 8 | 9 | Av_func = lambda v: variance_vec_prod(net, criterion, dataloader, v) 10 | mu = power_method(v0, Av_func, n_iters, tol, verbose) 11 | return mu 12 | 13 | 14 | def eigen_hessian(net, criterion, dataloader, n_iters=10, tol=1e-2, verbose=False): 15 | n_parameters = num_parameters(net) 16 | v0 = torch.randn(n_parameters) 17 | 18 | Av_func = lambda v: hessian_vec_prod(net, criterion, dataloader, v) 19 | mu = power_method(v0, Av_func, n_iters, tol, verbose) 20 | return mu 21 | 22 | 23 | def variance_vec_prod(net, criterion, dataloader, v): 24 | X, y = dataloader.X, dataloader.y 25 | Av, Hv, n_samples = 0, 0, len(y) 26 | 27 | for i in range(n_samples): 28 | bx, by = X[i:i+1].cuda(), y[i:i+1].cuda() 29 | Hv_i = Hv_batch(net, criterion, bx, by, v) 30 | Av_i = Hv_batch(net, criterion, bx, by, Hv_i) 31 | Av += Av_i 32 | Hv += Hv_i 33 | Av /= n_samples 34 | Hv /= n_samples 35 | H2v = hessian_vec_prod(net, criterion, dataloader, Hv) 36 | return Av - H2v 37 | 38 | 39 | def hessian_vec_prod(net, criterion, dataloader, v): 40 | Hv_t = 0 41 | n_batchs = len(dataloader) 42 | dataloader.idx = 0 43 | for _ in range(n_batchs): 44 | bx, by = next(dataloader) 45 | Hv_t += Hv_batch(net, criterion, bx.cuda(), by.cuda(), v) 46 | 47 | return Hv_t/n_batchs 48 | 49 | 50 | def Hv_batch(net, criterion, batch_x, batch_y, v): 51 | """ 52 | Hessian vector multiplication 53 | """ 54 | net.eval() 55 | logits = net(batch_x) 56 | loss = criterion(logits, batch_y) 57 | 58 | grads = autograd.grad(loss, net.parameters(), create_graph=True, retain_graph=True) 59 | idx, res = 0, 0 60 | for grad_i in grads: 61 | ng = torch.numel(grad_i) 62 | v_i = v[idx:idx+ng].cuda() 63 | res += torch.dot(v_i, grad_i.view(-1)) 64 | idx += ng 65 | 66 | Hv = autograd.grad(res, net.parameters()) 67 | Hv = [t.data.cpu().view(-1) for t in Hv] 68 | Hv = torch.cat(Hv) 69 | return Hv 70 | 71 | 72 | def power_method(v0, Av_func, n_iters=10, tol=1e-3, verbose=False): 73 | mu = 0 74 | v = v0/v0.norm() 75 | for i in range(n_iters): 76 | time_start = time.time() 77 | 78 | Av = Av_func(v) 79 | mu_pre = mu 80 | mu = torch.dot(Av,v).item() 81 | v = Av/Av.norm() 82 | 83 | if abs(mu-mu_pre)/abs(mu) < tol: 84 | break 85 | if verbose: 86 | print('%d-th step takes %.0f seconds, \t %.2e'%(i+1,time.time()-time_start,mu)) 87 | return mu 88 | 89 | 90 | def num_parameters(net): 91 | """ 92 | return the number of parameters for given model 93 | """ 94 | n_parameters = 0 95 | for para in net.parameters(): 96 | n_parameters += para.data.numel() 97 | 98 | return n_parameters 99 | -------------------------------------------------------------------------------- /src/models/mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LeNet(nn.Module): 7 | def __init__(self): 8 | super(LeNet,self).__init__() 9 | self.conv1 = nn.Conv2d(1,6,5,stride=1) # 28-5+1=24 10 | self.conv2 = nn.Conv2d(6,16,5,stride=1) # 12-5+1=8 11 | self.fc1 = nn.Linear(4*4*16,200) 12 | self.fc2 = nn.Linear(200,10) 13 | 14 | def forward(self,x): 15 | if x.ndimension()==3: 16 | x = x.unsqueeze(0) 17 | o = F.relu(self.conv1(x)) 18 | o = F.avg_pool2d(o,2,2) 19 | 20 | o = F.relu(self.conv2(o)) 21 | o = F.avg_pool2d(o,2,2) 22 | 23 | o = o.view(o.shape[0],-1) 24 | o = self.fc1(o) 25 | o = F.relu(o) 26 | o = self.fc2(o) 27 | return o 28 | 29 | class FNN(nn.Module): 30 | def __init__(self): 31 | super(FNN,self).__init__() 32 | self.net = nn.Sequential(nn.Linear(784,500), 33 | nn.ReLU(), 34 | nn.Linear(500,500), 35 | nn.ReLU(), 36 | nn.Linear(500,500), 37 | nn.ReLU(), 38 | nn.Linear(500,10)) 39 | 40 | def forward(self,x): 41 | x = x.view(x.shape[0],-1) 42 | o = self.net(x) 43 | return o 44 | 45 | 46 | def lenet(): 47 | return LeNet() 48 | 49 | def fnn(): 50 | return FNN() 51 | 52 | -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | def get_cifar_mean_std(): 9 | mean = torch.Tensor([x/255 for x in [125.3,123.0, 113.9]]).view(1,3,1,1) 10 | std = torch.Tensor([x/255 for x in [63.0, 62.1, 66.7]]).view(1,3,1,1) 11 | mean_var = mean.cuda() 12 | std_var = std.cuda() 13 | return mean_var,std_var 14 | 15 | #==================================== 16 | # ResNet 17 | #==================================== 18 | def conv3x3(in_channels,out_channels,stride=1): 19 | return nn.Conv2d(in_channels,out_channels,kernel_size=3, 20 | stride=stride,padding=1,bias=True) 21 | 22 | 23 | class short_cut(nn.Module): 24 | def __init__(self,in_channels,out_channels,type='A'): 25 | super(short_cut,self).__init__() 26 | self.type = 'D' if in_channels == out_channels else type 27 | if self.type == 'C': 28 | self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=1,padding=0,stride=2,bias=False) 29 | self.bn = nn.BatchNorm2d(out_channels) 30 | elif self.type == 'A': 31 | self.avg = nn.AvgPool2d(kernel_size=1,stride=2) 32 | 33 | def forward(self,x): 34 | if self.type == 'A': 35 | x = self.avg(x) 36 | return torch.cat((x,x.mul(0)),1) 37 | elif self.type == 'C': 38 | x = self.conv(x) 39 | x = self.bn(x) 40 | return x 41 | elif self.type == 'D': 42 | return x 43 | 44 | class residual_block(nn.Module): 45 | def __init__(self,in_channels,out_channels,stride=1,shortcutType='D'): 46 | super(residual_block,self).__init__() 47 | self.conv1 = conv3x3(in_channels,out_channels,stride) 48 | self.bn1 = nn.BatchNorm2d(out_channels) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3(out_channels,out_channels) 51 | self.bn2 = nn.BatchNorm2d(out_channels) 52 | self.shortcut = short_cut(in_channels,out_channels,type=shortcutType) 53 | 54 | def forward(self,x): 55 | o = self.conv1(x) 56 | o = self.bn1(o) 57 | o = self.relu(o) 58 | o = self.conv2(o) 59 | o = self.bn2(o) 60 | o += self.shortcut(x) 61 | 62 | o = self.relu(o) 63 | return o 64 | 65 | class ResNet(nn.Module): 66 | def __init__(self,block,depth,multiple=2,num_classes=10,shortcutType='A'): 67 | super(ResNet,self).__init__() 68 | assert (depth-2) %6 == 0 , 'depth should be 6*m + 2, like 20 32 44 56 110' 69 | num_blocks = (depth-2)//6 70 | print('resnet: depth: %d, # of blocks at each stage: %d'%(depth,num_blocks)) 71 | 72 | self.in_channels = 8*multiple 73 | self.conv = conv3x3(3,self.in_channels) 74 | self.bn = nn.BatchNorm2d(self.in_channels) 75 | self.relu = nn.ReLU(inplace=True) 76 | 77 | self.stage1 = self._make_layer(block,8*multiple,num_blocks,1) # 32x32x16 78 | self.stage2 = self._make_layer(block,16*multiple,num_blocks,2) # 16x16x32 79 | self.stage3 = self._make_layer(block,32*multiple,num_blocks,2) # 8x8x64 80 | self.avg_pool = nn.AvgPool2d(8) 81 | self.fc = nn.Linear(32*multiple,num_classes) 82 | self.name = 'resnet' 83 | 84 | # initialization by Kaiming strategy 85 | for m in self.modules(): 86 | if isinstance(m,nn.Conv2d): 87 | fin = m.kernel_size[0]*m.kernel_size[1]*m.out_channels 88 | m.weight.data.normal_(0,math.sqrt(2.0/fin)) 89 | #m.bias.data.zero_() 90 | elif isinstance(m,nn.BatchNorm2d): 91 | m.weight.data.fill_(1) 92 | m.bias.data.zero_() 93 | elif isinstance(m,nn.Linear): 94 | init.kaiming_normal_(m.weight) 95 | m.bias.data.zero_() 96 | 97 | self.mean_var,self.std_var = get_cifar_mean_std() 98 | 99 | 100 | def forward(self,x): 101 | x = (x-self.mean_var)/self.std_var 102 | o = self.conv(x) 103 | o = self.bn(o) 104 | o = self.relu(o) 105 | 106 | o = self.stage1(o) 107 | o = self.stage2(o) 108 | o = self.stage3(o) 109 | 110 | o = self.avg_pool(o) 111 | o = o.view(o.size(0),-1) 112 | o = self.fc(o) 113 | return o 114 | 115 | def _make_layer(self,block,out_channels,num_blocks,stride=1,shortcutType='A'): 116 | layers = [] 117 | layers.append(block(self.in_channels,out_channels,stride,shortcutType=shortcutType)) 118 | self.in_channels = out_channels 119 | for i in range(1,num_blocks): 120 | layers.append(block(out_channels,out_channels)) 121 | return nn.Sequential(*layers) 122 | 123 | 124 | #==================================== 125 | # API 126 | #==================================== 127 | def resnet(width=2,depth=14,num_classes=10): 128 | if (depth-2)%6 != 0: 129 | raise ValueError('depth: %d is not legal depth for resnet'%(depth)) 130 | return ResNet(residual_block,depth,width,num_classes) 131 | -------------------------------------------------------------------------------- /src/models/vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision.git 3 | ''' 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | class VGG(nn.Module): 16 | ''' 17 | VGG model 18 | ''' 19 | def __init__(self, features,feature_size=512,num_classes=10): 20 | super(VGG, self).__init__() 21 | self.features = features 22 | self.classifier = nn.Sequential( 23 | nn.Linear(feature_size, 128), 24 | nn.ReLU(True), 25 | nn.Linear(128,num_classes), 26 | ) 27 | # Initialize weights 28 | for m in self.modules(): 29 | if isinstance(m, nn.Conv2d): 30 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 31 | m.weight.data.normal_(0, math.sqrt(2. / n)) 32 | m.bias.data.zero_() 33 | 34 | 35 | def forward(self, x): 36 | x = self.features(x) 37 | x = x.view(x.size(0), -1) 38 | x = self.classifier(x) 39 | return x 40 | 41 | 42 | def make_layers(cfg, batch_norm=False): 43 | layers = [] 44 | in_channels = 3 45 | for v in cfg: 46 | if v == 'M': 47 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 48 | else: 49 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 50 | if batch_norm: 51 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 52 | else: 53 | layers += [conv2d, nn.ReLU(inplace=True)] 54 | in_channels = v 55 | return nn.Sequential(*layers) 56 | 57 | 58 | cfg = { 59 | 'A': [16, 'M', 16, 'M', 32, 'M', 64, 'M', 64, 'M'], 60 | 'A1': [16, 'M', 32, 'M', 32, 32, 'M', 64, 64, 'M', 128, 128, 'M'], 61 | 'A2': [32, 'M', 64, 'M', 64, 64, 'M', 128, 128, 'M', 256, 256, 'M'], 62 | 'A3': [64, 'M', 128, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M'], 63 | 'A4': [128, 'M', 256, 'M', 256, 256, 'M', 512, 512, 'M', 1024, 1024, 'M'], 64 | 'B': [16, 16, 'M', 32, 32, 'M', 64, 64, 'M', 128, 128, 'M', 128, 128, 'M'], 65 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 66 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 67 | 512, 512, 512, 512, 'M'], 68 | } 69 | 70 | 71 | def vgg11(num_classes=10): 72 | """VGG 11-layer model (configuration "A")""" 73 | return VGG(make_layers(cfg['A']),feature_size=64,num_classes=num_classes) 74 | 75 | 76 | def vgg11_big(num_classes=10): 77 | """VGG 11-layer model (configuration "A")""" 78 | return VGG(make_layers(cfg['A3']),cfg['A3'][-2],num_classes) 79 | 80 | def vgg11_bn(num_classes): 81 | """VGG 11-layer model (configuration "A") with batch normalization""" 82 | return VGG(make_layers(cfg['A'], batch_norm=True)) 83 | 84 | 85 | def vgg13(num_classes=10): 86 | """VGG 13-layer model (configuration "B")""" 87 | return VGG(make_layers(cfg['B']),num_classes) 88 | 89 | 90 | def vgg13_bn(): 91 | """VGG 13-layer model (configuration "B") with batch normalization""" 92 | return VGG(make_layers(cfg['B'], batch_norm=True)) 93 | 94 | 95 | def vgg16(): 96 | """VGG 16-layer model (configuration "D")""" 97 | return VGG(make_layers(cfg['D'])) 98 | 99 | 100 | def vgg16_bn(): 101 | """VGG 16-layer model (configuration "D") with batch normalization""" 102 | return VGG(make_layers(cfg['D'], batch_norm=True)) 103 | 104 | 105 | def vgg19(): 106 | """VGG 19-layer model (configuration "E")""" 107 | return VGG(make_layers(cfg['E'])) 108 | 109 | 110 | def vgg19_bn(): 111 | """VGG 19-layer model (configuration 'E') with batch normalization""" 112 | return VGG(make_layers(cfg['E'], batch_norm=True)) 113 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | def train(model, criterion, optimizer, dataloader, batch_size, n_iters=50000, verbose=True): 5 | model.train() 6 | acc_avg, loss_avg = 0, 0 7 | 8 | since = time.time() 9 | for iter_now in range(n_iters): 10 | optimizer.zero_grad() 11 | loss,acc = compute_minibatch_gradient(model, criterion, dataloader, batch_size) 12 | optimizer.step() 13 | 14 | acc_avg = 0.9 * acc_avg + 0.1 * acc if acc_avg > 0 else acc 15 | loss_avg = 0.9 * loss_avg + 0.1 * loss if loss_avg > 0 else loss 16 | 17 | if iter_now%200 == 0 and verbose: 18 | now = time.time() 19 | print('%d/%d, took %.0f seconds, train_loss: %.1e, train_acc: %.2f'%( 20 | iter_now+1, n_iters, now-since, loss_avg, acc_avg)) 21 | since = time.time() 22 | 23 | 24 | def compute_minibatch_gradient(model, criterion, dataloader, batch_size): 25 | loss,acc = 0,0 26 | n_loads = batch_size // dataloader.batch_size 27 | 28 | for i in range(n_loads): 29 | inputs,targets = next(dataloader) 30 | inputs, targets = inputs.cuda(), targets.cuda() 31 | 32 | logits = model(inputs) 33 | E = criterion(logits,targets) 34 | E.backward() 35 | 36 | loss += E.item() 37 | acc += accuracy(logits.data,targets) 38 | 39 | for p in model.parameters(): 40 | p.grad.data /= n_loads 41 | 42 | return loss/n_loads, acc/n_loads 43 | 44 | 45 | def accuracy(logits, targets): 46 | n = logits.shape[0] 47 | if targets.ndimension() == 2: 48 | _, y_trues = torch.max(targets,1) 49 | else: 50 | y_trues = targets 51 | _, y_preds = torch.max(logits,1) 52 | 53 | acc = (y_trues==y_preds).float().sum()*100.0/n 54 | return acc 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from .models.vgg import vgg11 3 | from .models.mnist import fnn 4 | from .data import load_fmnist,load_cifar10 5 | from .trainer import accuracy 6 | from .linalg import eigen_variance, eigen_hessian 7 | 8 | 9 | 10 | def load_net(dataset): 11 | if dataset == 'fashionmnist': 12 | return fnn().cuda() 13 | elif dataset == 'cifar10': 14 | return vgg11(num_classes=2).cuda() 15 | else: 16 | raise ValueError('Dataset %s is not supported'%(dataset)) 17 | 18 | 19 | def load_data(dataset, training_size, batch_size): 20 | if dataset == 'fashionmnist': 21 | return load_fmnist(training_size, batch_size) 22 | elif dataset == 'cifar10': 23 | return load_cifar10(training_size, batch_size) 24 | else: 25 | raise ValueError('Dataset %s is not supported'%(dataset)) 26 | 27 | 28 | def get_sharpness(net, criterion, dataloader, n_iters=10, tol=1e-2, verbose=False): 29 | v = eigen_hessian(net, criterion, dataloader, \ 30 | n_iters=n_iters, tol=tol, verbose=verbose) 31 | return v 32 | 33 | 34 | def get_nonuniformity(net, criterion, dataloader, n_iters=10, tol=1e-2, verbose=False): 35 | v = eigen_variance(net, criterion, dataloader, \ 36 | n_iters=n_iters, tol=tol, verbose=verbose) 37 | return math.sqrt(v) 38 | 39 | 40 | def eval_accuracy(model, criterion, dataloader): 41 | model.eval() 42 | n_batchs = len(dataloader) 43 | dataloader.idx = 0 44 | 45 | loss_t, acc_t = 0.0, 0.0 46 | for i in range(n_batchs): 47 | inputs,targets = next(dataloader) 48 | inputs, targets = inputs.cuda(), targets.cuda() 49 | 50 | logits = model(inputs) 51 | loss_t += criterion(logits,targets).item() 52 | acc_t += accuracy(logits.data,targets) 53 | 54 | return loss_t/n_batchs, acc_t/n_batchs 55 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import torch 5 | 6 | from src.trainer import train 7 | from src.utils import load_net, load_data, eval_accuracy 8 | 9 | 10 | def get_args(): 11 | argparser = argparse.ArgumentParser(description=__doc__) 12 | argparser.add_argument('--gpuid', 13 | default='0,', help='gpu id, [0] ') 14 | argparser.add_argument('--dataset', 15 | default='fashionmnist', help='dataset, [fashionmnist] | cifar10') 16 | argparser.add_argument('--n_samples', type=int, 17 | default=1000, help='training set size, [1000]') 18 | argparser.add_argument('--load_size', type=int, 19 | default=1000, help='load size for dataset, [1000]') 20 | argparser.add_argument('--optimizer', 21 | default='sgd', help='optimizer, [sgd]') 22 | argparser.add_argument('--n_iters', type=int, 23 | default=10000, help='number of iteration used to train nets, [10000]') 24 | argparser.add_argument('--batch_size', type=int, 25 | default=1000, help='batch size, [1000]') 26 | argparser.add_argument('--learning_rate', type=float, 27 | default=1e-1, help='learning rate') 28 | argparser.add_argument('--momentum', type=float, 29 | default='0.0', help='momentum, [0.0]') 30 | argparser.add_argument('--model_file', 31 | default='fnn.pkl', help='filename to save the net, fnn.pkl') 32 | 33 | args = argparser.parse_args() 34 | if args.load_size > args.batch_size: 35 | raise ValueError('load size should not be larger than batch size') 36 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpuid 37 | 38 | print('===> Config:') 39 | print(json.dumps(vars(args), indent=2)) 40 | return args 41 | 42 | def get_optimizer(net, args): 43 | if args.optimizer == 'sgd': 44 | return torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=args.momentum) 45 | elif args.optimizer == 'adam': 46 | return torch.optim.Adam(net.parameters(), lr=args.learning_rate) 47 | else: 48 | raise ValueError('optimizer %s has not been supported'%(args.optimizer)) 49 | 50 | def main(): 51 | args = get_args() 52 | 53 | criterion = torch.nn.MSELoss().cuda() 54 | train_loader, test_loader = load_data(args.dataset, 55 | training_size=args.n_samples, 56 | batch_size=args.load_size) 57 | net = load_net(args.dataset) 58 | optimizer = get_optimizer(net, args) 59 | print(optimizer) 60 | 61 | print('===> Architecture:') 62 | print(net) 63 | 64 | print('===> Start training') 65 | train(net, criterion, optimizer, train_loader, args.batch_size, args.n_iters, verbose=True) 66 | 67 | train_loss, train_accuracy = eval_accuracy(net, criterion, train_loader) 68 | test_loss, test_accuracy = eval_accuracy(net, criterion, test_loader) 69 | print('===> Solution: ') 70 | print('\t train loss: %.2e, acc: %.2f' % (train_loss, train_accuracy)) 71 | print('\t test loss: %.2e, acc: %.2f' % (test_loss, test_accuracy)) 72 | 73 | torch.save(net.state_dict(), args.model_file) 74 | 75 | 76 | if __name__ == '__main__': 77 | main() 78 | --------------------------------------------------------------------------------