├── README.md ├── main.py ├── models.py ├── run.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # SGD Tail-Index Analysis 2 | on the importance of the role of the noise in SGD 3 | 4 | Please use the run.py as a template to run bathces of experiments to reproduce the results in the paper. Make sure to replace the command that is specific to your server. 5 | 6 | 7 | If you find this code useful, thank you for citing the following work: 8 | 9 | > @inproceedings{Simsekli2019Tail, 10 | author = {U. Simsekli and L. Sagun and M. Gurbuzbalaban} 11 | title = {A Tail-Index Analysis of Stochastic Gradient Noise in Deep Neural Networks}, 12 | booktitle = {Proceedings of the 36th International Conference on Machine Learning, (ICML) 2019}, 13 | address = {Long Beach, CA, USA}, 14 | month = jun, 15 | year = {2019}, 16 | } 17 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import math 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | from models import alexnet, fc 10 | from utils import get_id, get_data, accuracy 11 | from utils import get_grads, alpha_estimator, alpha_estimator2 12 | from utils import linear_hinge_loss, get_layerWise_norms 13 | 14 | 15 | def eval(eval_loader, net, crit, opt, args, test=True): 16 | 17 | net.eval() 18 | 19 | # run over both test and train set 20 | total_size = 0 21 | total_loss = 0 22 | total_acc = 0 23 | grads = [] 24 | outputs = [] 25 | 26 | P = 0 # num samples / batch size 27 | for x, y in eval_loader: 28 | P += 1 29 | # loop over dataset 30 | x, y = x.to(args.device), y.to(args.device) 31 | opt.zero_grad() 32 | out = net(x) 33 | 34 | outputs.append(out) 35 | 36 | loss = crit(out, y) 37 | prec = accuracy(out, y) 38 | bs = x.size(0) 39 | 40 | loss.backward() 41 | grad = get_grads(net).cpu() 42 | grads.append(grad) 43 | 44 | total_size += int(bs) 45 | total_loss += float(loss) * bs 46 | total_acc += float(prec) * bs 47 | 48 | M = len(grads[0]) # total number of parameters 49 | grads = torch.cat(grads).view(-1, M) 50 | mean_grad = grads.sum(0) / P 51 | noise_norm = (grads - mean_grad).norm(dim=1) 52 | 53 | N = M * P 54 | 55 | for i in range(1, 1 + int(math.sqrt(N))): 56 | if N%i == 0: 57 | m = i 58 | alpha = alpha_estimator(m, (grads - mean_grad).view(-1, 1)) 59 | 60 | del grads 61 | del mean_grad 62 | 63 | hist = [ 64 | total_loss / total_size, 65 | total_acc / total_size, 66 | alpha.item() 67 | ] 68 | 69 | print(hist) 70 | 71 | return hist, outputs, noise_norm 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('--iterations', default=100000, type=int) 77 | parser.add_argument('--batch_size_train', default=100, type=int) 78 | parser.add_argument('--batch_size_eval', default=100, type=int, 79 | help='must be equal to training batch size') 80 | parser.add_argument('--lr', default=0.1, type=float) 81 | parser.add_argument('--mom', default=0, type=float) 82 | parser.add_argument('--wd', default=0, type=float) 83 | parser.add_argument('--print_freq', default=100, type=int) 84 | parser.add_argument('--eval_freq', default=100, type=int) 85 | parser.add_argument('--dataset', default='mnist', type=str, 86 | help='mnist | cifar10 | cifar100') 87 | parser.add_argument('--path', default='./data', type=str) 88 | parser.add_argument('--seed', default=0, type=int) 89 | parser.add_argument('--model', default='fc', type=str) 90 | parser.add_argument('--criterion', default='NLL', type=str, 91 | help='NLL | linear_hinge') 92 | parser.add_argument('--scale', default=64, type=int, 93 | help='scale of the number of convolutional filters') 94 | parser.add_argument('--depth', default=3, type=int) 95 | parser.add_argument('--width', default=100, type=int, 96 | help='width of fully connected layers') 97 | parser.add_argument('--save_dir', default='results/', type=str) 98 | parser.add_argument('--verbose', action='store_true', default=False) 99 | parser.add_argument('--double', action='store_true', default=False) 100 | parser.add_argument('--no_cuda', action='store_true', default=False) 101 | parser.add_argument('--lr_schedule', action='store_true', default=False) 102 | args = parser.parse_args() 103 | 104 | # initial setup 105 | if args.double: 106 | torch.set_default_tensor_type('torch.DoubleTensor') 107 | args.use_cuda = not args.no_cuda and torch.cuda.is_available() 108 | args.device = torch.device('cuda' if args.use_cuda else 'cpu') 109 | torch.manual_seed(args.seed) 110 | 111 | print(args) 112 | 113 | # training setup 114 | train_loader, test_loader_eval, train_loader_eval, num_classes = get_data(args) 115 | 116 | if args.model == 'fc': 117 | if args.dataset == 'mnist': 118 | net = fc(width=args.width, depth=args.depth, num_classes=num_classes).to(args.device) 119 | elif args.dataset == 'cifar10': 120 | net = fc(width=args.width, depth=args.depth, num_classes=num_classes, input_dim=3*32*32).to(args.device) 121 | elif args.model == 'alexnet': 122 | net = alexnet(ch=args.scale, num_classes=num_classes).to(args.device) 123 | 124 | print(net) 125 | 126 | opt = optim.SGD( 127 | net.parameters(), 128 | lr=args.lr, 129 | momentum=args.mom, 130 | weight_decay=args.wd 131 | ) 132 | 133 | if args.lr_schedule: 134 | milestone = int(args.iterations / 3) 135 | scheduler = optim.lr_scheduler.MultiStepLR(opt, 136 | milestones=[milestone, 2*milestone], 137 | gamma=0.5) 138 | 139 | if args.criterion == 'NLL': 140 | crit = nn.CrossEntropyLoss().to(args.device) 141 | elif args.criterion == 'linear_hinge': 142 | crit = linear_hinge_loss 143 | 144 | def cycle_loader(dataloader): 145 | while 1: 146 | for data in dataloader: 147 | yield data 148 | 149 | circ_train_loader = cycle_loader(train_loader) 150 | 151 | # training logs per iteration 152 | training_history = [] 153 | weight_grad_history = [] 154 | 155 | # eval logs less frequently 156 | evaluation_history_TEST = [] 157 | evaluation_history_TRAIN = [] 158 | noise_norm_history_TEST = [] 159 | noise_norm_history_TRAIN = [] 160 | 161 | STOP = False 162 | 163 | for i, (x, y) in enumerate(circ_train_loader): 164 | 165 | if i % args.eval_freq == 0: 166 | # first record is at the initial point 167 | te_hist, te_outputs, te_noise_norm = eval(test_loader_eval, net, crit, opt, args) 168 | tr_hist, tr_outputs, tr_noise_norm = eval(train_loader_eval, net, crit, opt, args, test=False) 169 | evaluation_history_TEST.append([i, *te_hist]) 170 | evaluation_history_TRAIN.append([i, *tr_hist]) 171 | noise_norm_history_TEST.append(te_noise_norm) 172 | noise_norm_history_TRAIN.append(tr_noise_norm) 173 | if int(tr_hist[1]) == 100: 174 | print('yaaay all training data is correctly classified!!!') 175 | STOP = True 176 | 177 | net.train() 178 | 179 | x, y = x.to(args.device), y.to(args.device) 180 | 181 | opt.zero_grad() 182 | out = net(x) 183 | loss = crit(out, y) 184 | 185 | # calculate the gradients 186 | loss.backward() 187 | 188 | # record training history (starts at initial point) 189 | training_history.append([i, loss.item(), accuracy(out, y).item()]) 190 | weight_grad_history.append([i, *get_layerWise_norms(net)]) 191 | 192 | # take the step 193 | opt.step() 194 | 195 | if i % args.print_freq == 0: 196 | print(training_history[-1]) 197 | 198 | if args.lr_schedule: 199 | scheduler.step(i) 200 | 201 | if i > args.iterations: 202 | STOP = True 203 | 204 | if STOP: 205 | # final evaluation and saving results 206 | print('eval time {}'.format(i)) 207 | te_hist, te_outputs, te_noise_norm = eval(test_loader_eval, net, crit, opt, args) 208 | tr_hist, tr_outputs, tr_noise_norm = eval(train_loader_eval, net, crit, opt, args, test=False) 209 | evaluation_history_TEST.append([i + 1, *te_hist]) 210 | evaluation_history_TRAIN.append([i + 1, *tr_hist]) 211 | noise_norm_history_TEST.append(te_noise_norm) 212 | noise_norm_history_TRAIN.append(tr_noise_norm) 213 | 214 | 215 | if not os.path.exists(args.save_dir): 216 | os.makedirs(args.save_dir) 217 | else: 218 | print('Folder already exists, beware of overriding old data!') 219 | 220 | # save the setup 221 | torch.save(args, args.save_dir + '/args.info') 222 | # save the outputs 223 | torch.save(te_outputs, args.save_dir + '/te_outputs.pyT') 224 | torch.save(tr_outputs, args.save_dir + '/tr_outputs.pyT') 225 | # save the model 226 | torch.save(net, args.save_dir + '/net.pyT') 227 | # save the logs 228 | torch.save(training_history, args.save_dir + '/training_history.hist') 229 | torch.save(weight_grad_history, args.save_dir + '/weight_history.hist') 230 | torch.save(evaluation_history_TEST, args.save_dir + '/evaluation_history_TEST.hist') 231 | torch.save(evaluation_history_TRAIN, args.save_dir + '/evaluation_history_TRAIN.hist') 232 | torch.save(noise_norm_history_TEST, args.save_dir + '/noise_norm_history_TEST.hist') 233 | torch.save(noise_norm_history_TRAIN, args.save_dir + '/noise_norm_history_TRAIN.hist') 234 | 235 | break 236 | 237 | 238 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Identical copies of two AlexNet models 2 | import torch 3 | import torch.nn as nn 4 | import copy 5 | 6 | class FullyConnected(nn.Module): 7 | 8 | def __init__(self, input_dim=28*28 , width=50, depth=3, num_classes=10): 9 | super(FullyConnected, self).__init__() 10 | self.input_dim = input_dim 11 | self.width = width 12 | self.depth = depth 13 | self.num_classes = num_classes 14 | 15 | layers = self.get_layers() 16 | 17 | self.fc = nn.Sequential( 18 | nn.Linear(self.input_dim, self.width, bias=False), 19 | nn.ReLU(inplace=True), 20 | *layers, 21 | nn.Linear(self.width, self.num_classes, bias=False), 22 | ) 23 | 24 | def get_layers(self): 25 | layers = [] 26 | for i in range(self.depth - 2): 27 | layers.append(nn.Linear(self.width, self.width, bias=False)) 28 | layers.append(nn.ReLU()) 29 | return layers 30 | 31 | def forward(self, x): 32 | x = x.view(x.size(0), self.input_dim) 33 | x = self.fc(x) 34 | return x 35 | 36 | 37 | # This is a copy from online repositories 38 | class AlexNet(nn.Module): 39 | 40 | def __init__(self, input_height=32, input_width=32, input_channels=3, ch=64, num_classes=1000): 41 | # ch is the scale factor for number of channels 42 | super(AlexNet, self).__init__() 43 | 44 | self.input_height = input_height 45 | self.input_width = input_width 46 | self.input_channels = input_channels 47 | 48 | self.features = nn.Sequential( 49 | nn.Conv2d(3, out_channels=ch, kernel_size=4, stride=2, padding=2), 50 | nn.ReLU(inplace=True), 51 | nn.MaxPool2d(kernel_size=3, stride=2), 52 | nn.Conv2d(ch, ch, kernel_size=5, padding=2), 53 | nn.ReLU(inplace=True), 54 | nn.MaxPool2d(kernel_size=3, stride=2), 55 | nn.Conv2d(ch, ch, kernel_size=3, padding=1), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(ch, ch, kernel_size=3, padding=1), 58 | nn.ReLU(inplace=True), 59 | nn.Conv2d(ch, ch, kernel_size=3, padding=1), 60 | nn.ReLU(inplace=True), 61 | nn.MaxPool2d(kernel_size=3, stride=2), 62 | ) 63 | 64 | self.size = self.get_size() 65 | print(self.size) 66 | a = torch.tensor(self.size).float() 67 | b = torch.tensor(2).float() 68 | self.width = int(a) * int(1 + torch.log(a) / torch.log(b)) 69 | 70 | self.classifier = nn.Sequential( 71 | nn.Dropout(), 72 | nn.Linear(self.size, self.width), 73 | nn.ReLU(inplace=True), 74 | nn.Dropout(), 75 | nn.Linear(self.width, self.width), 76 | nn.ReLU(inplace=True), 77 | nn.Linear(self.width, num_classes), 78 | ) 79 | 80 | def get_size(self): 81 | # hack to get the size for the FC layer... 82 | x = torch.randn(1, self.input_channels, self.input_height, self.input_width) 83 | y = self.features(x) 84 | print(y.size()) 85 | return y.view(-1).size(0) 86 | 87 | def forward(self, x): 88 | x = self.features(x) 89 | x = x.view(x.size(0), -1) 90 | x = self.classifier(x) 91 | return x 92 | 93 | 94 | def alexnet(**kwargs): 95 | return AlexNet(**kwargs) 96 | 97 | 98 | def fc(**kwargs): 99 | return FullyConnected(**kwargs) 100 | 101 | 102 | if __name__ == '__main__': 103 | # testing 104 | 105 | x = torch.randn(5, 1, 32, 32) 106 | net = FullyConnected(input_dim=32*32, width=123) 107 | print(net(x)) 108 | 109 | x = torch.randn(5, 3, 32, 32).cuda() 110 | net = AlexNet(ch=128).cuda() 111 | print(net(x)) 112 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import itertools 4 | import time 5 | 6 | # folder to save 7 | base_path = 'results_FC_scales' 8 | 9 | if not os.path.exists(base_path): 10 | os.makedirs(base_path) 11 | 12 | # server setup 13 | launcher = "srun --nodes=1 --gres=gpu:1 --time=40:00:00 --mem=60G" # THIS IS AN EXAMPLE!!! 14 | 15 | # experimental setup 16 | width = [4, 8, 16, 32, 64, 128, 256, 512, 1024] 17 | depth = [2, 3, 4, 5, 6, 7, 8, 9, 10] 18 | seeds = list(range(3)) 19 | dataset = ['mnist', 'cifar10'] 20 | loss = ['NLL','linear_hinge'] 21 | model = ['fc'] 22 | 23 | grid = itertools.product(width, depth, seeds, dataset, loss, model) 24 | 25 | processes = [] 26 | for w, dep, s, d, l, m in grid: 27 | 28 | save_dir = base_path + '/{}_{:04d}_{:02d}_{}_{}_{}'.format(dep, w, s, d, l, m) 29 | if os.path.exists(save_dir): 30 | # folder created only at the end when all is done! 31 | print('folder already exists, quitting') 32 | continue 33 | 34 | cmd = launcher + ' ' 35 | cmd += 'python main.py ' 36 | cmd += '--save_dir {} '.format(save_dir) 37 | cmd += '--width {} '.format(w) 38 | cmd += '--depth {} '.format(dep) 39 | cmd += '--seed {} '.format(s) 40 | cmd += '--dataset {} '.format(d) 41 | cmd += '--model {} '.format(m) 42 | cmd += '--lr {} '.format('0.1') 43 | cmd += '--lr_schedule ' 44 | cmd += '--iterations {} '.format(65000) 45 | # cmd += '--print_freq {} '.format(1), # dbg 46 | # cmd += '--verbose ' 47 | 48 | # print(cmd) 49 | 50 | f = open(save_dir + '.log', 'w') 51 | subprocess.Popen(cmd.split(), stdout=f, stderr=f)#.wait() 52 | 53 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # some useful functions 2 | 3 | import math 4 | import torch 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | 8 | 9 | def get_layerWise_norms(net): 10 | w = [] 11 | g = [] 12 | for p in net.parameters(): 13 | if p.requires_grad: 14 | w.append(p.view(-1).norm()) 15 | g.append(p.grad.view(-1).norm()) 16 | return w, g 17 | 18 | def linear_hinge_loss(output, target): 19 | binary_target = output.new_empty(*output.size()).fill_(-1) 20 | for i in range(len(target)): 21 | binary_target[i, target[i]] = 1 22 | delta = 1 - binary_target * output 23 | delta[delta <= 0] = 0 24 | return delta.mean() 25 | 26 | def get_grads(model): 27 | # wrt data at the current step 28 | res = [] 29 | for p in model.parameters(): 30 | if p.requires_grad: 31 | res.append(p.grad.view(-1)) 32 | grad_flat = torch.cat(res) 33 | return grad_flat 34 | 35 | # Corollary 2.4 in Mohammadi 2014 36 | def alpha_estimator(m, X): 37 | # X is N by d matrix 38 | N = len(X) 39 | n = int(N/m) # must be an integer 40 | Y = torch.sum(X.view(n, m, -1), 1) 41 | eps = np.spacing(1) 42 | Y_log_norm = torch.log(Y.norm(dim=1) + eps).mean() 43 | X_log_norm = torch.log(X.norm(dim=1) + eps).mean() 44 | diff = (Y_log_norm - X_log_norm) / math.log(m) 45 | return 1 / diff 46 | 47 | # Corollary 2.2 in Mohammadi 2014 48 | def alpha_estimator2(m, k, X): 49 | # X is N by d matrix 50 | N = len(X) 51 | n = int(N/m) # must be an integer 52 | Y = torch.sum(X.view(n, m, -1), 1) 53 | eps = np.spacing(1) 54 | Y_log_norm = torch.log(Y.norm(dim=1) + eps) 55 | X_log_norm = torch.log(X.norm(dim=1) + eps) 56 | 57 | # This can be implemented more efficiently by using 58 | # the np.partition function, which currently doesn't 59 | # exist in pytorch: may consider passing the tensor to np 60 | 61 | Yk = torch.sort(Y_log_norm)[0][k-1] 62 | Xk = torch.sort(X_log_norm)[0][m*k-1] 63 | diff = (Yk - Xk) / math.log(m) 64 | return 1 / diff 65 | 66 | def accuracy(out, y): 67 | _, pred = out.max(1) 68 | correct = pred.eq(y) 69 | return 100 * correct.sum().float() / y.size(0) 70 | 71 | def get_data(args): 72 | 73 | # mean/std stats 74 | if args.dataset == 'cifar10': 75 | data_class = 'CIFAR10' 76 | num_classes = 10 77 | stats = { 78 | 'mean': [0.491, 0.482, 0.447], 79 | 'std': [0.247, 0.243, 0.262] 80 | } 81 | elif args.dataset == 'cifar100': 82 | data_class = 'CIFAR100' 83 | num_classes = 100 84 | stats = { 85 | 'mean': [0.5071, 0.4867, 0.4408] , 86 | 'std': [0.2675, 0.2565, 0.2761] 87 | } 88 | elif args.dataset == 'mnist': 89 | data_class = 'MNIST' 90 | num_classes = 10 91 | stats = { 92 | 'mean': [0.1307], 93 | 'std': [0.3081] 94 | } 95 | else: 96 | raise ValueError("unknown dataset") 97 | 98 | # input transformation w/o preprocessing for now 99 | 100 | trans = [ 101 | transforms.ToTensor(), 102 | lambda t: t.type(torch.get_default_dtype()), 103 | transforms.Normalize(**stats) 104 | ] 105 | 106 | # get tr and te data with the same normalization 107 | tr_data = getattr(datasets, data_class)( 108 | root=args.path, 109 | train=True, 110 | download=True, 111 | transform=transforms.Compose(trans) 112 | ) 113 | 114 | te_data = getattr(datasets, data_class)( 115 | root=args.path, 116 | train=False, 117 | download=True, 118 | transform=transforms.Compose(trans) 119 | ) 120 | 121 | # get tr_loader for train/eval and te_loader for eval 122 | train_loader = torch.utils.data.DataLoader( 123 | dataset=tr_data, 124 | batch_size=args.batch_size_train, 125 | shuffle=False, 126 | ) 127 | 128 | train_loader_eval = torch.utils.data.DataLoader( 129 | dataset=tr_data, 130 | batch_size=args.batch_size_eval, 131 | shuffle=False, 132 | ) 133 | 134 | test_loader_eval = torch.utils.data.DataLoader( 135 | dataset=te_data, 136 | batch_size=args.batch_size_eval, 137 | shuffle=False, 138 | ) 139 | 140 | return train_loader, test_loader_eval, train_loader_eval, num_classes 141 | --------------------------------------------------------------------------------