├── curvature ├── models │ ├── toy_network.py │ ├── __init__.py │ ├── densenet.py │ ├── wide_resnet.py │ ├── vgg.py │ ├── resnext.py │ ├── all_cnn.py │ └── preresnet.py ├── methods │ ├── __init__.py │ ├── shrinkageopt.py │ ├── swag.py │ └── subspaces.py ├── __init__.py ├── losses.py ├── data.py ├── imagenet32.py └── utils.py ├── REQUIREMENTS.txt ├── visualise ├── __init__.py ├── plot_loss_landscape.py ├── plot_training.py └── plot_spectrum.py ├── core ├── __init__.py ├── loss_landscape.py ├── loss_stats.py ├── spectrum.py └── train_network.py ├── optimizers ├── __init__.py ├── adam.py ├── swats.py ├── ekfac.py ├── hessianfree.py └── kfac.py ├── utils ├── network_utils.py ├── data_utils.py └── kfac_utils.py ├── example.py └── README.md /curvature/models/toy_network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | __all__ = ['toy_network'] -------------------------------------------------------------------------------- /curvature/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | subspaces, 3 | swag, 4 | shrinkageopt, 5 | ) 6 | -------------------------------------------------------------------------------- /REQUIREMENTS.txt: -------------------------------------------------------------------------------- 1 | pytorch > 1.1.0 2 | gpytorch 3 | tabulate 4 | argparse 5 | numpy 6 | matplotlib 7 | seaborn 8 | -------------------------------------------------------------------------------- /visualise/__init__.py: -------------------------------------------------------------------------------- 1 | from .plot_spectrum import plot_spectrum 2 | from .plot_loss_landscape import plot_loss_landscape 3 | from .plot_training import plot_training 4 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss_stats import compute_loss_stats 2 | from .spectrum import compute_eigenspectrum 3 | from .train_network import train_network 4 | from .loss_landscape import build_loss_landscape -------------------------------------------------------------------------------- /curvature/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .preresnet import * 2 | from .vgg import * 3 | from .wide_resnet import * 4 | 5 | # Added by Xingchen Wan 20 Oct 6 | from .resnext import * 7 | from .densenet import * 8 | 9 | # Added by Xingchen Wan 2 Dec 10 | from .all_cnn import * 11 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .kfac import KFACOptimizer 4 | from .ekfac import EKFACOptimizer 5 | from .swats import SWATS 6 | 7 | # We use custom-built Adam that integrates Adam and AdamW 8 | from .adam import Adam 9 | 10 | # For SGD, we use the inbuilt pytorch SGD optimiser 11 | from torch.optim import SGD 12 | -------------------------------------------------------------------------------- /curvature/__init__.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | #if torchvision.__version__ == "0.2.1": 3 | # print("Older torchvision found") 4 | # from . import data 5 | #else: 6 | from. import data as data 7 | 8 | from . import ( 9 | methods, 10 | models, 11 | losses, 12 | utils, 13 | ) 14 | 15 | __all__ = [ 16 | 'methods', 17 | 'models', 18 | 'data', 19 | 'losses', 20 | 'utils', 21 | ] 22 | -------------------------------------------------------------------------------- /utils/network_utils.py: -------------------------------------------------------------------------------- 1 | from curvature.models.cifar import (alexnet, densenet, resnet, vgg11, vgg11_bn, 2 | vgg16_bn, vgg19_bn, 3 | wrn) 4 | 5 | 6 | def get_network(network, **kwargs): 7 | networks = { 8 | 'alexnet': alexnet, 9 | 'densenet': densenet, 10 | 'resnet': resnet, 11 | 'vgg11': vgg11, 12 | 'vgg11_bn': vgg11_bn, 13 | 'vgg16_bn': vgg16_bn, 14 | 'vgg19_bn': vgg19_bn, 15 | 'wrn': wrn 16 | 17 | } 18 | 19 | return networks[network](**kwargs) 20 | 21 | -------------------------------------------------------------------------------- /curvature/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import backpack 4 | 5 | 6 | def cross_entropy(model, input, target, backpacked_model=False): 7 | """ 8 | Evaluate the cross entropy loss. 9 | :param model: 10 | :param input: 11 | :param target: 12 | :param backpacked_model: if the model uses backpack facility, this toggle will backpack.extend() the 13 | loss function for the additional functionalities 14 | :return: 15 | """ 16 | output = model(input) 17 | if backpacked_model: 18 | lossfunc = torch.nn.CrossEntropyLoss() 19 | lossfunc = backpack.extend(lossfunc) 20 | loss = lossfunc(output, target) 21 | else: 22 | loss = F.cross_entropy(output, target) 23 | return loss, output, {} 24 | 25 | 26 | def cross_entropy_func(model, input, target): 27 | return lambda: model(input), lambda pred: F.cross_entropy(pred, target) -------------------------------------------------------------------------------- /visualise/plot_loss_landscape.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | sns.set_style('whitegrid') 5 | 6 | 7 | def plot_loss_landscape( 8 | path: str, 9 | **plot_kwargs, 10 | ): 11 | """ 12 | Visualise the effect of perturbation on the eigen-directions on the performance of the network 13 | :param path: the path to the losslandscape- files generated by loss_landscape.py 14 | :param plot_kwargs: any plotting keyword argument. Note this will be applied to all plotting functions. 15 | :return: 16 | """ 17 | a = np.load(path) 18 | n = a['train_acc'].shape[0] 19 | 20 | plt.subplot(221) 21 | for i in range(n): 22 | plt.plot(a['ts'], a['train_acc'][i, :], ".-", 23 | **plot_kwargs) 24 | plt.xlabel('Perturbation') 25 | plt.ylabel('Train Accuracy') 26 | 27 | plt.subplot(222) 28 | for i in range(n): 29 | plt.plot(a['ts'], a['test_acc'][i, :], ".-", **plot_kwargs) 30 | plt.xlabel('Perturbation') 31 | plt.ylabel('Test Accuracy') 32 | 33 | plt.subplot(223) 34 | for i in range(n): 35 | plt.plot(a['ts'], a['train_loss'][i, :], ".-", **plot_kwargs) 36 | plt.xlabel('Perturbation') 37 | plt.ylabel('Train Loss') 38 | 39 | plt.subplot(224) 40 | for i in range(n): 41 | plt.plot(a['ts'], a['test_loss'][i, :], ".-", label='$\lambda = $' + str(a['eigvals'][a['idx'][i]]), 42 | **plot_kwargs) 43 | plt.xlabel('Perturbation') 44 | plt.ylabel('Test Loss') 45 | 46 | plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 47 | 48 | 49 | -------------------------------------------------------------------------------- /visualise/plot_training.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | sns.set_style('whitegrid') 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | def plot_training(dir: str, max_epoch=100, prefix='stats-', swag=False, show_top_5=False): 9 | """ 10 | Visualise the training process including train/test top1/top5 accuracy + loss 11 | :param dir: directories where the statistics files are saved 12 | :param max_epoch: maximum epoch allocated 13 | :param prefix: the prefix to the stats files (default: 'stats-') 14 | :param swag: whether SWAG is enabled 15 | :param show_top_5: whether show Top 5 accuracy in addition to the Top 1 accuracy 16 | :return: 17 | """ 18 | stats = [ 19 | 'train_accuracy', 20 | 'train_top5_accuracy', 21 | 'test_accuracy', 22 | 'test_top5_accuracy', 23 | 'train_loss', 24 | 'test_loss' 25 | ] 26 | x = np.arange(max_epoch) 27 | df = pd.DataFrame(np.nan, index=x, columns=stats) 28 | for i in range(max_epoch): 29 | a = np.load(dir + prefix + str(i) + ".npz", allow_pickle=True) 30 | for col in stats: 31 | if a[col] != [None]: 32 | df.loc[i, col] = a[col] 33 | else: 34 | df.loc[i, col] = np.nan 35 | if swag: 36 | try: 37 | df.loc[i, 'test_loss'] = a['swag_loss'] 38 | df.loc[i, 'test_accuracy'] = a['swag_accuracy'] 39 | df.loc[i, 'test_top5_accuracy'] = a['top5_accuracy'] 40 | except KeyError: 41 | pass 42 | plt.subplot(121) 43 | sns.lineplot(x, df['train_accuracy'], label='Train Accuracy') 44 | sns.lineplot(x, df['test_accuracy'], label='Test Accuracy') 45 | if show_top_5: 46 | sns.lineplot(x, df['train_top5_accuracy'], label='Train Top5 Accuracy') 47 | sns.lineplot(x, df['test_top5_accuracy'], label='Test Top5 Accuracy') 48 | plt.xlabel('Number of epochs') 49 | plt.ylabel('Accuracy') 50 | plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 51 | 52 | plt.subplot(122) 53 | sns.lineplot(x, df['train_loss'], label='Train loss') 54 | sns.lineplot(x, df['test_loss'], label='Test loss') 55 | plt.xlabel('Number of epochs') 56 | plt.ylabel('Loss') 57 | plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 58 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | # Here we provide an example usage of how to use the MLRG DeepCurvature Package 2 | # ---# Script Format #--- 3 | 4 | from core import * 5 | from visualise import * 6 | import matplotlib.pyplot as plt 7 | 8 | # 1. Train a VGG16 network on CIFAR 100. Let's train for 100 epochs (this will take a while - on test computer with 9 | # NVidia GeForce RTX 2080 Ti, each epoch of training takes ~ 10 seconds)) 10 | train_network( 11 | dir='result/VGG16-CIFAR100/', 12 | dataset='CIFAR100', 13 | data_path='data/', 14 | epochs=100, 15 | model='VGG16', 16 | optimizer='SGD', 17 | optimizer_kwargs={ 18 | 'lr': 0.03, 19 | 'momentum': 0.9, 20 | 'weight_decay': 5e-4 21 | } 22 | ) 23 | 24 | # 2. After this step, you should have a bunch of stats- and checkpoint files under the chosen dir. In this case, they 25 | # are stored under .result/VGG16-CIFAR100. The stats files contains the key information of the training and testing (if 26 | # that epoch is scheduled for testing) information, where the checkpoint-00XXX.pt contains the state_dict of the model 27 | # and the optimizer that we need for later analyses. Lets first visualise the training process 28 | plot_training( 29 | dir='result/VGG16-CIFAR100/', 30 | show_top_5=True 31 | ) 32 | plt.show() 33 | # 3. Let's consider the spectrum on the 100th epoch (last training epoch) 34 | 35 | # Let's first use the Lanczos estimation on the Generalised Gauss-Newton matrix - as a preliminary example, we run 20 36 | # Lanczos interations 37 | 38 | lanc = compute_eigenspectrum( 39 | dataset='CIFAR100', 40 | data_path='data/', 41 | model='VGG16', 42 | checkpoint_path='result/VGG16-CIFAR100/checkpoint-00100.pt', 43 | save_spectrum_path='result/VGG16-CIFAR100/spectra/spectrum-00100-ggn_lanczos', 44 | save_eigvec=True, 45 | lanczos_iters=20, 46 | curvature_matrix='ggn_lanczos', 47 | ) 48 | 49 | 50 | # 4. Visualise the result using a stem plot 51 | plot_spectrum('lanczos', path='result/VGG16-CIFAR100/spectra/spectrum-00100-ggn_lanczos.npz') 52 | plt.show() 53 | 54 | # 5. Visualise loss landscape 55 | build_loss_landscape( 56 | dataset='CIFAR100', 57 | data_path='data/', 58 | model='VGG16', 59 | spectrum_path='result/VGG16-CIFAR100/spectra/spectrum-00100-ggn_lanczos', 60 | checkpoint_path='result/VGG16-CIFAR100/checkpoint-00100.pt', 61 | save_path='result/VGG16-CIFAR100/losslandscape-00100.npz' 62 | ) 63 | 64 | plot_loss_landscape('result/VGG16-CIFAR100/losslandscape-00100.npz') 65 | plt.show() -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def get_transforms(dataset): 7 | transform_train = None 8 | transform_test = None 9 | if dataset == 'cifar10': 10 | transform_train = transforms.Compose([ 11 | transforms.RandomCrop(32, padding=4), 12 | transforms.RandomHorizontalFlip(), 13 | transforms.ToTensor(), 14 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 15 | ]) 16 | 17 | transform_test = transforms.Compose([ 18 | transforms.ToTensor(), 19 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 20 | ]) 21 | 22 | if dataset == 'cifar100': 23 | transform_train = transforms.Compose([ 24 | transforms.RandomCrop(32, padding=4), 25 | transforms.RandomHorizontalFlip(), 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 28 | ]) 29 | 30 | transform_test = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 33 | ]) 34 | 35 | assert transform_test is not None and transform_train is not None, 'Error, no dataset %s' % dataset 36 | return transform_train, transform_test 37 | 38 | 39 | def get_dataloader(dataset, train_batch_size, test_batch_size, num_workers=2, root='../data'): 40 | transform_train, transform_test = get_transforms(dataset) 41 | trainset, testset = None, None 42 | if dataset == 'cifar10': 43 | trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train) 44 | testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test) 45 | 46 | if dataset == 'cifar100': 47 | trainset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train) 48 | testset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test) 49 | 50 | 51 | assert trainset is not None and testset is not None, 'Error, no dataset %s' % dataset 52 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, 53 | num_workers=num_workers) 54 | testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, 55 | num_workers=num_workers) 56 | 57 | return trainloader, testloader -------------------------------------------------------------------------------- /curvature/methods/shrinkageopt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | 5 | class ShrinkageOpt(Optimizer): 6 | def __init__(self, params, lr=required, momentum=0, dampening=0, 7 | alpha=1.0, mu=0.0, clip_alpha=0.01, origin=None, wd_mode=True, nesterov=False): 8 | if lr is not required and lr < 0.0: 9 | raise ValueError("Invalid learning rate: {}".format(lr)) 10 | if momentum < 0.0: 11 | raise ValueError("Invalid momentum value: {}".format(momentum)) 12 | 13 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 14 | alpha=alpha, mu=mu, nesterov=nesterov, origin=origin) 15 | if nesterov and (momentum <= 0 or dampening != 0): 16 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 17 | super(ShrinkageOpt, self).__init__(params, defaults) 18 | self.wd_mode = wd_mode 19 | self.clip_alpha = clip_alpha 20 | 21 | def __setstate__(self, state): 22 | super(ShrinkageOpt, self).__setstate__(state) 23 | for group in self.param_groups: 24 | group.setdefault('nesterov', False) 25 | 26 | def step(self, closure=None): 27 | """Performs a single optimization step. 28 | 29 | Arguments: 30 | closure (callable, optional): A closure that reevaluates the model 31 | and returns the loss. 32 | """ 33 | loss = None 34 | if closure is not None: 35 | loss = closure() 36 | 37 | for group in self.param_groups: 38 | alpha = max(group['alpha'], self.clip_alpha) 39 | mu = group['mu'] 40 | momentum = group['momentum'] 41 | dampening = group['dampening'] 42 | nesterov = group['nesterov'] 43 | origin = group['origin'] 44 | 45 | for p in group['params']: 46 | if p.grad is None: 47 | continue 48 | d_p = p.grad.data 49 | if alpha != 1.0: 50 | shift = p.data.clone().detach() 51 | if origin is not None: 52 | shift = p.data.clone().detach() - origin.data 53 | if self.wd_mode: 54 | d_p.add_((1.0 - alpha) / alpha * mu, shift) 55 | else: 56 | d_p.mul_(alpha) 57 | d_p.add_((1.0 - alpha) * mu, shift) 58 | if momentum != 0: 59 | param_state = self.state[p] 60 | if 'momentum_buffer' not in param_state: 61 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 62 | else: 63 | buf = param_state['momentum_buffer'] 64 | buf.mul_(momentum).add_(1 - dampening, d_p) 65 | if nesterov: 66 | d_p = d_p.add(momentum, buf) 67 | else: 68 | d_p = buf 69 | 70 | p.data.add_(-group['lr'], d_p) 71 | 72 | return loss 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLRG Deep Curvature 2 | 3 | (Updated 19 Apr 2020) 4 | 5 | MLRG Deep Curvature is a PyTorch-based [1] package to analyse and visualise neural network curvature and loss landscape, powered by GPU-accelerated Lanczos algorithm built by GPytorch [2]. 6 | 7 | If you find our package is useful for your research, please consider citing below: 8 | 9 | MLRG Deep Curvature. Diego Granziol*, Xingchen Wan*, Timur Garipov*. In arXiv preprint: arXiv: 1912.09656. 2019. 10 | 11 | ## Network training and evaluation 12 | 13 | The package provides a range of pre-built modern popular neural network structures, such as VGG [3] and variants of ResNets [4], and various optimisation schemes in addition to the ones already present in the PyTorch frameworks, such as K-FAC [5] and SWATS [6]. These facilitates faster training and evaluation of the networks (although it is worth noting that any PyTorch-compatible optimisers or architectures can be easily integrated into its analysis framework). 14 | 15 | ## Eigenspectrum analysis of the curvature matrices 16 | 17 | Powered by the Lanczos techniques, with a single random vector the package uses Pearlmutter matrix-vector product trick for fast computation for inference of the eigenvalues and eigenvectors of the common curvature matrices of the deep neural networks. In addition to the standard Hessian matrix, Generalised Gauss-Newton matrix is also supported. 18 | 19 | ## Advanced Statistics of Networks 20 | 21 | In addition to the commonly used statistics to evaluate network training and performance such as the training and testing losses and accuracy, the package supports computations of more advanced statistics, such as squared mean and variance of gradients and Hessians (and GGN), squared norms of Hessian and GGN, L2 and L-inf norms of the network weights and etc. These statistics are useful and relevant for a wide range of purposes such as the designs of second-order optimisers and network architecture. 22 | 23 | ## Visualisations 24 | 25 | For all main features above, accompanying visualisation tools are included. In addition, with the eigen-information obtained visualisations of the loss landscape are also supported by studying the sensitivity of the neural network to perturbations of weights. One key difference is that, instead of random directions as featured in some other packages, this package perturbs the weights in the eigenvector directions explicitly. 26 | 27 | For an illustrated example of its use, please see example.ipynb. 28 | 29 | 30 | 31 | 32 | References: 33 | 34 | 1. Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L. and Desmaison, A., 2019. PyTorch: An imperative style, high-performance deep learning library. In Advances in Neural Information Processing Systems (pp. 8024-8035). 35 | 36 | 37 | 2. Gardner, J., Pleiss, G., Weinberger, K.Q., Bindel, D. and Wilson, A.G., 2018. Gpytorch: Blackbox matrix-matrix gaussian process inference with gpu acceleration. In Advances in Neural Information Processing Systems (pp. 7576-7586). 38 | 39 | 40 | 3. Simonyan, K. and Zisserman, A., 2014. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556. 41 | 42 | 4. He, K., Zhang, X., Ren, S. and Sun, J., 2016. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778). 43 | 44 | 5. Martens, J. and Grosse, R., 2015, June. Optimizing neural networks with kronecker-factored approximate curvature. In International conference on machine learning (pp. 2408-2417). 45 | 46 | 6. Keskar, N.S. and Socher, R., 2017. Improving generalization performance by switching from adam to sgd. arXiv preprint arXiv:1712.07628. 47 | -------------------------------------------------------------------------------- /visualise/plot_spectrum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def plot_spectrum(mode, *args, **kwargs): 6 | if mode == 'lanczos': 7 | plot_spectrum_lanczos(*args, **kwargs) 8 | else: 9 | raise ValueError('Mode' + str(mode) + " is not understood.") 10 | 11 | 12 | def plot_spectrum_lanczos( 13 | result: dict = None, 14 | path: str = None, 15 | display_spectrum_stats: bool = True, 16 | ): 17 | """ 18 | Generate a stem plot of the eigenspectrum, if we are using Lanczos 19 | Parameters 20 | ---------- 21 | result: dict: the return values from core/spectrum 22 | path: str: the path string to the saved spectrum result 23 | display_spectrum_stats: if True, a set of on-screen statistics of the eigenspectrum will be displayed 24 | 25 | Returns 26 | ------- 27 | 28 | """ 29 | if result is not None: 30 | a = result 31 | elif path is not None: 32 | a = np.load(path) 33 | else: raise ValueError('Either result or path needs to be non-empty.') 34 | eig = [] 35 | weight = [] 36 | for i in range(0, len(a['eigvals'])): 37 | eig.append(a['eigvals'][i, 0]) 38 | weight.append(a['gammas'][i]) 39 | markerline, stemlines, baseline = plt.stem(eig, weight, '-', linefmt='black') 40 | plt.xlabel('Eigenvalue Size') 41 | plt.ylabel('Spectral Density') 42 | 43 | # setting property of baseline with color red and linewidth 2 44 | plt.yscale('log') 45 | # plt.xscale('symlog') 46 | plt.xscale('linear') 47 | plt.setp(baseline, color='r', linewidth=2) 48 | plt.rcParams["figure.figsize"] = (10, 3) 49 | plt.rcParams.update({'font.size': 16}) 50 | plt.rc('axes', titlesize=16) 51 | plt.rc('xtick', labelsize=16) 52 | 53 | plt.xticks(np.arange(min(eig), max(eig), (max(eig) - min(eig)) / 3)) 54 | plt.xticks(list(plt.xticks()[0]) + [max(eig)]) 55 | plt.tick_params(labelbottom='on', labeltop='off') 56 | 57 | if display_spectrum_stats: 58 | print('\n Spectral Statistics') 59 | print('Maximum Value is ' + str(max(eig))) 60 | print('Minimum Value is ' + str(min(eig))) 61 | print('Mean of Bulk is ' + str(np.median(eig))) 62 | print('number of negative eigenvalues') 63 | 64 | negeigs = 0 65 | negweight = 0 66 | for i in range(0, len(eig)): 67 | if eig[i] < 0: 68 | negeigs = negeigs + 1 69 | # print('eig val') 70 | # print(eig[i]) 71 | # print('weight val') 72 | # print(weight[i]) 73 | if weight[i] < 0.6: 74 | negweight = negweight + weight[i] 75 | print('number of negative Ritz values') 76 | print(negeigs) 77 | print('weight of negative Ritz values') 78 | print(negweight) 79 | print('pseudo log determinant') 80 | print(np.sum(np.log(np.abs(eig)) * weight)) 81 | print('trace') 82 | print(np.dot(eig, weight)) 83 | print('\n') 84 | idx = np.argmax(weight) 85 | weightvalue = weight[np.argmax(weight)] 86 | print('degeneracy value = ' + str(weightvalue) + ' at eigenvalue ' + str(eig[np.argmax(weight)])) 87 | weight = np.delete(weight, idx) 88 | eig = np.delete(eig, idx) 89 | idx = np.argmax(weight) 90 | weightvalue = weight[np.argmax(weight)] 91 | print('degeneracy value = ' + str(weightvalue) + ' at eigenvalue ' + str(eig[np.argmax(weight)])) 92 | weight = np.delete(weight, idx) 93 | eig = np.delete(eig, idx) 94 | weightvalue = weight[np.argmax(weight)] 95 | print('degeneracy value = ' + str(weightvalue) + ' at eigenvalue ' + str(eig[np.argmax(weight)])) 96 | print('degeneracy of largest eigenvalue = ' + str(weight[np.argmax(eig)]) + ' value = ' + str(eig[np.argmax(eig)])) 97 | print('degeneracy of smallest eigenvalue = ' + str(weight[np.argmin(eig)]) + ' value = ' + str(eig[np.argmin(eig)])) 98 | -------------------------------------------------------------------------------- /curvature/models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 38 | super(DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | def DenseNet121(): 87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 88 | 89 | def DenseNet169(): 90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | def DenseNet201(): 93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | def DenseNet161(): 96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | def densenet_cifar(): 99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | def test(): 102 | net = densenet_cifar() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /curvature/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | WideResNet model definition 3 | ported from https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py 4 | """ 5 | 6 | import torchvision.transforms as transforms 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | import math 11 | 12 | __all__ = ['WideResNet28x10'] 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 17 | 18 | 19 | def conv_init(m): 20 | classname = m.__class__.__name__ 21 | if classname.find('Conv') != -1: 22 | init.xavier_uniform(m.weight, gain=math.sqrt(2)) 23 | init.constant(m.bias, 0) 24 | elif classname.find('BatchNorm') != -1: 25 | init.constant(m.weight, 1) 26 | init.constant(m.bias, 0) 27 | 28 | 29 | class WideBasic(nn.Module): 30 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 31 | super(WideBasic, self).__init__() 32 | self.bn1 = nn.BatchNorm2d(in_planes) 33 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 34 | self.dropout = nn.Dropout(p=dropout_rate) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 37 | 38 | self.shortcut = nn.Sequential() 39 | if stride != 1 or in_planes != planes: 40 | self.shortcut = nn.Sequential( 41 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 42 | ) 43 | 44 | def forward(self, x): 45 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 46 | out = self.conv2(F.relu(self.bn2(out))) 47 | out += self.shortcut(x) 48 | 49 | return out 50 | 51 | 52 | class WideResNet(nn.Module): 53 | def __init__(self, num_classes=10, depth=28, widen_factor=10, dropout_rate=0.): 54 | super(WideResNet, self).__init__() 55 | self.in_planes = 16 56 | 57 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 58 | n = (depth - 4) / 6 59 | k = widen_factor 60 | 61 | nstages = [16, 16 * k, 32 * k, 64 * k] 62 | 63 | self.conv1 = conv3x3(3, nstages[0]) 64 | self.layer1 = self._wide_layer(WideBasic, nstages[1], n, dropout_rate, stride=1) 65 | self.layer2 = self._wide_layer(WideBasic, nstages[2], n, dropout_rate, stride=2) 66 | self.layer3 = self._wide_layer(WideBasic, nstages[3], n, dropout_rate, stride=2) 67 | self.bn1 = nn.BatchNorm2d(nstages[3], momentum=0.9) 68 | self.linear = nn.Linear(nstages[3], num_classes) 69 | 70 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 71 | strides = [stride] + [1] * int(num_blocks - 1) 72 | layers = [] 73 | 74 | for stride in strides: 75 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 76 | self.in_planes = planes 77 | 78 | return nn.Sequential(*layers) 79 | 80 | def forward(self, x): 81 | out = self.conv1(x) 82 | out = self.layer1(out) 83 | out = self.layer2(out) 84 | out = self.layer3(out) 85 | out = F.relu(self.bn1(out)) 86 | out = F.avg_pool2d(out, 8) 87 | out = out.view(out.size(0), -1) 88 | out = self.linear(out) 89 | 90 | return out 91 | 92 | 93 | class WideResNet28x10: 94 | base = WideResNet 95 | args = list() 96 | kwargs = {'depth': 28, 'widen_factor': 10} 97 | transform_train = transforms.Compose([ 98 | transforms.Resize(32), 99 | transforms.RandomCrop(32, padding=4), 100 | transforms.RandomHorizontalFlip(), 101 | transforms.ToTensor(), 102 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 103 | ]) 104 | transform_test = transforms.Compose([ 105 | transforms.Resize(32), 106 | transforms.ToTensor(), 107 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 108 | ]) 109 | -------------------------------------------------------------------------------- /optimizers/adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | # An adapted adam optimiser combining Adam, AdamW and Adam with L2 regularisation 6 | # Incomplete 7 | 8 | 9 | class Adam(Optimizer): 10 | r"""Implements Adam algorithm. 11 | 12 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 13 | 14 | Arguments: 15 | params (iterable): iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr (float, optional): learning rate (default: 1e-3) 18 | betas (Tuple[float, float], optional): coefficients used for computing 19 | running averages of gradient and its square (default: (0.9, 0.999)) 20 | eps (float, optional): term added to the denominator to improve 21 | numerical stability (default: 1e-8) 22 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 23 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 24 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 25 | (default: False) 26 | 27 | .. _Adam\: A Method for Stochastic Optimization: 28 | https://arxiv.org/abs/1412.6980 29 | .. _On the Convergence of Adam and Beyond: 30 | https://openreview.net/forum?id=ryQu7f-RZ 31 | """ 32 | 33 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 34 | weight_decay=0, amsgrad=False, 35 | decoupled_wd=False,): 36 | if not 0.0 <= lr: 37 | raise ValueError("Invalid learning rate: {}".format(lr)) 38 | if not 0.0 <= eps: 39 | raise ValueError("Invalid epsilon value: {}".format(eps)) 40 | if not 0.0 <= betas[0] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 42 | if not 0.0 <= betas[1] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 44 | defaults = dict(lr=lr, betas=betas, eps=eps, 45 | weight_decay=weight_decay, amsgrad=amsgrad) 46 | super(Adam, self).__init__(params, defaults) 47 | self.decoupled_wd = decoupled_wd 48 | 49 | def __setstate__(self, state): 50 | super(Adam, self).__setstate__(state) 51 | for group in self.param_groups: 52 | group.setdefault('amsgrad', False) 53 | 54 | def step(self, closure=None): 55 | """Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | loss = None 62 | if closure is not None: 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | 70 | if self.decoupled_wd and group['weight_decay'] != 0: 71 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 72 | 73 | grad = p.grad.data 74 | if grad.is_sparse: 75 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 76 | amsgrad = group['amsgrad'] 77 | 78 | state = self.state[p] 79 | 80 | # State initialization 81 | if len(state) == 0: 82 | state['step'] = 0 83 | # Exponential moving average of gradient values 84 | state['exp_avg'] = torch.zeros_like(p.data) 85 | # Exponential moving average of squared gradient values 86 | state['exp_avg_sq'] = torch.zeros_like(p.data) 87 | if amsgrad: 88 | # Maintains max of all exp. moving avg. of sq. grad. values 89 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 90 | 91 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 92 | if amsgrad: 93 | max_exp_avg_sq = state['max_exp_avg_sq'] 94 | beta1, beta2 = group['betas'] 95 | 96 | state['step'] += 1 97 | bias_correction1 = 1 - beta1 ** state['step'] 98 | bias_correction2 = 1 - beta2 ** state['step'] 99 | 100 | if not self.decoupled_wd and group['weight_decay'] != 0: 101 | grad.add_(group['weight_decay'], p.data) 102 | 103 | # Decay the first and second moment running average coefficient 104 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 105 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 106 | if amsgrad: 107 | # Maintains the maximum of all 2nd moment running avg. till now 108 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 109 | # Use the max. for normalizing running avg. of gradient 110 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 111 | else: 112 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 113 | 114 | step_size = group['lr'] / bias_correction1 115 | 116 | p.data.addcdiv_(-step_size, exp_avg, denom) 117 | 118 | return loss -------------------------------------------------------------------------------- /curvature/methods/swag.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import set_weights 4 | from .subspaces import Subspace 5 | 6 | class SWAG(torch.nn.Module): 7 | 8 | def __init__(self, base, subspace_type, 9 | subspace_kwargs=None, var_clamp=1e-6, *args, **kwargs): 10 | super(SWAG, self).__init__() 11 | 12 | self.base_model = base(*args, **kwargs) 13 | self.num_parameters = sum(param.numel() for param in self.base_model.parameters()) 14 | 15 | self.register_buffer('mean', torch.zeros(self.num_parameters)) 16 | self.register_buffer('sq_mean', torch.zeros(self.num_parameters)) 17 | self.register_buffer('n_models', torch.zeros(1, dtype=torch.long)) 18 | 19 | # Initialize subspace 20 | if subspace_kwargs is None: 21 | subspace_kwargs = dict() 22 | self.subspace = Subspace.create(subspace_type, num_parameters=self.num_parameters, 23 | **subspace_kwargs) 24 | 25 | self.var_clamp = var_clamp 26 | 27 | self.cov_factor = None 28 | self.model_device = 'cpu' 29 | 30 | # dont put subspace on cuda 31 | def cuda(self, device=None): 32 | self.model_device = 'cuda' 33 | self.base_model.cuda(device=device) 34 | 35 | def to(self, *args, **kwargs): 36 | self.base_model.to(*args, **kwargs) 37 | device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs) 38 | self.model_device = device.type 39 | self.subspace.to(device=torch.device('cpu'), dtype=dtype, non_blocking=non_blocking) 40 | 41 | def forward(self, *args, **kwargs): 42 | return self.base_model(*args, **kwargs) 43 | 44 | def collect_model(self, base_model, *args, **kwargs): 45 | # need to refit the space after collecting a new model 46 | self.cov_factor = None 47 | 48 | w = torch.cat([param.detach().cpu().view(-1) for param in base_model.parameters()]) 49 | # first moment 50 | self.mean.mul_(self.n_models.item() / (self.n_models.item() + 1.0)) 51 | self.mean.add_(w / (self.n_models.item() + 1.0)) 52 | 53 | # second moment 54 | self.sq_mean.mul_(self.n_models.item() / (self.n_models.item() + 1.0)) 55 | self.sq_mean.add_(w ** 2 / (self.n_models.item() + 1.0)) 56 | 57 | dev_vector = w - self.mean 58 | 59 | self.subspace.collect_vector(dev_vector, *args, **kwargs) 60 | self.n_models.add_(1) 61 | 62 | def _get_mean_and_variance(self): 63 | variance = torch.clamp(self.sq_mean - self.mean ** 2, self.var_clamp) 64 | return self.mean, variance 65 | 66 | def fit(self): 67 | if self.cov_factor is not None: 68 | return 69 | self.cov_factor = self.subspace.get_space() 70 | 71 | def set_swa(self): 72 | set_weights(self.base_model, self.mean, self.model_device) 73 | 74 | def sample(self, scale=0.5, diag_noise=True): 75 | self.fit() 76 | mean, variance = self._get_mean_and_variance() 77 | 78 | eps_low_rank = torch.randn(self.cov_factor.size()[0]) 79 | z = self.cov_factor.t() @ eps_low_rank 80 | if diag_noise: 81 | z += variance * torch.randn_like(variance) 82 | z *= scale ** 0.5 83 | sample = mean + z 84 | 85 | # apply to parameters 86 | set_weights(self.base_model, sample, self.model_device) 87 | return sample 88 | 89 | def get_space(self, export_cov_factor=True): 90 | mean, variance = self._get_mean_and_variance() 91 | if not export_cov_factor: 92 | return mean.clone(), variance.clone() 93 | else: 94 | self.fit() 95 | return mean.clone(), variance.clone(), self.cov_factor.clone() 96 | 97 | 98 | class SWA(torch.nn.Module): 99 | def __init__(self, base, *args, **kwargs): 100 | super(SWA, self).__init__() 101 | 102 | self.base_model = base(*args, **kwargs) 103 | self.num_parameters = sum(param.numel() for param in self.base_model.parameters()) 104 | 105 | self.register_buffer('mean', torch.zeros(self.num_parameters)) 106 | self.register_buffer('n_models', torch.zeros(1, dtype=torch.long)) 107 | 108 | # Initialize subspace 109 | self.model_device = 'cpu' 110 | 111 | # dont put subspace on cuda 112 | 113 | def cuda(self, device=None): 114 | self.model_device = 'cuda' 115 | self.base_model.cuda(device=device) 116 | 117 | def to(self, *args, **kwargs): 118 | self.base_model.to(*args, **kwargs) 119 | device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs) 120 | self.model_device = device.type 121 | 122 | def forward(self, *args, **kwargs): 123 | return self.base_model(*args, **kwargs) 124 | 125 | def collect_model(self, base_model, *args, **kwargs): 126 | # need to refit the space after collecting a new model 127 | w = torch.cat([param.detach().cpu().view(-1) for param in base_model.parameters()]) 128 | # first moment 129 | self.mean.mul_(self.n_models.item() / (self.n_models.item() + 1.0)) 130 | self.mean.add_(w / (self.n_models.item() + 1.0)) 131 | self.n_models.add_(1) 132 | 133 | def _get_mean_and_variance(self): 134 | return self.mean, None 135 | 136 | def fit(self): 137 | pass 138 | 139 | def set_swa(self): 140 | set_weights(self.base_model, self.mean, self.model_device) 141 | 142 | def sample(self, scale=0.5, diag_noise=True): 143 | pass 144 | -------------------------------------------------------------------------------- /curvature/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import os 5 | 6 | from curvature.imagenet32 import IMAGENET32 7 | 8 | 9 | def datasets( 10 | dataset, 11 | path, 12 | transform_train, 13 | transform_test, 14 | use_validation=True, 15 | val_size=5000, 16 | train_subset=None, 17 | train_subset_seed=None): 18 | assert dataset in {'CIFAR10', 'CIFAR100', 'MNIST', 'ImageNet32'} 19 | print('Loading %s from %s' % (dataset, path)) 20 | 21 | path = os.path.join(path, dataset.lower()) 22 | if dataset == 'ImageNet32': 23 | ds = IMAGENET32 24 | train_set = ds(root=path, train=True, download=False, transform=transform_train) 25 | else: 26 | ds = getattr(torchvision.datasets, dataset) 27 | train_set = ds(root=path, train=True, download=True, transform=transform_train) 28 | n_train_samples = len(train_set) 29 | if isinstance(val_size, float): 30 | assert val_size < 1, "If entered as a float number to represent the fraction " \ 31 | "of validation data, this number must be smaller than 1." 32 | val_size = int(n_train_samples * val_size) 33 | elif isinstance(val_size, int): 34 | pass 35 | else: 36 | raise TypeError("val_size needs to be either an int or a float, but got "+type(val_size)) 37 | num_classes = np.max(train_set.train_labels) + 1 38 | if use_validation: 39 | print('Using %d samples for validation [deterministic split]' % (val_size)) 40 | train_set.train_data = train_set.train_data[:-val_size] 41 | train_set.train_labels = train_set.train_labels[:-val_size] 42 | 43 | test_set = ds(root=path, train=True, download=True, transform=transform_test) 44 | test_set.train = False 45 | test_set.test_data = test_set.train_data[-val_size:] 46 | test_set.test_labels = test_set.train_labels[-val_size:] 47 | delattr(test_set, 'train_data') 48 | delattr(test_set, 'train_labels') 49 | else: 50 | print('You are going to run models on the test set. Are you sure?') 51 | if dataset == 'ImageNet32': 52 | test_set = ds(root=path, train=False, download=False, transform=transform_test) 53 | else: 54 | test_set = ds(root=path, train=False, download=True, transform=transform_test) 55 | 56 | if train_subset is not None: 57 | order = np.arange(train_set.train_data.shape[0]) 58 | if train_subset_seed is not None: 59 | rng = np.random.RandomState(train_subset_seed) 60 | rng.shuffle(order) 61 | train_set.train_data = train_set.train_data[order[:train_subset]] 62 | train_set.train_labels = np.array(train_set.train_labels)[order[:train_subset]].tolist() 63 | 64 | print('Using train (%d) + test (%d)' % (train_set.train_data.shape[0], test_set.test_data.shape[0])) 65 | 66 | return \ 67 | { 68 | 'train': train_set, 69 | 'test': test_set 70 | }, \ 71 | num_classes 72 | 73 | 74 | def loaders( 75 | dataset, 76 | path, 77 | batch_size, 78 | num_workers, 79 | transform_train, 80 | transform_test, 81 | use_validation=True, 82 | val_size=5000, 83 | shuffle_train=True): 84 | 85 | ds_dict, num_classes = datasets( 86 | dataset, path, transform_train, transform_test, use_validation=use_validation, val_size=val_size) 87 | 88 | return \ 89 | { 90 | 'train': torch.utils.data.DataLoader( 91 | ds_dict['train'], 92 | batch_size=batch_size, 93 | shuffle=shuffle_train, 94 | num_workers=num_workers, 95 | pin_memory=True 96 | ), 97 | 'test': torch.utils.data.DataLoader( 98 | ds_dict['test'], 99 | batch_size=batch_size, 100 | shuffle=False, 101 | num_workers=num_workers, 102 | pin_memory=True 103 | ), 104 | }, \ 105 | num_classes 106 | 107 | 108 | class CIFAR10AUG(torch.utils.data.Dataset): 109 | base_class = torchvision.datasets.CIFAR10 110 | 111 | def __init__(self, root, train=True, transform=None, download=False, shuffle_seed=1): 112 | self.base = self.base_class(root, train=train, transform=None, target_transform=None, download=download) 113 | self.transform = transform 114 | 115 | self.pad = 4 116 | self.size = len(self.base) * (2 * self.pad + 1) * (2 * self.pad + 1) * 2 117 | rng = np.random.RandomState(shuffle_seed) 118 | self.order = rng.permutation(self.size) 119 | 120 | def __len__(self): 121 | return self.size 122 | 123 | def __getitem__(self, index): 124 | index = self.order[index] 125 | 126 | base_index = index // ((2 * self.pad + 1) * (2 * self.pad + 1) * 2) 127 | img, target = self.base[base_index] 128 | 129 | transform_index = index % ((2 * self.pad + 1) * (2 * self.pad + 1) * 2) 130 | flip_index = transform_index // ((2 * self.pad + 1) * (2 * self.pad + 1)) 131 | crop_index = transform_index % ((2 * self.pad + 1) * (2 * self.pad + 1)) 132 | crop_x = crop_index // (2 * self.pad + 1) 133 | crop_y = crop_index % (2 * self.pad + 1) 134 | 135 | if flip_index: 136 | img = torchvision.transforms.functional.hflip(img) 137 | img = torchvision.transforms.functional.pad(img, self.pad) 138 | img = torchvision.transforms.functional.crop(img, crop_x, crop_y, 32, 32) 139 | 140 | if self.transform is not None: 141 | img = self.transform(img) 142 | 143 | return img, target 144 | 145 | 146 | class CIFAR100AUG(CIFAR10AUG): 147 | base_class = torchvision.datasets.CIFAR100 -------------------------------------------------------------------------------- /curvature/methods/subspaces.py: -------------------------------------------------------------------------------- 1 | """ 2 | subspace classes 3 | CovarianceSpace: covariance subspace 4 | PCASpace: PCA subspace 5 | FreqDirSpace: Frequent Directions Space 6 | """ 7 | 8 | import abc 9 | 10 | import torch 11 | import numpy as np 12 | 13 | from sklearn.decomposition import TruncatedSVD 14 | from sklearn.utils.extmath import randomized_svd 15 | 16 | 17 | class Subspace(torch.nn.Module, metaclass=abc.ABCMeta): 18 | subclasses = {} 19 | 20 | @classmethod 21 | def register_subclass(cls, subspace_type): 22 | def decorator(subclass): 23 | cls.subclasses[subspace_type] = subclass 24 | return subclass 25 | return decorator 26 | 27 | @classmethod 28 | def create(cls, subspace_type, **kwargs): 29 | if subspace_type not in cls.subclasses: 30 | raise ValueError('Bad subspaces type {}'.format(subspace_type)) 31 | return cls.subclasses[subspace_type](**kwargs) 32 | 33 | def __init__(self): 34 | super(Subspace, self).__init__() 35 | 36 | @abc.abstractmethod 37 | def collect_vector(self, vector): 38 | pass 39 | 40 | @abc.abstractmethod 41 | def get_space(self): 42 | pass 43 | 44 | 45 | @Subspace.register_subclass('empty') 46 | class EmptySpace(Subspace): 47 | def __init__(self, num_parameters, rank=20): 48 | super(EmptySpace, self).__init__() 49 | 50 | self.num_parameters = num_parameters 51 | self.rank = rank 52 | 53 | # random subspace is independent of data 54 | def collect_vector(self, vector): 55 | pass 56 | 57 | def get_space(self): 58 | raise NotImplementedError 59 | 60 | 61 | @Subspace.register_subclass('random') 62 | class RandomSpace(Subspace): 63 | def __init__(self, num_parameters, rank=20): 64 | 65 | super(RandomSpace, self).__init__() 66 | 67 | self.num_parameters = num_parameters 68 | self.rank = rank 69 | 70 | self.subspace = torch.randn(rank, num_parameters) 71 | 72 | # random subspace is independent of data 73 | def collect_vector(self, vector): 74 | pass 75 | 76 | def get_space(self): 77 | return self.subspace 78 | 79 | 80 | @Subspace.register_subclass('covariance') 81 | class CovarianceSpace(Subspace): 82 | 83 | def __init__(self, num_parameters, max_rank=20): 84 | super(CovarianceSpace, self).__init__() 85 | 86 | self.num_parameters = num_parameters 87 | 88 | self.register_buffer('rank', torch.zeros(1, dtype=torch.long)) 89 | self.register_buffer('cov_mat_sqrt', 90 | torch.empty(0, self.num_parameters, dtype=torch.float32)) 91 | 92 | self.max_rank = max_rank 93 | 94 | def collect_vector(self, vector): 95 | if self.rank.item() + 1 > self.max_rank: 96 | self.cov_mat_sqrt = self.cov_mat_sqrt[1:, :] 97 | self.cov_mat_sqrt = torch.cat((self.cov_mat_sqrt, vector.view(1, -1)), dim=0) 98 | self.rank = torch.min(self.rank + 1, torch.as_tensor(self.max_rank)).view(-1) 99 | 100 | def get_space(self): 101 | return self.cov_mat_sqrt.clone() / (self.cov_mat_sqrt.size(0) - 1) ** 0.5 102 | 103 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 104 | missing_keys, unexpected_keys, error_msgs): 105 | rank = state_dict[prefix + 'rank'].item() 106 | self.cov_mat_sqrt = self.cov_mat_sqrt.new_empty((rank, self.cov_mat_sqrt.size()[1])) 107 | super(CovarianceSpace, self)._load_from_state_dict(state_dict, prefix, local_metadata, 108 | strict, missing_keys, unexpected_keys, 109 | error_msgs) 110 | 111 | 112 | @Subspace.register_subclass('pca') 113 | class PCASpace(CovarianceSpace): 114 | 115 | def __init__(self, num_parameters, pca_rank=20, max_rank=20): 116 | super(PCASpace, self).__init__(num_parameters, max_rank=max_rank) 117 | 118 | assert isinstance(pca_rank, int) 119 | assert 1 <= pca_rank <= max_rank 120 | 121 | self.pca_rank = pca_rank 122 | 123 | def get_space(self): 124 | 125 | cov_mat_sqrt_np = self.cov_mat_sqrt.clone().numpy() 126 | 127 | # perform PCA on DD' 128 | cov_mat_sqrt_np /= (max(1, self.rank.item() - 1))**0.5 129 | 130 | if self.pca_rank == 'mle': 131 | pca_rank = self.rank.item() 132 | else: 133 | pca_rank = self.pca_rank 134 | 135 | pca_rank = max(1, min(pca_rank, self.rank.item())) 136 | pca_decomp = TruncatedSVD(n_components=pca_rank) 137 | pca_decomp.fit(cov_mat_sqrt_np) 138 | 139 | _, s, Vt = randomized_svd(cov_mat_sqrt_np, n_components=pca_rank, n_iter=5) 140 | 141 | return torch.FloatTensor(s[:, None] * Vt) 142 | 143 | 144 | @Subspace.register_subclass('freq_dir') 145 | class FreqDirSpace(CovarianceSpace): 146 | def __init__(self, num_parameters, max_rank=20): 147 | super(FreqDirSpace, self).__init__(num_parameters, max_rank=max_rank) 148 | self.register_buffer('num_models', torch.zeros(1, dtype=torch.long)) 149 | self.delta = 0.0 150 | self.normalized = False 151 | 152 | def collect_vector(self, vector): 153 | if self.rank >= 2 * self.max_rank: 154 | sketch = self.cov_mat_sqrt.numpy() 155 | [_, s, Vt] = np.linalg.svd(sketch, full_matrices=False) 156 | if s.size >= self.max_rank: 157 | current_delta = s[self.max_rank - 1] ** 2 158 | self.delta += current_delta 159 | s = np.sqrt(s[:self.max_rank - 1] ** 2 - current_delta) 160 | self.cov_mat_sqrt = torch.from_numpy(s[:, None] * Vt[:s.size, :]) 161 | 162 | self.cov_mat_sqrt = torch.cat((self.cov_mat_sqrt, vector.view(1, -1)), dim=0) 163 | self.rank = torch.as_tensor(self.cov_mat_sqrt.size(0)) 164 | self.num_models.add_(1) 165 | self.normalized = False 166 | 167 | def get_space(self): 168 | if not self.normalized: 169 | sketch = self.cov_mat_sqrt.numpy() 170 | [_, s, Vt] = np.linalg.svd(sketch, full_matrices=False) 171 | self.cov_mat_sqrt = torch.from_numpy(s[:, None] * Vt) 172 | self.normalized = True 173 | curr_rank = min(self.rank.item(), self.max_rank) 174 | return self.cov_mat_sqrt[:curr_rank].clone() / max(1, self.num_models.item() - 1) ** 0.5 175 | -------------------------------------------------------------------------------- /curvature/imagenet32.py: -------------------------------------------------------------------------------- 1 | # Imagenet loader for torchvision <= 0.2.0 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import sys 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | 13 | import torch.utils.data as data 14 | from torchvision.datasets.utils import download_url, check_integrity 15 | 16 | 17 | class IMAGENET32(data.Dataset): 18 | """`CIFAR10 `_ Dataset. 19 | 20 | Args: 21 | root (string): Root directory of dataset where directory 22 | ``cifar-10-batches-py`` exists. 23 | train (bool, optional): If True, creates dataset from training set, otherwise 24 | creates from test set. 25 | transform (callable, optional): A function/transform that takes in an PIL image 26 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 27 | target_transform (callable, optional): A function/transform that takes in the 28 | target and transforms it. 29 | download (bool, optional): If true, downloads the dataset from the internet and 30 | puts it in root directory. If dataset is already downloaded, it is not 31 | downloaded again. 32 | 33 | """ 34 | base_folder = 'Imagenet32_train_npz' 35 | url = "http://www.image-net.org/image/downsample/Imagenet32_train_npz.zip" 36 | filename = "Imagenet32_train_npz.zip" 37 | tgz_md5 = 'b0d308fb0016e41348a90f0ae772ee38' 38 | train_list = [ 39 | ['train_data_batch_1.npz', '464fde20de6eb44c28cc1a8c11544bb1'], 40 | ['train_data_batch_2.npz', 'bdb56e71882c3fd91619d789d5dd7c79'], 41 | ['train_data_batch_3.npz', '83ff36d76ea26867491a281ea6e1d03b'], 42 | ['train_data_batch_4.npz', '98ff184fe109d5c2a0f6da63843880c7'], 43 | ['train_data_batch_5.npz', '462b8803e13c3e6de9498da7aaaae57c8'], 44 | ['train_data_batch_6.npz', 'e0b06665f890b029f1d8d0a0db26e119'], 45 | ['train_data_batch_7.npz', '9731f469aac1622477813c132c5a847a'], 46 | ['train_data_batch_8.npz', '60aed934b9d26b7ee83a1a83bdcfbe0f'], 47 | ['train_data_batch_9.npz', 'b96328e6affd718660c2561a6fe8c14c'], 48 | ['train_data_batch_10.npz', '1dc618d544c554220dd118f72975470c'], 49 | ] 50 | 51 | test_list = [ 52 | ['val_data.npz', 'a8c04a389f2649841fb7a01720da9dd9'], 53 | ] 54 | 55 | def __init__(self, root, train=True, 56 | transform=None, target_transform=None, 57 | download=False): 58 | self.root = os.path.expanduser(root) 59 | self.transform = transform 60 | self.target_transform = target_transform 61 | self.train = train # training set or test set 62 | 63 | if download: 64 | self.download() 65 | 66 | #if not self._check_integrity(): 67 | # raise RuntimeError('Dataset not found or corrupted.' + 68 | # ' You can use download=True to download it') 69 | 70 | # now load the picked numpy arrays 71 | if self.train: 72 | self.train_data = [] 73 | self.train_labels = [] 74 | for fentry in self.train_list: 75 | f = fentry[0] 76 | file = os.path.join(root, self.base_folder, f) 77 | fo = open(file, 'rb') 78 | if sys.version_info[0] == 2: 79 | entry = pickle.load(fo) 80 | else: 81 | entry = np.load(fo) 82 | self.train_data.append(entry['data']) 83 | if 'labels' in entry: 84 | self.train_labels.extend(entry['labels']) 85 | else: 86 | self.train_labels.extend(['fine_labels']) 87 | fo.close() 88 | 89 | self.train_data = np.concatenate(self.train_data) 90 | self.train_data = self.train_data.reshape((1281167, 3, 32, 32)) 91 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 92 | else: 93 | f = self.test_list[0][0] 94 | file = os.path.join(root, self.base_folder, f) 95 | fo = open(file, 'rb') 96 | if sys.version_info[0] == 2: 97 | entry = pickle.load(fo) 98 | else: 99 | entry = np.load(fo) 100 | self.test_data = entry['data'] 101 | if 'labels' in entry: 102 | self.test_labels = entry['labels'] 103 | else: 104 | self.test_labels = entry['fine_labels'] 105 | fo.close() 106 | self.test_data = self.test_data.reshape((50000, 3, 32, 32)) 107 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 108 | 109 | def __getitem__(self, index): 110 | """ 111 | Args: 112 | index (int): Index 113 | 114 | Returns: 115 | tuple: (image, target) where target is index of the target class. 116 | """ 117 | if self.train: 118 | img, target = self.train_data[index], self.train_labels[index] 119 | else: 120 | img, target = self.test_data[index], self.test_labels[index] 121 | 122 | # doing this so that it is consistent with all other datasets 123 | # to return a PIL Image 124 | img = Image.fromarray(img) 125 | 126 | if self.transform is not None: 127 | img = self.transform(img) 128 | 129 | if self.target_transform is not None: 130 | target = self.target_transform(target) 131 | 132 | return img, target 133 | 134 | def __len__(self): 135 | if self.train: 136 | return 1281167 137 | else: 138 | return 50000 139 | 140 | def _check_integrity(self): 141 | root = self.root 142 | for fentry in (self.train_list + self.test_list): 143 | filename, md5 = fentry[0], fentry[1] 144 | fpath = os.path.join(root, self.base_folder, filename) 145 | if not check_integrity(fpath, md5): 146 | return False 147 | return True 148 | 149 | def download(self): 150 | import tarfile 151 | 152 | if self._check_integrity(): 153 | print('Files already downloaded and verified') 154 | return 155 | 156 | root = self.root 157 | download_url(self.url, root, self.filename, self.tgz_md5) 158 | 159 | # extract file 160 | cwd = os.getcwd() 161 | tar = tarfile.open(os.path.join(root, self.filename), "r:gz") 162 | os.chdir(root) 163 | tar.extractall() 164 | tar.close() 165 | os.chdir(cwd) -------------------------------------------------------------------------------- /utils/kfac_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def try_contiguous(x): 7 | if not x.is_contiguous(): 8 | x = x.contiguous() 9 | 10 | return x 11 | 12 | 13 | def _extract_patches(x, kernel_size, stride, padding): 14 | """ 15 | :param x: The input feature maps. (batch_size, in_c, h, w) 16 | :param kernel_size: the kernel size of the conv filter (tuple of two elements) 17 | :param stride: the stride of conv operation (tuple of two elements) 18 | :param padding: number of paddings. be a tuple of two elements 19 | :return: (batch_size, out_h, out_w, in_c*kh*kw) 20 | """ 21 | if padding[0] + padding[1] > 0: 22 | x = F.pad(x, (padding[1], padding[1], padding[0], 23 | padding[0])).data # Actually check dims 24 | x = x.unfold(2, kernel_size[0], stride[0]) 25 | x = x.unfold(3, kernel_size[1], stride[1]) 26 | x = x.transpose_(1, 2).transpose_(2, 3).contiguous() 27 | x = x.view( 28 | x.size(0), x.size(1), x.size(2), 29 | x.size(3) * x.size(4) * x.size(5)) 30 | return x 31 | 32 | 33 | def update_running_stat(aa, m_aa, stat_decay): 34 | # using inplace operation to save memory! 35 | m_aa *= stat_decay / (1 - stat_decay) 36 | m_aa += aa 37 | m_aa *= (1 - stat_decay) 38 | 39 | 40 | class ComputeMatGrad: 41 | 42 | @classmethod 43 | def __call__(cls, input, grad_output, layer): 44 | if isinstance(layer, nn.Linear): 45 | grad = cls.linear(input, grad_output, layer) 46 | elif isinstance(layer, nn.Conv2d): 47 | grad = cls.conv2d(input, grad_output, layer) 48 | else: 49 | raise NotImplementedError 50 | return grad 51 | 52 | @staticmethod 53 | def linear(input, grad_output, layer): 54 | """ 55 | :param input: batch_size * input_dim 56 | :param grad_output: batch_size * output_dim 57 | :param layer: [nn.module] output_dim * input_dim 58 | :return: batch_size * output_dim * (input_dim + [1 if with bias]) 59 | """ 60 | with torch.no_grad(): 61 | if layer.bias is not None: 62 | input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1) 63 | input = input.unsqueeze(1) 64 | grad_output = grad_output.unsqueeze(2) 65 | grad = torch.bmm(grad_output, input) 66 | return grad 67 | 68 | @staticmethod 69 | def conv2d(input, grad_output, layer): 70 | """ 71 | :param input: batch_size * in_c * in_h * in_w 72 | :param grad_output: batch_size * out_c * h * w 73 | :param layer: nn.module batch_size * out_c * (in_c*k_h*k_w + [1 if with bias]) 74 | :return: 75 | """ 76 | with torch.no_grad(): 77 | input = _extract_patches(input, layer.kernel_size, layer.stride, layer.padding) 78 | input = input.view(-1, input.size(-1)) # b * hw * in_c*kh*kw 79 | grad_output = grad_output.transpose(1, 2).transpose(2, 3) 80 | grad_output = try_contiguous(grad_output).view(grad_output.size(0), -1, grad_output.size(-1)) 81 | # b * hw * out_c 82 | if layer.bias is not None: 83 | input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1) 84 | input = input.view(grad_output.size(0), -1, input.size(-1)) # b * hw * in_c*kh*kw 85 | grad = torch.einsum('abm,abn->amn', (grad_output, input)) 86 | return grad 87 | 88 | 89 | class ComputeCovA: 90 | 91 | @classmethod 92 | def compute_cov_a(cls, a, layer): 93 | return cls.__call__(a, layer) 94 | 95 | @classmethod 96 | def __call__(cls, a, layer): 97 | if isinstance(layer, nn.Linear): 98 | cov_a = cls.linear(a, layer) 99 | elif isinstance(layer, nn.Conv2d): 100 | cov_a = cls.conv2d(a, layer) 101 | else: 102 | # FIXME(CW): for extension to other layers. 103 | # raise NotImplementedError 104 | cov_a = None 105 | 106 | return cov_a 107 | 108 | @staticmethod 109 | def conv2d(a, layer): 110 | batch_size = a.size(0) 111 | a = _extract_patches(a, layer.kernel_size, layer.stride, layer.padding) 112 | spatial_size = a.size(1) * a.size(2) 113 | a = a.view(-1, a.size(-1)) 114 | if layer.bias is not None: 115 | a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1) 116 | a = a/spatial_size 117 | # FIXME(CW): do we need to divide the output feature map's size? 118 | return a.t() @ (a / batch_size) 119 | 120 | @staticmethod 121 | def linear(a, layer): 122 | # a: batch_size * in_dim 123 | batch_size = a.size(0) 124 | if layer.bias is not None: 125 | a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1) 126 | return a.t() @ (a / batch_size) 127 | 128 | 129 | class ComputeCovG: 130 | 131 | @classmethod 132 | def compute_cov_g(cls, g, layer, batch_averaged=False): 133 | """ 134 | :param g: gradient 135 | :param layer: the corresponding layer 136 | :param batch_averaged: if the gradient is already averaged with the batch size? 137 | :return: 138 | """ 139 | # batch_size = g.size(0) 140 | return cls.__call__(g, layer, batch_averaged) 141 | 142 | @classmethod 143 | def __call__(cls, g, layer, batch_averaged): 144 | if isinstance(layer, nn.Conv2d): 145 | cov_g = cls.conv2d(g, layer, batch_averaged) 146 | elif isinstance(layer, nn.Linear): 147 | cov_g = cls.linear(g, layer, batch_averaged) 148 | else: 149 | cov_g = None 150 | 151 | return cov_g 152 | 153 | @staticmethod 154 | def conv2d(g, layer, batch_averaged): 155 | # g: batch_size * n_filters * out_h * out_w 156 | # n_filters is actually the output dimension (analogous to Linear layer) 157 | spatial_size = g.size(2) * g.size(3) 158 | batch_size = g.shape[0] 159 | g = g.transpose(1, 2).transpose(2, 3) 160 | g = try_contiguous(g) 161 | g = g.view(-1, g.size(-1)) 162 | 163 | if batch_averaged: 164 | g = g * batch_size 165 | g = g * spatial_size 166 | cov_g = g.t() @ (g / g.size(0)) 167 | 168 | return cov_g 169 | 170 | @staticmethod 171 | def linear(g, layer, batch_averaged): 172 | # g: batch_size * out_dim 173 | batch_size = g.size(0) 174 | 175 | if batch_averaged: 176 | cov_g = g.t() @ (g * batch_size) 177 | else: 178 | cov_g = g.t() @ (g / batch_size) 179 | return cov_g 180 | 181 | 182 | 183 | if __name__ == '__main__': 184 | def test_ComputeCovA(): 185 | pass 186 | 187 | def test_ComputeCovG(): 188 | pass 189 | 190 | 191 | 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /curvature/models/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | VGG model definition 3 | ported from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 4 | """ 5 | 6 | import math 7 | import torch.nn as nn 8 | import torchvision.transforms as transforms 9 | 10 | __all__ = ['VGG6','VGG16basic','VGG11', 'VGG11BN','VGG16', 'VGG16BN', 'VGG19', 'VGG19BN',] 11 | 12 | 13 | def make_layers(cfg, batch_norm=True): 14 | layers = list() 15 | in_channels = 3 16 | for v in cfg: 17 | if v == 'M': 18 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 19 | else: 20 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 21 | if batch_norm: 22 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 23 | else: 24 | layers += [conv2d, nn.ReLU(inplace=True)] 25 | in_channels = v 26 | return nn.Sequential(*layers) 27 | 28 | 29 | cfg = { 30 | #6: [64, 'M', 128, 'M', 256, 'M', 512, 'M', 512, 'M'], 31 | 6: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 32 | 11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 33 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 34 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 35 | 512, 512, 512, 512, 'M'], 36 | } 37 | 38 | 39 | class VGG(nn.Module): 40 | def __init__(self, num_classes=10, depth=16, batch_norm=False): 41 | super(VGG, self).__init__() 42 | self.features = make_layers(cfg[depth], batch_norm) 43 | self.classifier = nn.Sequential( 44 | nn.Dropout(), 45 | nn.Linear(512, 512), 46 | nn.ReLU(True), 47 | nn.Dropout(), 48 | nn.Linear(512, 512), 49 | nn.ReLU(True), 50 | nn.Linear(512, num_classes), 51 | ) 52 | 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2. / n)) 57 | m.bias.data.zero_() 58 | 59 | def forward(self, x): 60 | x = self.features(x) 61 | x = x.view(x.size(0), -1) 62 | x = self.classifier(x) 63 | return x 64 | 65 | 66 | class Base: 67 | base = VGG 68 | args = list() 69 | kwargs = dict() 70 | transform_train = transforms.Compose([ 71 | transforms.RandomHorizontalFlip(), 72 | transforms.Resize(32), 73 | transforms.RandomCrop(32, padding=4), 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 76 | #transforms.Normalize((0.4376821 , 0.4437697 , 0.47280442), (0.19803012, 0.20101562, 0.19703614)) 77 | ]) 78 | 79 | transform_test = transforms.Compose([ 80 | transforms.Resize(32), 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 83 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 84 | #transforms.Normalize((0.45242316, 0.45249584, 0.46897713), (0.21943445, 0.22656967, 0.22850613)) 85 | ]) 86 | 87 | 88 | class Basic: 89 | base = VGG 90 | args = list() 91 | kwargs = dict() 92 | transform_train = transforms.Compose([transforms.Resize(32), 93 | transforms.ToTensor(), 94 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 95 | transform_test = transforms.Compose([transforms.Resize(32), 96 | transforms.ToTensor(), 97 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 98 | 99 | 100 | class VGG16basic(Basic): 101 | pass 102 | 103 | class VGG6(Base): 104 | kwargs = {'depth': 6} 105 | pass 106 | 107 | class VGG11(Base): 108 | kwargs = {'depth': 11} 109 | pass 110 | 111 | class VGG16(Base): 112 | pass 113 | 114 | 115 | class VGG16BN(Base): 116 | kwargs = {'batch_norm': True} 117 | 118 | 119 | class VGG19(Base): 120 | kwargs = {'depth': 19} 121 | 122 | 123 | class VGG19BN(Base): 124 | kwargs = {'depth': 19, 'batch_norm': True} 125 | 126 | class VGG11(Base): 127 | pass 128 | 129 | 130 | class VGG11BN(Base): 131 | kwargs = {'batch_norm': True} 132 | 133 | 134 | # The VGG-16 model for Backpack - added by Xingchen Wan 30 Nov 2019 135 | 136 | 137 | def make_layers_backpack(cfg, batch_norm=True): 138 | layers = list() 139 | in_channels = 3 140 | for v in cfg: 141 | if v == 'M': 142 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 143 | else: 144 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 145 | if batch_norm: 146 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 147 | else: 148 | layers += [conv2d, nn.ReLU(inplace=True)] 149 | in_channels = v 150 | return layers 151 | 152 | 153 | def get_backpacked_VGG(model: VGG, depth=16, batch_norm=False, num_classes=10, cuda=True): 154 | 155 | import backpack, numpy as np 156 | 157 | features_layer_list = make_layers_backpack(cfg[depth], batch_norm) 158 | flatten_layer = [backpack.core.layers.Flatten()] 159 | classifier_list = [ 160 | nn.Dropout(), 161 | nn.Linear(512, 512), 162 | nn.ReLU(True), 163 | nn.Dropout(), 164 | nn.Linear(512, 512), 165 | nn.ReLU(True), 166 | nn.Linear(512, num_classes) 167 | ] 168 | backpacked_model_layers = features_layer_list + flatten_layer + classifier_list 169 | # Initialise the Backpack-ready model 170 | backpacked_model = nn.Sequential(*backpacked_model_layers) 171 | 172 | def _copy_block_content(model1, model2, offset=0): 173 | """Copy the weight and bias model1 -> model2, layer wise. Only model with identical names are reported""" 174 | m2_state_dict = model2.state_dict() 175 | for k, v in model1.state_dict().items(): 176 | n_layer = int(k.split(".")[0]) + offset 177 | model2_key = str(n_layer)+"."+k.split(".")[1] 178 | assert model2_key in m2_state_dict.keys(), model2_key + "is not in m2_state_key!. m2_state_key is " + str(m2_state_dict.keys()) 179 | m2_state_dict[model2_key].copy_(v) 180 | return model2, offset 181 | 182 | backpacked_model, offset = _copy_block_content(model.features, backpacked_model) 183 | #offset = np.max(np.array([int(n.split(".")[0]) for n in model.features.state_dict().keys()])) + 1 184 | # if depth == 6: 185 | # offset = 0 # Apologies for the magic number, but this is just an expediency for now. Xingchen 186 | # else: 187 | # raise NotImplementedError 188 | backpacked_model, _ = _copy_block_content(model.classifier, backpacked_model, 32) 189 | backpacked_model.to('cuda' if cuda else 'cpu') 190 | return backpacked_model 191 | -------------------------------------------------------------------------------- /optimizers/swats.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Patrik Purgai 3 | @copyright: Copyright 2019, swats 4 | @license: MIT 5 | @email: purgai.patrik@gmail.com 6 | @date: 2019.05.30. 7 | """ 8 | 9 | # pylint: disable=no-member 10 | 11 | from torch.optim.optimizer import Optimizer 12 | import torch 13 | 14 | 15 | class SWATS(Optimizer): 16 | r"""Implements Switching from Adam to SGD technique. Proposed in 17 | `Improving Generalization Performance by Switching from Adam to SGD` 18 | by Nitish Shirish Keskar, Richard Socher (2017). 19 | The method applies Adam in the first phase of the training, then 20 | switches to SGD when a criteria is met. 21 | Implementation of Adam and SGD update are from `torch.optim.Adam` and 22 | `torch.optim.SGD`. 23 | """ 24 | 25 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 26 | weight_decay=0, amsgrad=False, verbose=False, 27 | nesterov=False, decoupled_wd=False): 28 | self.decoupled_wd = decoupled_wd 29 | if not 0.0 <= lr: 30 | raise ValueError( 31 | "Invalid learning rate: {}".format(lr)) 32 | if not 0.0 <= eps: 33 | raise ValueError( 34 | "Invalid epsilon value: {}".format(eps)) 35 | if not 0.0 <= betas[0] < 1.0: 36 | raise ValueError( 37 | "Invalid beta parameter at index 0: {}".format(betas[0])) 38 | if not 0.0 <= betas[1] < 1.0: 39 | raise ValueError( 40 | "Invalid beta parameter at index 1: {}".format(betas[1])) 41 | defaults = dict(lr=lr, betas=betas, eps=eps, phase='ADAM', 42 | weight_decay=weight_decay, amsgrad=amsgrad, 43 | verbose=verbose, nesterov=nesterov) 44 | 45 | super(SWATS, self).__init__(params, defaults) 46 | 47 | def __setstate__(self, state): 48 | super(SWATS, self).__setstate__(state) 49 | for group in self.param_groups: 50 | group.setdefault('amsgrad', False) 51 | group.setdefault('nesterov', False) 52 | group.setdefault('verbose', False) 53 | 54 | def step(self, closure=None): 55 | """Performs a single optimization step. 56 | Arguments: 57 | closure (callable, optional): 58 | A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | loss = None 62 | if closure is not None: 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | for w in group['params']: 67 | if w.grad is None: 68 | continue 69 | if self.decoupled_wd and group['weight_decay'] != 0: 70 | w.data.mul_(1 - group['lr'] * group['weight_decay']) 71 | 72 | grad = w.grad.data 73 | 74 | if grad.is_sparse: 75 | raise RuntimeError( 76 | 'Adam does not support sparse gradients, ' 77 | 'please consider SparseAdam instead') 78 | 79 | amsgrad = group['amsgrad'] 80 | 81 | state = self.state[w] 82 | 83 | # state initialization 84 | if len(state) == 0: 85 | state['step'] = 0 86 | # exponential moving average of gradient values 87 | state['exp_avg'] = torch.zeros_like(w.data) 88 | # exponential moving average of squared gradient values 89 | state['exp_avg_sq'] = torch.zeros_like(w.data) 90 | # moving average for the non-orthogonal projection scaling 91 | state['exp_avg2'] = w.new(1).fill_(0) 92 | if amsgrad: 93 | # maintains max of all exp. moving avg. 94 | # of sq. grad. values 95 | state['max_exp_avg_sq'] = torch.zeros_like(w.data) 96 | 97 | exp_avg, exp_avg2, exp_avg_sq = \ 98 | state['exp_avg'], state['exp_avg2'], state['exp_avg_sq'], 99 | 100 | if amsgrad: 101 | max_exp_avg_sq = state['max_exp_avg_sq'] 102 | beta1, beta2 = group['betas'] 103 | 104 | state['step'] += 1 105 | 106 | if not self.decoupled_wd and group['weight_decay'] != 0: 107 | grad.add_(group['weight_decay'], w.data) 108 | 109 | # if its SGD phase, take an SGD update and continue 110 | if group['phase'] == 'SGD': 111 | if 'momentum_buffer' not in state: 112 | buf = state['momentum_buffer'] = torch.clone( 113 | grad).detach() 114 | else: 115 | buf = state['momentum_buffer'] 116 | buf.mul_(beta1).add_(grad) 117 | grad = buf 118 | 119 | grad.mul_(1 - beta1) 120 | if group['nesterov']: 121 | grad.add_(beta1, buf) 122 | 123 | w.data.add_(-group['lr'], grad) 124 | continue 125 | 126 | # decay the first and second moment running average coefficient 127 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 128 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 129 | if amsgrad: 130 | # maintains the maximum of all 2nd 131 | # moment running avg. till now 132 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 133 | # use the max. for normalizing running avg. of gradient 134 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 135 | else: 136 | denom = exp_avg_sq.sqrt().add_(group['eps']) 137 | 138 | bias_correction1 = 1 - beta1 ** state['step'] 139 | bias_correction2 = 1 - beta2 ** state['step'] 140 | step_size = group['lr'] * \ 141 | (bias_correction2 ** 0.5) / bias_correction1 142 | 143 | p = -step_size * (exp_avg / denom) 144 | w.data.add_(p) 145 | 146 | p_view = p.view(-1) 147 | pg = p_view.dot(grad.view(-1)) 148 | 149 | if pg != 0: 150 | # the non-orthognal scaling estimate 151 | scaling = p_view.dot(p_view) / -pg 152 | exp_avg2.mul_(beta2).add_(1 - beta2, scaling) 153 | 154 | # bias corrected exponential average 155 | corrected_exp_avg = exp_avg2 / bias_correction2 156 | 157 | # checking criteria of switching to SGD training 158 | if state['step'] > 1 and \ 159 | corrected_exp_avg.allclose(scaling, rtol=1e-6) and \ 160 | corrected_exp_avg > 0: 161 | group['phase'] = 'SGD' 162 | group['lr'] = corrected_exp_avg.item() 163 | if group['verbose']: 164 | print('Switching to SGD after ' 165 | '{} steps with lr {:.5f} ' 166 | 'and momentum {:.5f}.'.format( 167 | state['step'], group['lr'], beta1)) 168 | 169 | return loss -------------------------------------------------------------------------------- /curvature/models/resnext.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollár, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | """ 8 | 9 | __author__ = "Pau Rodríguez López, ISELAB, CVC-UAB" 10 | __email__ = "pau.rodri1@gmail.com" 11 | 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn import init 15 | 16 | 17 | class ResNeXtBottleneck(nn.Module): 18 | """ 19 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 20 | """ 21 | 22 | def __init__(self, in_channels, out_channels, stride, cardinality, base_width, widen_factor): 23 | """ Constructor 24 | Args: 25 | in_channels: input channel dimensionality 26 | out_channels: output channel dimensionality 27 | stride: conv stride. Replaces pooling layer. 28 | cardinality: num of convolution groups. 29 | base_width: base number of channels in each group. 30 | widen_factor: factor to reduce the input dimensionality before convolution. 31 | """ 32 | super(ResNeXtBottleneck, self).__init__() 33 | width_ratio = out_channels / (widen_factor * 64.) 34 | D = cardinality * int(base_width * width_ratio) 35 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) 36 | self.bn_reduce = nn.BatchNorm2d(D) 37 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 38 | self.bn = nn.BatchNorm2d(D) 39 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 40 | self.bn_expand = nn.BatchNorm2d(out_channels) 41 | 42 | self.shortcut = nn.Sequential() 43 | if in_channels != out_channels: 44 | self.shortcut.add_module('shortcut_conv', 45 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, 46 | bias=False)) 47 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels)) 48 | 49 | def forward(self, x): 50 | bottleneck = self.conv_reduce.forward(x) 51 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True) 52 | bottleneck = self.conv_conv.forward(bottleneck) 53 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True) 54 | bottleneck = self.conv_expand.forward(bottleneck) 55 | bottleneck = self.bn_expand.forward(bottleneck) 56 | residual = self.shortcut.forward(x) 57 | return F.relu(residual + bottleneck, inplace=True) 58 | 59 | 60 | class CifarResNeXt(nn.Module): 61 | """ 62 | ResNext optimized for the Cifar dataset, as specified in 63 | https://arxiv.org/pdf/1611.05431.pdf 64 | """ 65 | 66 | def __init__(self, cardinality, depth, num_classes, base_width, widen_factor=4): 67 | """ Constructor 68 | Args: 69 | cardinality: number of convolution groups. 70 | depth: number of layers. 71 | nlabels: number of classes 72 | base_width: base number of channels in each group. 73 | widen_factor: factor to adjust the channel dimensionality 74 | """ 75 | super(CifarResNeXt, self).__init__() 76 | self.cardinality = cardinality 77 | self.depth = depth 78 | self.block_depth = (self.depth - 2) // 9 79 | self.base_width = base_width 80 | self.widen_factor = widen_factor 81 | self.nlabels = num_classes 82 | self.output_size = 64 83 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor] 84 | 85 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 86 | self.bn_1 = nn.BatchNorm2d(64) 87 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) 88 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) 89 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) 90 | self.classifier = nn.Linear(self.stages[3], self.nlabels) 91 | init.kaiming_normal(self.classifier.weight) 92 | 93 | for key in self.state_dict(): 94 | if key.split('.')[-1] == 'weight': 95 | if 'conv' in key: 96 | init.kaiming_normal(self.state_dict()[key], mode='fan_out') 97 | if 'bn' in key: 98 | self.state_dict()[key][...] = 1 99 | elif key.split('.')[-1] == 'bias': 100 | self.state_dict()[key][...] = 0 101 | 102 | def block(self, name, in_channels, out_channels, pool_stride=2): 103 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 104 | Args: 105 | name: string name of the current block. 106 | in_channels: number of input channels 107 | out_channels: number of output channels 108 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 109 | Returns: a Module consisting of n sequential bottlenecks. 110 | """ 111 | block = nn.Sequential() 112 | for bottleneck in range(self.block_depth): 113 | name_ = '%s_bottleneck_%d' % (name, bottleneck) 114 | if bottleneck == 0: 115 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality, 116 | self.base_width, self.widen_factor)) 117 | else: 118 | block.add_module(name_, 119 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.base_width, 120 | self.widen_factor)) 121 | return block 122 | 123 | def forward(self, x): 124 | x = self.conv_1_3x3.forward(x) 125 | x = F.relu(self.bn_1.forward(x), inplace=True) 126 | x = self.stage_1.forward(x) 127 | x = self.stage_2.forward(x) 128 | x = self.stage_3.forward(x) 129 | x = F.avg_pool2d(x, 8, 1) 130 | x = x.view(-1, self.stages[3]) 131 | return self.classifier(x) 132 | 133 | 134 | class ResNeXt29CIFAR: 135 | import torchvision.transforms as transforms 136 | 137 | base = CifarResNeXt 138 | args = list() 139 | kwargs = {"cardinality": 8, 140 | "depth": 29, 141 | "base_width": 64, 142 | } 143 | 144 | mean = [x / 255 for x in [125.3, 123.0, 113.9]] 145 | std = [x / 255 for x in [63.0, 62.1, 66.7]] 146 | transform_train = transforms.Compose( 147 | [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), 148 | transforms.Normalize(mean, std)]) 149 | transform_test = transforms.Compose( 150 | [transforms.ToTensor(), transforms.Normalize(mean, std)]) 151 | 152 | 153 | class ResNeXt29: 154 | import torchvision.transforms as transforms 155 | base = ResNeXtBottleneck 156 | args = list() 157 | kwargs = {'cardinality': 8, 158 | "depth": 29, 159 | "base_width": 64, 160 | } 161 | mean = [x / 255 for x in [125.3, 123.0, 113.9]] 162 | std = [x / 255 for x in [63.0, 62.1, 66.7]] 163 | transform_train = transforms.Compose( 164 | [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), 165 | transforms.Normalize(mean, std)]) 166 | transform_test = transforms.Compose( 167 | [transforms.ToTensor(), transforms.Normalize(mean, std)]) -------------------------------------------------------------------------------- /core/loss_landscape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import tabulate 4 | import time 5 | 6 | from curvature import data, models, utils, losses 7 | from curvature.methods.swag import SWAG 8 | 9 | 10 | def build_loss_landscape( 11 | dataset: str, 12 | data_path: str, 13 | model: str, 14 | spectrum_path: str, 15 | checkpoint_path: str, 16 | use_test: bool = True, 17 | batch_size: int = 128, 18 | num_workers: int = 4, 19 | save_path: str = None, 20 | dist: float = 1., 21 | n_points: int = 21, 22 | seed: int = None, 23 | device: str = 'cuda', 24 | swag: bool = False, 25 | ) -> dict: 26 | """ 27 | This function loads a checkpoint from the network training, and the spectrum result from Lanczos algorithm, and 28 | perturbs the weight by a specified amount in each of the eigenvalue directions and then store the resulting train/ 29 | testing loss/accuracy after the perturbation. This tool is only for spectrua computed using the Lanczos algorithm 30 | 31 | :param dataset: str: ['CIFAR10', 'CIFAR100', 'MNIST', 'ImageNet32'*]: the dataset on which you would like to train the 32 | model. For ImageNet 32, we use the downsampled 32 x 32 Full ImageNet dataset. We do not provide download due to 33 | the proprietary issues, and please drop the data of ImageNet 32 in 'data/' folder 34 | 35 | :param data_path: str: the path string of the dataset 36 | 37 | :param model: str: the neural network architecture you would like to train. All available models are listed under 'models'/ 38 | Example: VGG16BN, PreResNet110 (Preactivated ResNet - 110 layers) 39 | 40 | :param spectrum_path: str: the output spectrum from the Lanczos eigenspectrum 41 | Note: only results using Lanczos algorithm can be used; diagonal approximations are not applicable here 42 | 43 | :param checkpoint_path: str: the checkpoint from network training 44 | 45 | :param use_test: bool: if True, you will test the model on the test set. If not, a portion of the training data will be 46 | assigned as the validation set. 47 | 48 | :param batch_size: int: the minibatch size 49 | 50 | :param num_workers: number of workers for the dataloader 51 | 52 | :param save_path: if provided, the loss stats dictionary will be saved an additional copy as numpy array in the specified 53 | path. 54 | 55 | :param dist: float. distance to travel along all directions (default: 60.0) 56 | 57 | :param n_points: number of points on a grid (default: 21) 58 | 59 | :param seed: 60 | 61 | :param device: 62 | 63 | :param swag: 64 | 65 | :return: 66 | """ 67 | if device == 'cuda': 68 | if not torch.cuda.is_available(): 69 | device = 'cpu' 70 | torch.backends.cudnn.benchmark = True 71 | if seed is not None: 72 | torch.manual_seed(seed) 73 | torch.cuda.manual_seed(seed) 74 | np.random.seed(seed) 75 | print('Using model ', model) 76 | model_cfg = getattr(models, model) 77 | 78 | print('Loading dataset %s from %s' % (dataset, data_path)) 79 | loaders, num_classes = data.loaders( 80 | dataset, 81 | data_path, 82 | batch_size, 83 | num_workers, 84 | transform_train=model_cfg.transform_test, 85 | transform_test=model_cfg.transform_test, 86 | use_validation=not use_test, 87 | shuffle_train=False, 88 | ) 89 | print('Preparing model') 90 | 91 | if not swag: 92 | model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 93 | print('Loading %s' % checkpoint_path) 94 | checkpoint = torch.load(checkpoint_path) 95 | model.load_state_dict(checkpoint['state_dict']) 96 | else: 97 | swag_model = SWAG(model_cfg.base, 98 | subspace_type='random', 99 | *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 100 | print('Loading %s' % checkpoint_path) 101 | checkpoint = torch.load(checkpoint_path) 102 | swag_model.load_state_dict(checkpoint['state_dict'], strict=False) 103 | swag_model.set_swa() 104 | model = swag_model.base_model 105 | 106 | model.to(device) 107 | num_parameters = sum([p.numel() for p in model.parameters()]) 108 | print('Loading %s' % spectrum_path) 109 | basis_dict = torch.load(spectrum_path) 110 | 111 | mean = basis_dict['w'].detach().numpy() 112 | eigvals = basis_dict['eigvals'].numpy()[:, 0] 113 | gammas = basis_dict['gammas'].numpy() 114 | V = basis_dict['V'].numpy() 115 | 116 | rank = eigvals.size 117 | criterion = losses.cross_entropy 118 | idx = np.array([], dtype=np.int32) 119 | idx = np.concatenate((idx, np.argsort(eigvals)[np.minimum(rank - 1, [0, 1, 2, 5])])) 120 | idx = np.concatenate((idx, np.argsort(-eigvals)[np.minimum(rank - 1, [0, 1, 2, 5])])) 121 | idx = np.concatenate((idx, np.argsort(np.abs(eigvals))[np.minimum(rank - 1, [0, 1, 2, 5])])) 122 | idx = np.sort(np.unique(np.minimum(idx, rank - 1))) 123 | K = len(idx) 124 | 125 | ts = np.linspace(-dist, dist, n_points) 126 | 127 | train_acc = np.zeros((K, n_points)) 128 | train_loss = np.zeros((K, n_points)) 129 | test_acc = np.zeros((K, n_points)) 130 | test_loss = np.zeros((K, n_points)) 131 | 132 | columns = ['#', 't', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time'] 133 | table = None 134 | 135 | for i, id in enumerate(idx): 136 | v = V[id, :].copy() 137 | for j, t in enumerate(ts): 138 | start_time = time.time() 139 | w = mean + t * v 140 | 141 | offset = 0 142 | for param in model.parameters(): 143 | size = np.prod(param.size()) 144 | param.data.copy_(param.new_tensor(w[offset:offset + size].reshape(param.size()))) 145 | offset += size 146 | 147 | utils.bn_update(loaders['train'], model) 148 | train_res = utils.eval(loaders['train'], model, criterion) 149 | test_res = utils.eval(loaders['test'], model, criterion) 150 | 151 | train_acc[i, j] = train_res['accuracy'] 152 | train_loss[i, j] = train_res['loss'] 153 | test_acc[i, j] = test_res['accuracy'] 154 | test_loss[i, j] = test_res['loss'] 155 | 156 | run_time = time.time() - start_time 157 | values = [id, t, train_loss[i, j], train_acc[i, j], test_loss[i, j], test_acc[i, j], run_time] 158 | table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f') 159 | if j == 0: 160 | table = table.split('\n') 161 | table = '\n'.join([table[1]] + table) 162 | else: 163 | table = table.split('\n')[2] 164 | print('Iteration: '+str(i * len(ts) + j) + '/' + str(len(ts) * len(idx))) 165 | print(table) 166 | 167 | if save_path is not None: 168 | np.savez( 169 | save_path, 170 | dim=num_parameters, 171 | ts=ts, 172 | eigvals=eigvals, 173 | gammas=gammas, 174 | idx=idx, 175 | train_acc=train_acc, 176 | train_err=100.0 - train_acc, 177 | train_loss=train_loss, 178 | test_acc=test_acc, 179 | test_err=100.0 - test_acc, 180 | test_loss=test_loss, 181 | ) 182 | 183 | return { 184 | 'dim': num_parameters, 185 | 'ts': ts, 186 | 'eigvals': eigvals, 187 | 'gammas': gammas, 188 | 'idx': idx, 189 | 'train_acc': train_acc, 190 | 'train_err': 100.0 - train_acc, 191 | 'train_loss': train_loss, 192 | 'test_acc': test_acc, 193 | 'test_err': 100.0 - test_acc, 194 | 'test_loss': test_loss, 195 | } 196 | -------------------------------------------------------------------------------- /curvature/models/all_cnn.py: -------------------------------------------------------------------------------- 1 | # All-CNN-C architecture for *CIFAR-100* dataset. 2 | # Added by Xingchen Wan on 2 Dec. Modified from this github repository: 3 | # https://github.com/fsschneider/DeepOBS/blob/develop/deepobs/pytorch/testproblems/testproblems_modules.py 4 | 5 | 6 | import torch.nn as nn 7 | from math import ceil 8 | 9 | __all__ = ['AllCNN_CIFAR100'] 10 | 11 | 12 | def _determine_padding_from_tf_same( 13 | input_dimensions, kernel_dimensions, stride_dimensions 14 | ): 15 | """Implements tf's padding 'same' for kernel processes like convolution or pooling. 16 | Args: 17 | input_dimensions (int or tuple): dimension of the input image 18 | kernel_dimensions (int or tuple): dimensions of the convolution kernel 19 | stride_dimensions (int or tuple): the stride of the convolution 20 | Returns: A padding 4-tuple for padding layer creation that mimics tf's padding 'same'. 21 | """ 22 | 23 | # get dimensions 24 | in_height, in_width = input_dimensions 25 | 26 | if isinstance(kernel_dimensions, int): 27 | kernel_height = kernel_dimensions 28 | kernel_width = kernel_dimensions 29 | else: 30 | kernel_height, kernel_width = kernel_dimensions 31 | 32 | if isinstance(stride_dimensions, int): 33 | stride_height = stride_dimensions 34 | stride_width = stride_dimensions 35 | else: 36 | stride_height, stride_width = stride_dimensions 37 | 38 | # determine the output size that is to achive by the padding 39 | out_height = ceil(in_height / stride_height) 40 | out_width = ceil(in_width / stride_width) 41 | 42 | # determine the pad size along each dimension 43 | pad_along_height = max( 44 | (out_height - 1) * stride_height + kernel_height - in_height, 0 45 | ) 46 | pad_along_width = max( 47 | (out_width - 1) * stride_width + kernel_width - in_width, 0 48 | ) 49 | 50 | # determine padding 4-tuple (can be asymmetric) 51 | pad_top = pad_along_height // 2 52 | pad_bottom = pad_along_height - pad_top 53 | pad_left = pad_along_width // 2 54 | pad_right = pad_along_width - pad_left 55 | 56 | return pad_left, pad_right, pad_top, pad_bottom 57 | 58 | 59 | def hook_factory_tf_padding_same(kernel_size, stride): 60 | """Generates the torch pre forward hook that needs to be registered on 61 | the padding layer to mimic tf's padding 'same'""" 62 | 63 | def hook(module, input): 64 | """The hook overwrites the padding attribute of the padding layer.""" 65 | image_dimensions = input[0].size()[-2:] 66 | module.padding = _determine_padding_from_tf_same( 67 | image_dimensions, kernel_size, stride 68 | ) 69 | 70 | return hook 71 | 72 | 73 | def tfconv2d( 74 | in_channels, 75 | out_channels, 76 | kernel_size, 77 | stride=1, 78 | dilation=1, 79 | groups=1, 80 | bias=True, 81 | tf_padding_type=None, 82 | ): 83 | modules = [] 84 | if tf_padding_type == "same": 85 | padding = nn.ZeroPad2d(0) 86 | hook = hook_factory_tf_padding_same(kernel_size, stride) 87 | padding.register_forward_pre_hook(hook) 88 | modules.append(padding) 89 | 90 | modules.append( 91 | nn.Conv2d( 92 | in_channels=in_channels, 93 | out_channels=out_channels, 94 | kernel_size=kernel_size, 95 | stride=stride, 96 | padding=0, 97 | dilation=dilation, 98 | groups=groups, 99 | bias=bias, 100 | ) 101 | ) 102 | return nn.Sequential(*modules) 103 | 104 | 105 | def mean_allcnnc(): 106 | """The all convolution layer implementation of torch.mean(). 107 | Use the backpack version of the flatten layer - edited by Xingchen Wan""" 108 | from backpack.core.layers import Flatten 109 | return nn.Sequential(nn.AvgPool2d(kernel_size=(6, 6)), Flatten()) 110 | 111 | 112 | class AllCNN_C(nn.Sequential): 113 | def __init__(self, num_classes=100): 114 | super(AllCNN_C, self).__init__() 115 | 116 | self.add_module("dropout1", nn.Dropout(p=0.2)) 117 | 118 | self.add_module( 119 | "conv1", tfconv2d(in_channels=3, 120 | out_channels=96, 121 | kernel_size=3, 122 | tf_padding_type="same",), 123 | ) 124 | self.add_module("relu1", nn.ReLU()) 125 | self.add_module( 126 | "conv2", 127 | tfconv2d( 128 | in_channels=96, 129 | out_channels=96, 130 | kernel_size=3, 131 | tf_padding_type="same", 132 | ), 133 | ) 134 | self.add_module("relu2", nn.ReLU()) 135 | self.add_module( 136 | "conv3", 137 | tfconv2d( 138 | in_channels=96, 139 | out_channels=96, 140 | kernel_size=3, 141 | stride=(2, 2), 142 | tf_padding_type="same", 143 | ), 144 | ) 145 | self.add_module("relu3", nn.ReLU()) 146 | 147 | self.add_module("dropout2", nn.Dropout(p=0.5)) 148 | 149 | self.add_module( 150 | "conv4", 151 | tfconv2d( 152 | in_channels=96, 153 | out_channels=192, 154 | kernel_size=3, 155 | tf_padding_type="same", 156 | ), 157 | ) 158 | self.add_module("relu4", nn.ReLU()) 159 | self.add_module( 160 | "conv5", 161 | tfconv2d( 162 | in_channels=192, 163 | out_channels=192, 164 | kernel_size=3, 165 | tf_padding_type="same", 166 | ), 167 | ) 168 | self.add_module("relu5", nn.ReLU()) 169 | self.add_module( 170 | "conv6", 171 | tfconv2d( 172 | in_channels=192, 173 | out_channels=192, 174 | kernel_size=3, 175 | stride=(2, 2), 176 | tf_padding_type="same", 177 | ), 178 | ) 179 | self.add_module("relu6", nn.ReLU()) 180 | 181 | self.add_module("dropout3", nn.Dropout(p=0.5)) 182 | 183 | self.add_module( 184 | "conv7", tfconv2d(in_channels=192, out_channels=192, kernel_size=3) 185 | ) 186 | self.add_module("relu7", nn.ReLU()) 187 | self.add_module( 188 | "conv8", 189 | tfconv2d( 190 | in_channels=192, 191 | out_channels=192, 192 | kernel_size=1, 193 | tf_padding_type="same", 194 | ), 195 | ) 196 | self.add_module("relu8", nn.ReLU()) 197 | self.add_module( 198 | "conv9", 199 | tfconv2d( 200 | in_channels=192, 201 | out_channels=num_classes, 202 | kernel_size=1, 203 | tf_padding_type="same", 204 | ), 205 | ) 206 | self.add_module("relu9", nn.ReLU()) 207 | 208 | self.add_module("mean", mean_allcnnc()) 209 | 210 | # init the layers 211 | for module in self.modules(): 212 | if isinstance(module, nn.Conv2d): 213 | nn.init.constant_(module.bias, 0.1) 214 | nn.init.xavier_normal_(module.weight) 215 | 216 | 217 | import torchvision.transforms as transforms 218 | 219 | 220 | class AllCNN_CIFAR100: 221 | base = AllCNN_C 222 | args = list() 223 | kwargs = dict() 224 | transform_train = transforms.Compose([ 225 | transforms.Resize(32), 226 | transforms.RandomCrop(32, padding=4), 227 | transforms.RandomHorizontalFlip(), 228 | transforms.ToTensor(), 229 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 230 | ]) 231 | transform_test = transforms.Compose([ 232 | transforms.Resize(32), 233 | transforms.ToTensor(), 234 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 235 | ]) -------------------------------------------------------------------------------- /curvature/models/preresnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | PreResNet model definition 3 | ported from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/preresnet.py 4 | """ 5 | 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | import math 9 | 10 | __all__ = ['PreResNet110', 'PreResNet56', 'PreResNet8', 'PreResNet83', 'PreResNet164'] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.bn1 = nn.BatchNorm2d(inplanes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.bn1(x) 35 | out = self.relu(out) 36 | out = self.conv1(out) 37 | 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | out = self.conv2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | 47 | return out 48 | 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None): 54 | super(Bottleneck, self).__init__() 55 | self.bn1 = nn.BatchNorm2d(inplanes) 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn3 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.bn1(x) 70 | out = self.relu(out) 71 | out = self.conv1(out) 72 | 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | out = self.conv2(out) 76 | 77 | out = self.bn3(out) 78 | out = self.relu(out) 79 | out = self.conv3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | 86 | return out 87 | 88 | 89 | class PreResNet(nn.Module): 90 | 91 | def __init__(self, num_classes=10, depth=110): 92 | super(PreResNet, self).__init__() 93 | if depth >= 44: 94 | assert (depth - 2) % 9 == 0, 'depth should be 9n+2' 95 | n = (depth - 2) // 9 96 | block = Bottleneck 97 | else: 98 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 99 | n = (depth - 2) // 6 100 | block = BasicBlock 101 | 102 | self.inplanes = 16 103 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 104 | bias=False) 105 | self.layer1 = self._make_layer(block, 16, n) 106 | self.layer2 = self._make_layer(block, 32, n, stride=2) 107 | self.layer3 = self._make_layer(block, 64, n, stride=2) 108 | self.bn = nn.BatchNorm2d(64 * block.expansion) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.avgpool = nn.AvgPool2d(8) 111 | self.fc = nn.Linear(64 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | ) 128 | 129 | layers = list() 130 | layers.append(block(self.inplanes, planes, stride, downsample)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | x = self.conv1(x) 139 | 140 | x = self.layer1(x) # 32x32 141 | x = self.layer2(x) # 16x16 142 | x = self.layer3(x) # 8x8 143 | x = self.bn(x) 144 | x = self.relu(x) 145 | 146 | x = self.avgpool(x) 147 | x = x.view(x.size(0), -1) 148 | x = self.fc(x) 149 | 150 | return x 151 | 152 | class PreResNet164: 153 | base = PreResNet 154 | args = list() 155 | kwargs = {'depth': 164} 156 | transform_train = transforms.Compose([ 157 | transforms.Resize(32), 158 | transforms.RandomCrop(32, padding=4), 159 | transforms.RandomHorizontalFlip(), 160 | transforms.ToTensor(), 161 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 162 | ]) 163 | transform_test = transforms.Compose([ 164 | transforms.Resize(32), 165 | transforms.ToTensor(), 166 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 167 | ]) 168 | 169 | class PreResNet110: 170 | base = PreResNet 171 | args = list() 172 | kwargs = {'depth': 110} 173 | transform_train = transforms.Compose([ 174 | transforms.Resize(32), 175 | transforms.RandomCrop(32, padding=4), 176 | transforms.RandomHorizontalFlip(), 177 | transforms.ToTensor(), 178 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 179 | ]) 180 | transform_test = transforms.Compose([ 181 | transforms.Resize(32), 182 | transforms.ToTensor(), 183 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 184 | ]) 185 | 186 | class PreResNet83: 187 | base = PreResNet 188 | args = list() 189 | kwargs = {'depth': 83} 190 | transform_train = transforms.Compose([ 191 | transforms.RandomCrop(32, padding=4), 192 | transforms.RandomHorizontalFlip(), 193 | transforms.ToTensor(), 194 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 195 | ]) 196 | transform_test = transforms.Compose([ 197 | transforms.ToTensor(), 198 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 199 | ]) 200 | 201 | class PreResNet56: 202 | base = PreResNet 203 | args = list() 204 | kwargs = {'depth': 56} 205 | transform_train = transforms.Compose([ 206 | transforms.Resize(32), 207 | transforms.RandomCrop(32, padding=4), 208 | transforms.RandomHorizontalFlip(), 209 | transforms.ToTensor(), 210 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 211 | ]) 212 | transform_test = transforms.Compose([ 213 | transforms.Resize(32), 214 | transforms.ToTensor(), 215 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 216 | ]) 217 | 218 | class PreResNet8: 219 | base = PreResNet 220 | args = list() 221 | kwargs = {'depth': 8} 222 | transform_train = transforms.Compose([ 223 | transforms.Resize(32), 224 | transforms.RandomCrop(32, padding=4), 225 | transforms.RandomHorizontalFlip(), 226 | transforms.ToTensor(), 227 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 228 | ]) 229 | transform_test = transforms.Compose([ 230 | transforms.Resize(32), 231 | transforms.ToTensor(), 232 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 233 | ]) 234 | -------------------------------------------------------------------------------- /core/loss_stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tabulate 3 | import time 4 | import numpy as np 5 | import os 6 | import torch 7 | 8 | from curvature import data, models, losses, utils 9 | from curvature.methods.swag import SWAG 10 | 11 | 12 | def compute_loss_stats( 13 | dataset: str, 14 | data_path: str, 15 | model: str, 16 | checkpoint_path: tuple, 17 | use_test: bool = True, 18 | batch_size: int = 128, 19 | num_workers: int = 4, 20 | num_subsamples: int = None, 21 | subsample_seed: int = None, 22 | stats_batch: int = 128, 23 | swag: bool = True, 24 | save_path: str = None, 25 | seed: int = None, 26 | curvature_matrix: str = 'hessian', 27 | device: str = 'cuda', 28 | ) -> dict: 29 | """ 30 | Compute the loss statistics, provided a checkpoint of the model saved using train_network.py. 31 | 32 | Parameters 33 | ---------- 34 | dataset: str: ['CIFAR10', 'CIFAR100', 'MNIST', 'ImageNet32'*]: the dataset on which you would like to train the 35 | model. For ImageNet 32, we use the downsampled 32 x 32 Full ImageNet dataset. We do not provide download due to 36 | the proprietary issues, and please drop the data of ImageNet 32 in 'data/' folder 37 | 38 | data_path: str: the path string of the dataset 39 | 40 | model: str: the neural network architecture you would like to train. All available models are listed under 'models'/ 41 | Example: VGG16BN, PreResNet110 (Preactivated ResNet - 110 layers) 42 | 43 | checkpoint_path: str: the path string to the checkpoints generated by train_network, which contains the state_dict 44 | of the network and the optimizer. 45 | 46 | use_test: bool: if True, you will test the model on the test set. If not, a portion of the training data will be 47 | assigned as the validation set. 48 | 49 | batch_size: int: the minibatch size 50 | 51 | num_workers: int: number of workers for the dataloader 52 | 53 | num_subsamples: int: Number of subsamples to draw randomly from the training dataset. If None, the entire dataset 54 | will be used. 55 | 56 | subsample_seed: int: the Pseudorandom number seed for subsample draw from above. 57 | 58 | stats_batch: int: the number of samples to run loss stats. Higher the stats_batch, higher the computation speed but 59 | at the same time higher the VRAM/RAM demand. 60 | 61 | swag: whether to use Stochastic Weight Averaging (Gaussian) 62 | 63 | save_path: if provided, the loss stats dictionary will be saved an additional copy as numpy array in the specified 64 | path. 65 | 66 | seed: if not None, a manual seed for the pseudo-random number generation will be used. 67 | 68 | curvature_matrix 69 | 70 | device: ['cpu', 'cuda']: the device on which the model and all computations are performed. Strongly recommend 'cuda' 71 | for GPU accleration in CUDA-enabled Nvidia Devices 72 | 73 | Returns: 74 | A dictionary, containing the following elements: 75 | 76 | ------- 77 | 'train_loss', 'train_acc', 'test_loss', 'test_acc', (Literal meaning) 78 | 'loss_mean', 'loss_var': mean and variance of test losses 79 | 'grad_mean_norm_sq', 'grad_var',: squared mean and variance of the *gradient* 80 | 'hess_mean_norm_sq', 'hess_var', 'hess_mu', squared mean, variance and mean of *Hessian* 81 | 'delta', 'alpha': Hessian confidence 82 | 'weight_norm_l2', 'weight_norm_linf': the L2 and L-inf norms of the weights 83 | """ 84 | if device == 'cuda': 85 | if not torch.cuda.is_available(): 86 | device = 'cpu' 87 | 88 | torch.backends.cudnn.benchmark = True 89 | if seed is not None: 90 | torch.manual_seed(seed) 91 | torch.cuda.manual_seed(seed) 92 | 93 | print('Using model %s' % model) 94 | model_cfg = getattr(models, model) 95 | 96 | full_datasets, _ = data.datasets( 97 | dataset, 98 | data_path, 99 | transform_train=model_cfg.transform_train, 100 | transform_test=model_cfg.transform_test, 101 | use_validation=not use_test, 102 | ) 103 | 104 | full_loader = torch.utils.data.DataLoader( 105 | full_datasets['train'], 106 | batch_size=batch_size, 107 | shuffle=False, 108 | num_workers=num_workers, 109 | pin_memory=True 110 | ) 111 | 112 | datasets, num_classes = data.datasets( 113 | dataset, 114 | data_path, 115 | transform_train=model_cfg.transform_test, 116 | transform_test=model_cfg.transform_test, 117 | use_validation=not use_test, 118 | train_subset=num_subsamples, 119 | train_subset_seed=subsample_seed, 120 | ) 121 | 122 | loader = torch.utils.data.DataLoader( 123 | datasets['train'], 124 | batch_size=stats_batch, 125 | shuffle=False, 126 | num_workers=num_workers, 127 | pin_memory=True 128 | ) 129 | 130 | batch_loader = torch.utils.data.DataLoader( 131 | datasets['train'], 132 | batch_size=batch_size, 133 | shuffle=False, 134 | num_workers=num_workers, 135 | pin_memory=True 136 | ) 137 | 138 | test_loader = torch.utils.data.DataLoader( 139 | datasets['test'], 140 | batch_size=batch_size, 141 | shuffle=False, 142 | num_workers=num_workers, 143 | pin_memory=True 144 | ) 145 | 146 | print('Preparing model') 147 | print(*model_cfg.args, dict(**model_cfg.kwargs)) 148 | if not swag: 149 | model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 150 | model.to(device) 151 | swag_model = None 152 | else: 153 | swag_model = SWAG(model_cfg.base, subspace_type='random', 154 | *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 155 | swag_model.to(device) 156 | model = None 157 | 158 | criterion = losses.cross_entropy 159 | 160 | stat_labels = [ 161 | 'train_loss', 'train_acc', 'test_loss', 'test_acc', 162 | 'loss_mean', 'loss_var', 163 | 'grad_mean_norm_sq', 'grad_var', 164 | 'hess_mean_norm_sq', 'hess_var', 'hess_mu', 165 | 'delta', 'alpha', 166 | 'weight_norm_l2', 'weight_norm_linf' 167 | ] 168 | 169 | # Is args.ckpt a directory? 170 | if len(checkpoint_path) == 1 and os.path.isdir(checkpoint_path[0]): 171 | checkpoint_path = [] 172 | for filename in os.listdir(checkpoint_path[0]): 173 | if filename.endswith(".pt"): 174 | checkpoint_path.append(os.path.join(checkpoint_path[0], filename)) 175 | print("File list: ", checkpoint_path) 176 | 177 | K = len(checkpoint_path) 178 | stat_dict = { 179 | label: np.zeros(K) for label in stat_labels 180 | } 181 | 182 | columns = ['#'] + stat_labels + ['time'] 183 | 184 | for i, ckpt_path in enumerate(checkpoint_path): 185 | start_time = time.time() 186 | print('Loading %s' % checkpoint_path) 187 | checkpoint = torch.load(ckpt_path) 188 | if not swag: 189 | model.load_state_dict(checkpoint['state_dict']) 190 | else: 191 | swag_model.load_state_dict(checkpoint['state_dict'], strict=False) 192 | swag_model.set_swa() 193 | model = swag_model.base_model 194 | 195 | utils.bn_update(full_loader, model) 196 | train_res = utils.eval(full_loader, model, criterion) 197 | test_res = utils.eval(test_loader, model, criterion) 198 | 199 | stat_dict['train_loss'][i] = train_res['loss'] 200 | stat_dict['train_acc'][i] = train_res['accuracy'] 201 | stat_dict['test_loss'][i] = test_res['loss'] 202 | stat_dict['test_acc'][i] = test_res['accuracy'] 203 | 204 | loss_stats = utils.loss_stats(loader, model, criterion, cuda=True, verbose=False, 205 | bn_train_mode=True, curvature_matrix=curvature_matrix) 206 | w = torch.cat([param.detach().cpu().view(-1) for param in model.parameters()]) 207 | w_l2_norm = torch.norm(w).numpy() 208 | w_linf_norm = torch.norm(w, float('inf')).numpy() 209 | 210 | for label, value in loss_stats.items(): 211 | stat_dict[label][i] = value 212 | stat_dict['weight_norm_l2'] = w_l2_norm 213 | stat_dict['weight_norm_linf'] = w_linf_norm 214 | ckpt_time = time.time() - start_time 215 | 216 | values = ['%d/%d' % (i + 1, K)] + [stat_dict[label][i] for label in stat_labels] + [ckpt_time] 217 | 218 | table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='0.2g') 219 | table = table.split('\n') 220 | table = '\n'.join([table[1]] + table) 221 | print(table) 222 | 223 | stat_dict['train_err'] = 100.0 - stat_dict['train_acc'] 224 | stat_dict['test_err'] = 100.0 - stat_dict['test_acc'] 225 | 226 | num_parameters = sum([p.numel() for p in model.parameters()]) 227 | 228 | if save_path is not None: 229 | np.savez( 230 | save_path, 231 | checkpoints=checkpoint_path, 232 | num_parameters=num_parameters, 233 | **stat_dict 234 | ) 235 | return stat_dict 236 | -------------------------------------------------------------------------------- /optimizers/ekfac.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | from utils.kfac_utils import (ComputeCovA, ComputeCovG, ComputeMatGrad) 7 | from utils.kfac_utils import update_running_stat 8 | 9 | 10 | class EKFACOptimizer(optim.Optimizer): 11 | def __init__(self, 12 | model, 13 | lr=0.001, 14 | momentum=0.9, 15 | stat_decay=0.95, 16 | damping=0.001, 17 | kl_clip=0.001, 18 | weight_decay=0, 19 | TCov=10, 20 | TScal=10, 21 | TInv=100, 22 | batch_averaged=True): 23 | if lr < 0.0: 24 | raise ValueError("Invalid learning rate: {}".format(lr)) 25 | if momentum < 0.0: 26 | raise ValueError("Invalid momentum value: {}".format(momentum)) 27 | if weight_decay < 0.0: 28 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 29 | defaults = dict(lr=lr, momentum=momentum, damping=damping, 30 | weight_decay=weight_decay) 31 | # TODO (CW): EKFAC optimizer now only support model as input 32 | super(EKFACOptimizer, self).__init__(model.parameters(), defaults) 33 | self.CovAHandler = ComputeCovA() 34 | self.CovGHandler = ComputeCovG() 35 | self.MatGradHandler = ComputeMatGrad() 36 | self.batch_averaged = batch_averaged 37 | 38 | self.known_modules = {'Linear', 'Conv2d'} 39 | 40 | self.modules = [] 41 | self.grad_outputs = {} 42 | 43 | self.model = model 44 | self._prepare_model() 45 | 46 | self.steps = 0 47 | 48 | self.m_aa, self.m_gg = {}, {} 49 | self.Q_a, self.Q_g = {}, {} 50 | self.d_a, self.d_g = {}, {} 51 | self.S_l = {} 52 | self.A, self.DS = {}, {} 53 | self.stat_decay = stat_decay 54 | 55 | self.kl_clip = kl_clip 56 | self.TCov = TCov 57 | self.TScal = TScal 58 | self.TInv = TInv 59 | 60 | def _save_input(self, module, input): 61 | if torch.is_grad_enabled() and self.steps % self.TCov == 0: 62 | aa = self.CovAHandler(input[0].data, module) 63 | # Initialize buffers 64 | if self.steps == 0: 65 | self.m_aa[module] = torch.diag(aa.new(aa.size(0)).fill_(1)) 66 | update_running_stat(aa, self.m_aa[module], self.stat_decay) 67 | if torch.is_grad_enabled() and self.steps % self.TScal == 0 and self.steps > 0: 68 | self.A[module] = input[0].data 69 | 70 | def _save_grad_output(self, module, grad_input, grad_output): 71 | # Accumulate statistics for Fisher matrices 72 | if self.acc_stats and self.steps % self.TCov == 0: 73 | gg = self.CovGHandler(grad_output[0].data, module, self.batch_averaged) 74 | # Initialize buffers 75 | if self.steps == 0: 76 | self.m_gg[module] = torch.diag(gg.new(gg.size(0)).fill_(1)) 77 | update_running_stat(gg, self.m_gg[module], self.stat_decay) 78 | 79 | # if self.steps % self.TInv == 0: 80 | # self._update_inv(module) 81 | 82 | if self.acc_stats and self.steps % self.TScal == 0 and self.steps > 0: 83 | self.DS[module] = grad_output[0].data 84 | # self._update_scale(module) 85 | 86 | def _prepare_model(self): 87 | count = 0 88 | print(self.model) 89 | print("=> We keep following layers in EKFAC. ") 90 | for module in self.model.modules(): 91 | classname = module.__class__.__name__ 92 | if classname in self.known_modules: 93 | self.modules.append(module) 94 | module.register_forward_pre_hook(self._save_input) 95 | module.register_backward_hook(self._save_grad_output) 96 | print('(%s): %s' % (count, module)) 97 | count += 1 98 | 99 | def _update_inv(self, m): 100 | """Do eigen decomposition for computing inverse of the ~ fisher. 101 | :param m: The layer 102 | :return: no returns. 103 | """ 104 | eps = 1e-10 # for numerical stability 105 | self.d_a[m], self.Q_a[m] = torch.symeig( 106 | self.m_aa[m], eigenvectors=True) 107 | self.d_g[m], self.Q_g[m] = torch.symeig( 108 | self.m_gg[m], eigenvectors=True) 109 | 110 | self.d_a[m].mul_((self.d_a[m] > eps).float()) 111 | self.d_g[m].mul_((self.d_g[m] > eps).float()) 112 | # if self.steps != 0: 113 | self.S_l[m] = self.d_g[m].unsqueeze(1) @ self.d_a[m].unsqueeze(0) 114 | 115 | @staticmethod 116 | def _get_matrix_form_grad(m, classname): 117 | """ 118 | :param m: the layer 119 | :param classname: the class name of the layer 120 | :return: a matrix form of the gradient. it should be a [output_dim, input_dim] matrix. 121 | """ 122 | if classname == 'Conv2d': 123 | p_grad_mat = m.weight.grad.data.view(m.weight.grad.data.size(0), -1) # n_filters * (in_c * kw * kh) 124 | else: 125 | p_grad_mat = m.weight.grad.data 126 | if m.bias is not None: 127 | p_grad_mat = torch.cat([p_grad_mat, m.bias.grad.data.view(-1, 1)], 1) 128 | return p_grad_mat 129 | 130 | def _get_natural_grad(self, m, p_grad_mat, damping): 131 | """ 132 | :param m: the layer 133 | :param p_grad_mat: the gradients in matrix form 134 | :return: a list of gradients w.r.t to the parameters in `m` 135 | """ 136 | # p_grad_mat is of output_dim * input_dim 137 | # inv((ss')) p_grad_mat inv(aa') = [ Q_g (1/R_g) Q_g^T ] @ p_grad_mat @ [Q_a (1/R_a) Q_a^T] 138 | v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m] 139 | v2 = v1 / (self.S_l[m] + damping) 140 | v = self.Q_g[m] @ v2 @ self.Q_a[m].t() 141 | if m.bias is not None: 142 | # we always put gradient w.r.t weight in [0] 143 | # and w.r.t bias in [1] 144 | v = [v[:, :-1], v[:, -1:]] 145 | v[0] = v[0].view(m.weight.grad.data.size()) 146 | v[1] = v[1].view(m.bias.grad.data.size()) 147 | else: 148 | v = [v.view(m.weight.grad.data.size())] 149 | 150 | return v 151 | 152 | def _kl_clip_and_update_grad(self, updates, lr): 153 | # do kl clip 154 | vg_sum = 0 155 | for m in self.modules: 156 | v = updates[m] 157 | vg_sum += (v[0] * m.weight.grad.data * lr ** 2).sum().item() 158 | if m.bias is not None: 159 | vg_sum += (v[1] * m.bias.grad.data * lr ** 2).sum().item() 160 | nu = min(1.0, math.sqrt(self.kl_clip / vg_sum)) 161 | 162 | for m in self.modules: 163 | v = updates[m] 164 | m.weight.grad.data.copy_(v[0]) 165 | m.weight.grad.data.mul_(nu) 166 | if m.bias is not None: 167 | m.bias.grad.data.copy_(v[1]) 168 | m.bias.grad.data.mul_(nu) 169 | 170 | def _step(self, closure): 171 | # FIXME (CW): Modified based on SGD (removed nestrov and dampening in momentum.) 172 | # FIXME (CW): 1. no nesterov, 2. buf.mul_(momentum).add_(1 - dampening , d_p) 173 | for group in self.param_groups: 174 | weight_decay = group['weight_decay'] 175 | momentum = group['momentum'] 176 | 177 | for p in group['params']: 178 | if p.grad is None: 179 | continue 180 | d_p = p.grad.data 181 | if weight_decay != 0 and self.steps >= 20 * self.TCov: 182 | d_p.add_(weight_decay, p.data) 183 | if momentum != 0: 184 | param_state = self.state[p] 185 | if 'momentum_buffer' not in param_state: 186 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 187 | buf.mul_(momentum).add_(d_p) 188 | else: 189 | buf = param_state['momentum_buffer'] 190 | buf.mul_(momentum).add_(1, d_p) 191 | d_p = buf 192 | 193 | p.data.add_(-group['lr'], d_p) 194 | 195 | def _update_scale(self, m): 196 | with torch.no_grad(): 197 | A, S = self.A[m], self.DS[m] 198 | grad_mat = self.MatGradHandler(A, S, m) # batch_size * out_dim * in_dim 199 | if self.batch_averaged: 200 | grad_mat *= S.size(0) 201 | 202 | s_l = (self.Q_g[m] @ grad_mat @ self.Q_a[m].t()) ** 2 # <- this consumes too much memory! 203 | s_l = s_l.mean(dim=0) 204 | if self.steps == 0: 205 | self.S_l[m] = s_l.new(s_l.size()).fill_(1) 206 | # s_ls = self.Q_g[m] @ grad_s 207 | # s_la = in_a @ self.Q_a[m].t() 208 | # s_l = 0 209 | # for i in range(0, s_ls.size(0), S.size(0)): # tradeoff between time and memory 210 | # start = i 211 | # end = min(s_ls.size(0), i + S.size(0)) 212 | # s_l += (torch.bmm(s_ls[start:end,:], s_la[start:end,:]) ** 2).sum(0) 213 | # s_l /= s_ls.size(0) 214 | # if self.steps == 0: 215 | # self.S_l[m] = s_l.new(s_l.size()).fill_(1) 216 | update_running_stat(s_l, self.S_l[m], self.stat_decay) 217 | # remove reference for reducing memory cost. 218 | self.A[m] = None 219 | self.DS[m] = None 220 | 221 | def step(self, closure=None): 222 | # FIXME(CW): temporal fix for compatibility with Official LR scheduler. 223 | group = self.param_groups[0] 224 | lr = group['lr'] 225 | damping = group['damping'] 226 | updates = {} 227 | for m in self.modules: 228 | classname = m.__class__.__name__ 229 | if self.steps % self.TInv == 0: 230 | self._update_inv(m) 231 | 232 | if self.steps % self.TScal == 0 and self.steps > 0: 233 | self._update_scale(m) 234 | 235 | p_grad_mat = self._get_matrix_form_grad(m, classname) 236 | v = self._get_natural_grad(m, p_grad_mat, damping) 237 | updates[m] = v 238 | self._kl_clip_and_update_grad(updates, lr) 239 | 240 | self._step(closure) 241 | self.steps += 1 242 | -------------------------------------------------------------------------------- /core/spectrum.py: -------------------------------------------------------------------------------- 1 | import time 2 | import tabulate 3 | import numpy as np 4 | from gpytorch.utils.lanczos import lanczos_tridiag 5 | import torch 6 | from curvature import data, models, losses, utils 7 | from curvature.methods.swag import SWAG, SWA 8 | 9 | 10 | def compute_eigenspectrum( 11 | dataset: str, 12 | data_path: str, 13 | model: str, 14 | checkpoint_path: str, 15 | curvature_matrix: str = 'hessian_lanczos', 16 | use_test: bool = True, 17 | batch_size: int = 128, 18 | num_workers: int = 4, 19 | swag: bool = False, 20 | lanczos_iters: int = 100, 21 | num_subsamples: int = None, 22 | subsample_seed: int = None, 23 | bn_train_mode: bool = True, 24 | save_spectrum_path: str = None, 25 | save_eigvec: bool = False, 26 | seed: int = None, 27 | device: str = 'cuda', 28 | ): 29 | """ 30 | This function takes a deep learning model and compute the eigenvalues and eigenvectors (if desired) of the deep 31 | learning model, either using Lanczos algorithm or using Backpack [1] interface of diagonal approximation of the 32 | various curvature matrix. 33 | Parameters 34 | ---------- 35 | dataset: str: ['CIFAR10', 'CIFAR100', 'MNIST', 'ImageNet32'*]: the dataset on which you would like to train the 36 | model. For ImageNet 32, we use the downsampled 32 x 32 Full ImageNet dataset. We do not provide download due to 37 | the proprietary issues, and please drop the data of ImageNet 32 in 'data/' folder data_path 38 | 39 | data_path: str: the path string of the dataset 40 | 41 | model: str: the neural network architecture you would like to train. All available models are listed under 'models'/ 42 | Example: VGG16BN, PreResNet110 (Preactivated ResNet - 110 layers) 43 | 44 | checkpoint_path: str: the path string to the checkpoints generated by train_network, which contains the state_dict 45 | of the network and the optimizer. 46 | 47 | curvature_matrix: str: the type of curvature matrix and computation method desired. 48 | Possible values are: 49 | hessian_lanczos: Lanczos algorithm of Hessian matrix 50 | ggn_lanczos: Lanczos algorithm on Generalised Gauss-Newton (GGN) 51 | cov_grad_lancozs: Lanczos algorithm on Covariance of Gradients 52 | 53 | WARNING: the Backpack package (the diagonal computation interface) we use does not support Residual layers in 54 | ResNets and derived networks (as of 14 Dec 2019), 55 | Further, it constrains the model to be a subclass of nn.Sequential. We have 56 | written modified VGG16 for this purpose, but there is no guarantee that other models will work as-is. 57 | 58 | use_test: bool: if True, you will test the model on the test set. If not, a portion of the training data will be 59 | assigned as the validation set. 60 | 61 | batch_size: int: the minibatch size 62 | 63 | num_workers: int: number of workers for the dataloader 64 | 65 | swag: whether to use Stochastic Weight Averaging (Gaussian) 66 | 67 | lanczos_iters: *only applicable if the curvature_matrix is set to hessian_lanczos, ggn_lanczos or cov_grad_lanczos* 68 | Number of iterations for the Lanczos algorithm. This also determines the Ritz value - vector pair generated from 69 | the Eigenspectrum. 70 | 71 | num_subsamples: int: Number of subsamples to draw randomly from the training dataset. If None, the entire dataset 72 | will be used. 73 | 74 | subsample_seed: int: the Pseudorandom number seed for subsample draw from above. 75 | 76 | bn_train_mode: bool: Applies only if the network architecture (''model'') used contains batch normalization layers. 77 | Toggles whether BN layers should be in train or eval mode. 78 | 79 | save_spectrum_path: str: If provided, the Ritz value generated (or the diagonal approximation) will be saved to this 80 | poth. 81 | 82 | save_eigvec: bool: If True, the implied eigenvectors will also be saved to the same format. 83 | Note: When this is true, instead of converting the arrays to numpy.ndarray we save directly the torch Tensor. The 84 | eigenvectors have size P, where P is the number of parameters in the model, so turning this mode on while running 85 | a large number of experiments could take lots of storage. 86 | 87 | seed: if not None, a manual seed for the pseudo-random number generation will be used. 88 | 89 | device: ['cpu', 'cuda']: the device on which the model and all computations are performed. Strongly recommend 'cuda' 90 | for GPU accleration in CUDA-enabled Nvidia Devices 91 | 92 | Returns 93 | ------- 94 | (eigvals, gammas, V): 95 | eigvals: the computed Ritz Value / diagonal elements of the curvature matrix 96 | gammas: 97 | V: 98 | """ 99 | if device == 'cuda': 100 | if not torch.cuda.is_available(): 101 | device = 'cpu' 102 | assert curvature_matrix in ['hessian_lanczos', 'ggn_lanczos', 'cov_grad_lanczos',] 103 | 104 | torch.backends.cudnn.benchmark = True 105 | if seed is not None: 106 | torch.manual_seed(seed) 107 | torch.cuda.manual_seed(seed) 108 | 109 | print('Using model %s' % model) 110 | model_cfg = getattr(models, model) 111 | 112 | datasets, num_classes = data.datasets( 113 | dataset, 114 | data_path, 115 | transform_train=model_cfg.transform_test, 116 | transform_test=model_cfg.transform_test, 117 | use_validation=not use_test, 118 | train_subset=num_subsamples, 119 | train_subset_seed=subsample_seed, 120 | ) 121 | 122 | loader = torch.utils.data.DataLoader( 123 | datasets['train'], 124 | batch_size=batch_size, 125 | shuffle=False, 126 | num_workers=num_workers, 127 | pin_memory=True 128 | ) 129 | 130 | full_datasets, _ = data.datasets( 131 | dataset, 132 | data_path, 133 | transform_train=model_cfg.transform_train, 134 | transform_test=model_cfg.transform_test, 135 | use_validation=not use_test, 136 | ) 137 | 138 | full_loader = torch.utils.data.DataLoader( 139 | full_datasets['train'], 140 | batch_size=batch_size, 141 | shuffle=False, 142 | num_workers=num_workers, 143 | pin_memory=True 144 | ) 145 | 146 | print('Preparing model') 147 | print(*model_cfg.args, dict(**model_cfg.kwargs)) 148 | 149 | if not swag: 150 | model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 151 | print('Loading %s' % checkpoint_path) 152 | checkpoint = torch.load(checkpoint_path) 153 | model.load_state_dict(checkpoint['state_dict']) 154 | else: 155 | swag_model = SWAG(model_cfg.base, 156 | subspace_type='random', 157 | *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 158 | print('Loading %s' % checkpoint_path) 159 | checkpoint = torch.load(checkpoint_path) 160 | swag_model.load_state_dict(checkpoint['state_dict'], strict=False) 161 | swag_model.set_swa() 162 | model = swag_model.base_model 163 | 164 | model.to(device) 165 | 166 | num_parametrs = sum([p.numel() for p in model.parameters()]) 167 | 168 | criterion = losses.cross_entropy 169 | 170 | class CurvVecProduct(object): 171 | def __init__(self, loader, model, criterion, curvature_matrix, full_loader=None): 172 | self.loader = loader 173 | self.full_loader = full_loader 174 | self.model = model 175 | self.criterion = criterion 176 | self.iters = 0 177 | self.timestamp = time.time() 178 | self.curvature_matrix = curvature_matrix 179 | 180 | def __call__(self, vector): 181 | start_time = time.time() 182 | if self.curvature_matrix == 'hessian_lanczos': 183 | output = utils.hess_vec( 184 | vector, 185 | self.loader, 186 | self.model, 187 | self.criterion, 188 | cuda= device == 'cuda', 189 | bn_train_mode=bn_train_mode, 190 | ) 191 | elif self.curvature_matrix == 'ggn_lanczos': 192 | output = utils.gn_vec( 193 | vector, 194 | self.loader, 195 | self.model, 196 | self.criterion, 197 | cuda=device == 'cuda', 198 | bn_train_mode=bn_train_mode 199 | ) 200 | elif self.curvature_matrix == 'cov_grad_lanczos': 201 | output = utils.covgrad_vec( 202 | vector, 203 | self.loader, 204 | self.model, 205 | self.criterion, 206 | cuda=device == 'cuda', 207 | bn_train_mode=bn_train_mode 208 | ) 209 | else: 210 | raise ValueError("Unrecognised curvature_matrix argument " + self.curvature_matrix) 211 | time_diff = time.time() - start_time 212 | self.iters += 1 213 | print('Iter %d. Time: %.2f' % (self.iters, time_diff)) 214 | # return output.unsqueeze(1)¬ 215 | return output.cpu().unsqueeze(1) 216 | 217 | w = torch.cat([param.detach().cpu().view(-1) for param in model.parameters()]) 218 | productor = CurvVecProduct(loader, model, criterion, curvature_matrix) 219 | utils.bn_update(full_loader, model) 220 | Q, T = lanczos_tridiag(productor, lanczos_iters, dtype=torch.float32, device='cpu', 221 | matrix_shape=(num_parametrs, num_parametrs)) 222 | eigvals, eigvects = T.eig(eigenvectors=True) 223 | gammas = eigvects[0, :] ** 2 224 | V = eigvects.t() @ Q.t() 225 | if save_spectrum_path is not None: 226 | if save_eigvec: 227 | torch.save( 228 | { 229 | 'w': w, 230 | 'eigvals': eigvals if eigvals is not None else None, 231 | 'gammas': gammas if gammas is not None else None, 232 | 'V': V if V is not None else None, 233 | }, 234 | save_spectrum_path, 235 | ) 236 | np.savez( 237 | save_spectrum_path, 238 | w=w.numpy(), 239 | eigvals=eigvals.numpy() if eigvals is not None else None, 240 | gammas=gammas.numpy() if gammas is not None else None 241 | ) 242 | return { 243 | 'w': w, 244 | 'eigvals': eigvals, 245 | 'gammas' :gammas, 246 | 'V': V 247 | } -------------------------------------------------------------------------------- /core/train_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from curvature import data, models, losses, utils 3 | from curvature.methods.swag import SWAG, SWA 4 | import optimizers 5 | import numpy as np 6 | import os, time 7 | import tabulate 8 | 9 | 10 | def train_network( 11 | dir, 12 | dataset, 13 | data_path, 14 | model: str, 15 | optimizer: str = 'SGD', 16 | optimizer_kwargs: dict = None, 17 | use_test: bool = True, 18 | batch_size: int = 128, 19 | num_workers: int = 4, 20 | resume: str = None, 21 | epochs: int = 300, 22 | save_freq: int = 25, 23 | eval_freq: int = 5, 24 | schedule: str = 'linear', 25 | swag: bool = False, 26 | swag_no_cov: bool = True, 27 | swag_resume: str = None, 28 | swag_subspace: str = 'pca', 29 | swag_lr: float = 0.05, 30 | swag_rank: int = 20, 31 | swag_start: int = 161, 32 | swag_c_epochs: int = 1, 33 | verbose: bool = False, 34 | device: str = 'cuda', 35 | seed: int = None 36 | ): 37 | """ 38 | This function trains a neural network model with given model, dataset, optimiser and other relevant configurations. 39 | Parameters 40 | ---------- 41 | dir: str: the directory to which the models and statistics are saved 42 | 43 | dataset: str: ['CIFAR10', 'CIFAR100', 'MNIST', 'ImageNet32'*]: the dataset on which you would like to train the 44 | model. For ImageNet 32, we use the downsampled 32 x 32 Full ImageNet dataset. We do not provide download due to 45 | the proprietary issues, and please drop the data of ImageNet 32 in 'data/' folder 46 | 47 | data_path: str: the path string of the dataset 48 | 49 | model: str: the neural network architecture you would like to train. All available models are listed under 'models'/ 50 | Example: VGG16BN, PreResNet110 (Preactivated ResNet - 110 layers) 51 | 52 | optimizer: str: the optimizer you would like to use. In additional to all the standard optimizers defined under 53 | torch.optim.Optimizer, in optimizer/ we defined some additional optimizers that you may use. Currently we only 54 | included SGD for the torch in-built optimizer. you may import yours manually by specifying the optimizer under 55 | optimizers/__init__.py 56 | 57 | optimizer_kwargs: dict: the keyword arguments to be supplied to the optimizer object. Some common ones include 58 | learning rate 'lr', momentum, weight decay, etc that are often optimizer-specific 59 | 60 | use_test: bool: if True, you will test the model on the test set. If not, a portion of the training data will be 61 | assigned as the validation set. 62 | 63 | batch_size: int: the minibatch size 64 | 65 | num_workers: int: number of workers for the dataloader 66 | 67 | resume: str: If not None, this string specifies a checkpoint containing the state-dict of the optimizer and the 68 | model from which pytorch may resume training 69 | 70 | epochs: int: total number of epochs of training 71 | 72 | save_freq: int: how frequent to save the model. 73 | Caution: for highly complicated modern models with many parameters, saving too often may quickly take up storage 74 | space. 75 | 76 | eval_freq: int: how frequent should the model evaluate on the validation/test dataset 77 | 78 | schedule: learning rate schedule. Allowed command = 'linear': linear decaying learning rate schedule and 'None': 79 | constant learning rate 80 | 81 | swag: whether to use Stochastic Weight Averaging (Gaussian) 82 | 83 | swag_no_cov: if True, no covariance matrix will be generated and we only have Stochastic Weight Averaging (instead 84 | of SWA-Gaussian) 85 | 86 | swag_resume: similar to ''resume'' argument, but on the SWA(G) model 87 | 88 | swag_subspace: *only applicable if swag=True and swag_no_cov=False'* subspace of the SWAG model 89 | 90 | swag_lr: *only applicable if swag=True* the learning rate after swa is activated. 91 | 92 | swag_rank: *only applicable if swag=True and swag_no_cov=False'* rank of SWAG Gaussian approx 93 | 94 | swag_start: *only applicable if swag=True*: the starting epoch number of weight averaging 95 | 96 | swag_c_epochs: *only applicable if swag=True*: frequency of model collection for averaging 97 | 98 | verbose: if True, verbose and debugging information will be displayed 99 | 100 | device: ['cpu', 'cuda']: the device on which the model and all computations are performed. Strongly recommend 'cuda' 101 | for GPU accleration in CUDA-enabled Nvidia Devices 102 | 103 | seed: if not None, a manual seed for the pseudo-random number generation will be used. 104 | 105 | Returns 106 | ------- 107 | 108 | """ 109 | if device == 'cuda': 110 | if not torch.cuda.is_available(): 111 | device = 'cpu' 112 | print('Preparing directory %s' % dir) 113 | os.makedirs(dir, exist_ok=True) 114 | 115 | torch.backends.cudnn.benchmark = True 116 | if seed is not None: 117 | torch.manual_seed(seed) 118 | torch.cuda.manual_seed(seed) 119 | print('Using model ', model) 120 | model_cfg = getattr(models, model) 121 | 122 | loaders, num_classes = data.loaders( 123 | dataset, 124 | data_path, 125 | batch_size, 126 | num_workers, 127 | model_cfg.transform_train, 128 | model_cfg.transform_test, 129 | use_validation=not use_test, 130 | ) 131 | 132 | print('Preparing model') 133 | print(*model_cfg.args, dict(**model_cfg.kwargs)) 134 | model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 135 | model.to(device) 136 | 137 | if swag: 138 | if not swag_no_cov: 139 | print('SWA-Gaussian Enabled') 140 | swag_model = SWAG(model_cfg.base, 141 | subspace_type=swag_subspace, subspace_kwargs={'max_rank': swag_rank}, 142 | *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 143 | swag_model.to(device) 144 | else: 145 | print('SWA Enabled') 146 | swag_model = SWA(model_cfg.base, *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 147 | swag_model.to(device) 148 | print(optimizer + ' training') 149 | 150 | def scheduler(epoch, mode): 151 | if mode == 'constant': 152 | return optimizer_kwargs['lr'] 153 | elif mode == 'linear': 154 | t = epoch / (swag_start if swag else epochs) 155 | lr_ratio = swag_lr / optimizer_kwargs['lr'] if swag else 0.01 156 | if t <= 0.5: 157 | factor = 1.0 158 | elif t <= 0.9: 159 | factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 160 | else: 161 | factor = lr_ratio 162 | return optimizer_kwargs['lr'] * factor 163 | else: 164 | raise NotImplementedError 165 | 166 | # Initialise a criterion 167 | criterion = losses.cross_entropy 168 | 169 | # Initialise the optimizer 170 | o = getattr(optimizers, optimizer) 171 | optim = o( 172 | model.parameters(), 173 | **optimizer_kwargs 174 | ) 175 | 176 | start_epoch = 0 177 | if resume is not None: 178 | print('Resume training from %s' % resume) 179 | checkpoint = torch.load(resume) 180 | start_epoch = checkpoint['epoch'] 181 | model.load_state_dict(checkpoint['state_dict']) 182 | optimizer.load_state_dict(checkpoint['optimizer']) 183 | 184 | if swag and swag_resume is not None: 185 | checkpoint = torch.load(swag_resume) 186 | swag_model.load_state_dict(checkpoint['state_dict']) 187 | 188 | utils.save_checkpoint( 189 | dir, 190 | start_epoch, 191 | epoch=start_epoch, 192 | state_dict=model.state_dict(), 193 | optimizer=optim.state_dict() 194 | ) 195 | 196 | for epoch in range(start_epoch, epochs): 197 | time_ep = time.time() 198 | 199 | lr = scheduler(epoch, schedule) 200 | utils.adjust_learning_rate(optim, lr) 201 | train_res = utils.train_epoch(loaders['train'], model, criterion, optim, verbose=verbose) 202 | 203 | # update batch norm parameters before testing 204 | utils.bn_update(loaders['train'], model) 205 | 206 | if epoch == 0 or epoch % eval_freq == eval_freq - 1 or epoch == epochs - 1: 207 | test_res = utils.eval(loaders['test'], model, criterion) 208 | else: 209 | test_res = {'loss': None, 'accuracy': None, 'top5_accuracy': None} 210 | 211 | if swag and (epoch + 1) > swag_start and (epoch + 1 - swag_start) % swag_c_epochs == 0: 212 | swag_model.collect_model(model) 213 | if epoch == 0 or epoch % eval_freq == eval_freq - 1 or epoch == epochs - 1: 214 | swag_model.set_swa() 215 | utils.bn_update(loaders['train'], swag_model) 216 | swag_res = utils.eval(loaders['test'], swag_model, criterion) 217 | else: 218 | swag_res = {'loss': None, 'accuracy': None, "top5_accuracy": None} 219 | 220 | if (epoch + 1) % save_freq == 0: 221 | utils.save_checkpoint( 222 | dir, 223 | epoch + 1, 224 | epoch=epoch + 1, 225 | state_dict=model.state_dict(), 226 | optimizer=optim.state_dict() 227 | ) 228 | utils.save_weight_norm( 229 | dir, 230 | epoch + 1, 231 | name='weight_norm', 232 | model=model 233 | ) 234 | if swag and (epoch + 1) > swag_start: 235 | utils.save_checkpoint( 236 | dir, 237 | epoch + 1, 238 | name='swag', 239 | epoch=epoch + 1, 240 | state_dict=swag_model.state_dict(), 241 | ) 242 | utils.save_weight_norm( 243 | dir, 244 | epoch + 1, 245 | name='swa_weight_norm', 246 | model=swag_model 247 | ) 248 | 249 | time_ep = time.time() - time_ep 250 | memory_usage = torch.cuda.memory_allocated() / (1024.0 ** 3) 251 | 252 | values = [epoch + 1, lr, train_res['loss'], train_res['accuracy'], test_res['loss'], test_res['accuracy'], 253 | test_res['top5_accuracy'], time_ep, memory_usage] 254 | 255 | np.savez( 256 | dir + 'stats-' + str(epoch), 257 | train_loss=train_res['loss'], 258 | time_ep=time_ep, 259 | memory_usage=memory_usage, 260 | train_accuracy=train_res['accuracy'], 261 | train_top5_accuracy=train_res['top5_accuracy'], 262 | test_loss=test_res['loss'], 263 | test_accuracy=test_res['accuracy'], 264 | test_top5_accuracy=test_res['top5_accuracy'] 265 | ) 266 | 267 | if swag: 268 | values = values[:-2] + [swag_res['loss'], swag_res['accuracy'], swag_res['top5_accuracy']] + values[-2:] 269 | np.savez( 270 | dir + 'stats-' + str(epoch), 271 | train_loss=train_res['loss'], 272 | time_ep=time_ep, 273 | memory_usage=memory_usage, 274 | train_accuracy=train_res['accuracy'], 275 | train_top5_accuracy=train_res['top5_accuracy'], 276 | test_loss=test_res['loss'], 277 | test_accuracy=test_res['accuracy'], 278 | test_top5_accuracy=test_res['top5_accuracy'], 279 | swag_loss=swag_res['loss'], 280 | swag_accuracy=swag_res['accuracy'], 281 | swag_top5_accuracy=swag_res['top5_accuracy'] 282 | ) 283 | 284 | if swag: 285 | values = values[:-2] + [swag_res['loss'], swag_res['accuracy'], swag_res['top5_accuracy']] + values[-2:] 286 | columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'te_top5_acc', 'time', 'mem_usage'] 287 | if swag: 288 | columns = columns[:-2] + ['swa_te_loss', 'swa_te_acc', 'swa_te_top5_acc'] + columns[-2:] 289 | swag_res = {'loss': None, 'accuracy': None, 'top5_accuracy': None} 290 | 291 | table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f') 292 | if epoch % 40 == 0: 293 | table = table.split('\n') 294 | table = '\n'.join([table[1]] + table) 295 | else: 296 | table = table.split('\n')[2] 297 | print(table) 298 | 299 | if epochs % save_freq != 0: 300 | utils.save_checkpoint( 301 | dir, 302 | epochs, 303 | epoch=epochs, 304 | state_dict=model.state_dict(), 305 | optimizer=optim.state_dict() 306 | ) 307 | if swag: 308 | utils.save_checkpoint( 309 | dir, 310 | epochs, 311 | name='swag', 312 | epoch=epochs, 313 | state_dict=swag_model.state_dict(), 314 | ) 315 | 316 | -------------------------------------------------------------------------------- /optimizers/hessianfree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import reduce 3 | 4 | 5 | class HessianFree(torch.optim.Optimizer): 6 | """ 7 | Implements the Hessian-free algorithm presented in `Training Deep and 8 | Recurrent Networks with Hessian-Free Optimization`_. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 1) 14 | delta_decay (float, optional): Decay of the previous result of 15 | computing delta with conjugate gradient method for the 16 | initialization of the next conjugate gradient iteration 17 | damping (float, optional): Initial value of the Tikhonov damping 18 | coefficient. (default: 0.5) 19 | max_iter (int, optional): Maximum number of Conjugate-Gradient 20 | iterations (default: 50) 21 | use_gnm (bool, optional): Use the generalized Gauss-Newton matrix: 22 | probably solves the indefiniteness of the Hessian (Section 20.6) 23 | verbose (bool, optional): Print statements (debugging) 24 | 25 | .. _Training Deep and Recurrent Networks with Hessian-Free Optimization: 26 | https://doi.org/10.1007/978-3-642-35289-8_27 27 | """ 28 | 29 | def __init__(self, params, 30 | lr=1, 31 | damping=0.5, 32 | delta_decay=0.95, 33 | cg_max_iter=100, 34 | use_gnm=True, 35 | verbose=False): 36 | 37 | if not (0.0 < lr <= 1): 38 | raise ValueError("Invalid lr: {}".format(lr)) 39 | 40 | if not (0.0 < damping <= 1): 41 | raise ValueError("Invalid damping: {}".format(damping)) 42 | 43 | if not cg_max_iter > 0: 44 | raise ValueError("Invalid cg_max_iter: {}".format(cg_max_iter)) 45 | 46 | defaults = dict(alpha=lr, 47 | damping=damping, 48 | delta_decay=delta_decay, 49 | cg_max_iter=cg_max_iter, 50 | use_gnm=use_gnm, 51 | verbose=verbose) 52 | super(HessianFree, self).__init__(params, defaults) 53 | 54 | self._params = self.param_groups[0]['params'] 55 | self._numel_cache = None 56 | 57 | def _numel(self): 58 | if self._numel_cache is None: 59 | self._numel_cache = reduce( 60 | lambda total, p: total + p.numel(), self._params, 0) 61 | return self._numel_cache 62 | 63 | def _modify_params(self, new_params): 64 | offset = 0 65 | for p in self._params: 66 | numel = p.numel() 67 | # view as to avoid deprecated pointwise semantics 68 | p.data = new_params[offset:offset + numel].view_as(p.data) 69 | offset += numel 70 | assert offset == self._numel() 71 | 72 | def _cast_like_params(self, vec): 73 | views = [] 74 | offset = 0 75 | for p in self._params: 76 | numel = p.numel() 77 | view = vec[offset:offset + numel].view_as(p).data 78 | views.append(view) 79 | offset += numel 80 | assert offset == self._numel() 81 | 82 | return list(views) 83 | 84 | def _gather_flat_params(self): 85 | views = list() 86 | for p in self._params: 87 | view = p.contiguous().view(-1) 88 | views.append(view) 89 | return torch.cat(views, 0) 90 | 91 | def _gather_flat_grad(self): 92 | views = list() 93 | for p in self._params: 94 | if p.grad is None: 95 | view = p.data.new(p.data.numel()).zero_() 96 | elif p.grad.data.is_sparse: 97 | view = p.grad.data.to_dense().view(-1) 98 | else: 99 | view = p.grad.contiguous().view(-1) 100 | views.append(view) 101 | return torch.cat(views, 0) 102 | 103 | def step(self, closure, b=None, M_inv=None): 104 | """ 105 | Performs a single optimization step. 106 | 107 | Arguments: 108 | closure (callable): A closure that re-evaluates the model 109 | and returns a tuple of the loss and the output. 110 | b (callable, optional): A closure that calculates the vector b in 111 | the minimization problem x^T . A . x + x^T b. 112 | M (callable, optional): The INVERSE preconditioner of A 113 | """ 114 | assert len(self.param_groups) == 1 115 | 116 | group = self.param_groups[0] 117 | alpha = group['alpha'] 118 | delta_decay = group['delta_decay'] 119 | cg_max_iter = group['cg_max_iter'] 120 | damping = group['damping'] 121 | use_gnm = group['use_gnm'] 122 | verbose = group['verbose'] 123 | 124 | state = self.state[self._params[0]] 125 | state.setdefault('func_evals', 0) 126 | state.setdefault('n_iter', 0) 127 | 128 | loss_before, output = closure() 129 | current_evals = 1 130 | state['func_evals'] += 1 131 | 132 | # Gather current parameters and respective gradients 133 | flat_params = self._gather_flat_params() 134 | flat_grad = self._gather_flat_grad() 135 | 136 | # Define linear operator 137 | if use_gnm: 138 | # Generalized Gauss-Newton vector product 139 | def A(x): 140 | return self._Gv(loss_before, output, x, damping) 141 | else: 142 | # Hessian-vector product 143 | def A(x): 144 | return self._Hv(flat_grad, x, damping) 145 | 146 | if M_inv is not None: 147 | m_inv = M_inv() 148 | 149 | # Preconditioner recipe (Section 20.13) 150 | if m_inv.dim() == 1: 151 | m = (m_inv + damping) ** (-0.85) 152 | 153 | def M(x): 154 | return m * x 155 | else: 156 | m = torch.inverse(m_inv + damping * torch.eye(*m_inv.shape)) 157 | 158 | def M(x): 159 | return m @ x 160 | else: 161 | M = None 162 | 163 | b = flat_grad.detach() if b is None else b().detach().flatten() 164 | 165 | # Initializing Conjugate-Gradient (Section 20.10) 166 | if state.get('init_delta') is not None: 167 | init_delta = delta_decay * state.get('init_delta') 168 | else: 169 | init_delta = torch.zeros_like(flat_params) 170 | 171 | eps = torch.finfo(b.dtype).eps 172 | 173 | # Conjugate-Gradient 174 | deltas, Ms = self._CG(A=A, b=b.neg(), x0=init_delta, 175 | M=M, max_iter=cg_max_iter, 176 | tol=1e1 * eps, eps=eps, martens=True) 177 | 178 | # Update parameters 179 | delta = state['init_delta'] = deltas[-1] 180 | M = Ms[-1] 181 | 182 | self._modify_params(flat_params + delta) 183 | loss_now = closure()[0] 184 | current_evals += 1 185 | state['func_evals'] += 1 186 | 187 | # Conjugate-Gradient backtracking (Section 20.8.7) 188 | if verbose: 189 | print("Original loss: \t{}".format(float(loss_before))) 190 | print("Loss before bt: {}".format(float(loss_now))) 191 | 192 | for (d, m) in zip(reversed(deltas[:-1][::2]), reversed(Ms[:-1][::2])): 193 | self._modify_params(flat_params + d) 194 | loss_prev = closure()[0] 195 | if float(loss_prev) > float(loss_now): 196 | break 197 | delta = d 198 | M = m 199 | loss_now = loss_prev 200 | 201 | if verbose: 202 | print("Loss after bt: \t{}".format(float(loss_now))) 203 | 204 | # The Levenberg-Marquardt Heuristic (Section 20.8.5) 205 | reduction_ratio = (float(loss_now) - 206 | float(loss_before)) / M if M != 0 else 1 207 | 208 | if reduction_ratio < 0.25: 209 | group['damping'] *= 3 / 2 210 | elif reduction_ratio > 0.75: 211 | group['damping'] *= 2 / 3 212 | if reduction_ratio < 0: 213 | group['init_delta'] = 0 214 | 215 | if verbose: 216 | print("Reduction_ratio: {}".format(reduction_ratio)) 217 | print("Damping: {}".format(group['damping'])) 218 | 219 | # Line Searching (Section 20.8.8) 220 | beta = 0.8 221 | c = 1e-2 222 | min_improv = min(c * torch.dot(b, delta), 0) 223 | 224 | for _ in range(60): 225 | if float(loss_now) <= float(loss_before) + alpha * min_improv: 226 | break 227 | 228 | alpha *= beta 229 | self._modify_params(flat_params + alpha * delta) 230 | loss_now = closure()[0] 231 | else: # No good update found 232 | alpha = 0.0 233 | loss_now = loss_before 234 | 235 | # Update the parameters (this time fo real) 236 | self._modify_params(flat_params + alpha * delta) 237 | 238 | if verbose: 239 | print("Final loss: {}".format(float(loss_now))) 240 | print("Lr: {}".format(alpha), end='\n\n') 241 | 242 | return loss_now 243 | 244 | def _CG(self, A, b, x0, M=None, max_iter=50, tol=1.2e-6, eps=1.2e-7, 245 | martens=False): 246 | """ 247 | Minimizes the linear system x^T.A.x - x^T b using the conjugate 248 | gradient method 249 | 250 | Arguments: 251 | A (callable): An abstract linear operator implementing the 252 | product A.x. A must represent a hermitian, positive definite 253 | matrix. 254 | b (torch.Tensor): The vector b. 255 | x0 (torch.Tensor): An initial guess for x. 256 | M (callable, optional): An abstract linear operator implementing 257 | the product of the preconditioner (for A) matrix with a vector. 258 | tol (float, optional): Tolerance for convergence. 259 | martens (bool, optional): Flag for Martens' convergence criterion. 260 | """ 261 | 262 | x = [x0] 263 | r = A(x[0]) - b 264 | 265 | if M is not None: 266 | y = M(r) 267 | p = -y 268 | else: 269 | p = -r 270 | 271 | res_i_norm = r @ r 272 | 273 | if martens: 274 | m = [0.5 * (r - b) @ x0] 275 | 276 | for i in range(max_iter): 277 | Ap = A(p) 278 | 279 | alpha = res_i_norm / ((p @ Ap) + eps) 280 | 281 | x.append(x[i] + alpha * p) 282 | r = r + alpha * Ap 283 | 284 | if M is not None: 285 | y = M(r) 286 | res_ip1_norm = y @ r 287 | else: 288 | res_ip1_norm = r @ r 289 | 290 | beta = res_ip1_norm / (res_i_norm + eps) 291 | res_i_norm = res_ip1_norm 292 | 293 | # Martens' Relative Progress stopping condition (Section 20.4) 294 | if martens: 295 | m.append(0.5 * A(x[i + 1]) @ x[i + 1] - b @ x[i + 1]) 296 | 297 | k = max(10, int(i / 10)) 298 | if i > k: 299 | stop = (m[i] - m[i - k]) / (m[i] + eps) 300 | if stop < 1e-4: 301 | break 302 | 303 | if res_i_norm < tol or torch.isnan(res_i_norm): 304 | break 305 | 306 | if M is not None: 307 | p = - y + beta * p 308 | else: 309 | p = - r + beta * p 310 | 311 | return (x, m) if martens else (x, None) 312 | 313 | def _Hv(self, gradient, vec, damping): 314 | """ 315 | Computes the Hessian vector product. 316 | """ 317 | # gg = torch.autograd.grad(gradient, self._params, 318 | # grad_outputs=vec, retain_graph=True) 319 | # Hv = torch.cat([g.contiguous().view(-1) for g in gg]) 320 | vec_ = self._cast_like_params(vec) 321 | 322 | Hv = self._Rop(gradient, self._params, vec_) 323 | Hv = torch.cat([h.flatten() for h in Hv]) 324 | 325 | return Hv + damping * vec # Tikhonov damping (Section 20.8.1) 326 | 327 | def _Gv(self, loss, output, vec, damping): 328 | """ 329 | Computes the generalized Gauss-Newton vector product. 330 | """ 331 | vec_ = self._cast_like_params(vec) 332 | Jv = self._Rop(output, self._params, vec_) 333 | 334 | gradient = torch.autograd.grad(loss, output, create_graph=True) 335 | HJv = self._Rop(gradient, output, Jv) 336 | 337 | JHJv = torch.autograd.grad( 338 | output, self._params, grad_outputs=HJv, retain_graph=True) 339 | 340 | Gv = torch.cat([j.detach().flatten() for j in JHJv]) 341 | return Gv + damping * vec # Tikhonov damping (Section 20.8.1) 342 | 343 | def _Rop(self, y, x, v): 344 | """ 345 | Computes the product (dy_i/dx_j) v_j: R-operator 346 | """ 347 | if isinstance(y, tuple): 348 | ws = [torch.zeros_like( 349 | y_i).requires_grad_(True) for y_i in y] 350 | else: 351 | ws = torch.zeros_like(y).requires_grad_(True) 352 | 353 | jacobian = torch.autograd.grad( 354 | y, x, grad_outputs=ws, create_graph=True) 355 | 356 | Jv = torch.autograd.grad( 357 | jacobian, ws, grad_outputs=v, retain_graph=True) 358 | 359 | return tuple([j.detach() for j in Jv]) 360 | 361 | 362 | # The empirical Fisher diagonal (Section 20.11.3) 363 | def empirical_fisher_diagonal(net, xs, ys, criterion): 364 | grads = list() 365 | for (x, y) in zip(xs, ys): 366 | fi = criterion(net(x), y) 367 | grads.append(torch.autograd.grad(fi, net.parameters(), 368 | retain_graph=False)) 369 | 370 | vec = torch.cat([(torch.stack(p) ** 2).mean(0).detach().flatten() 371 | for p in zip(*grads)]) 372 | return vec 373 | 374 | 375 | # The empirical Fisher matrix (Section 20.11.3) 376 | def empirical_fisher_matrix(net, xs, ys, criterion): 377 | grads = list() 378 | for (x, y) in zip(xs, ys): 379 | fi = criterion(net(x), y) 380 | grad = torch.autograd.grad(fi, net.parameters(), 381 | retain_graph=False) 382 | grads.append(torch.cat([g.detach().flatten() for g in grad])) 383 | 384 | grads = torch.stack(grads) 385 | n_batch = grads.shape[0] 386 | return torch.einsum('ij,ik->jk', grads, grads) / n_batch 387 | -------------------------------------------------------------------------------- /optimizers/kfac.py: -------------------------------------------------------------------------------- 1 | # Edited by Xingchen Wan: added KFAC-w (decoupled weight decay), KFAC-L2 (L2 regularisation) and Adaptive damping etc. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.optim as optim 7 | import numpy as np 8 | 9 | from utils.kfac_utils import (ComputeCovA, ComputeCovG) 10 | from utils.kfac_utils import update_running_stat 11 | 12 | 13 | class KFACOptimizer(optim.Optimizer): 14 | def __init__(self, 15 | model, 16 | lr=0.001, 17 | momentum=0.9, 18 | stat_decay=0.95, 19 | damping=0.001, 20 | kl_clip=0.001, 21 | weight_decay=0, 22 | TCov=10, 23 | TInv=100, 24 | batch_averaged=True, 25 | decoupled_wd=False, 26 | adaptive_mode=False, 27 | Tadapt=5, 28 | omega=19./20., 29 | cuda=True): 30 | if lr < 0.0: 31 | raise ValueError("Invalid learning rate: {}".format(lr)) 32 | if momentum < 0.0: 33 | raise ValueError("Invalid momentum value: {}".format(momentum)) 34 | if weight_decay < 0.0: 35 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 36 | defaults = dict(lr=lr, momentum=momentum, damping=damping, 37 | weight_decay=weight_decay,) 38 | # TODO (CW): KFAC optimizer now only support model as input 39 | super(KFACOptimizer, self).__init__(model.parameters(), defaults) 40 | self.CovAHandler = ComputeCovA() 41 | self.CovGHandler = ComputeCovG() 42 | self.batch_averaged = batch_averaged 43 | 44 | self.known_modules = {'Linear', 'Conv2d'} 45 | 46 | self.modules = [] 47 | self.grad_outputs = {} 48 | 49 | self.model = model 50 | self._prepare_model() 51 | 52 | self.steps = 0 53 | 54 | self.m_aa, self.m_gg = {}, {} 55 | self.Q_a, self.Q_g = {}, {} 56 | self.d_a, self.d_g = {}, {} 57 | self.stat_decay = stat_decay 58 | 59 | self.kl_clip = kl_clip 60 | self.TCov = TCov 61 | self.TInv = TInv 62 | 63 | self.acc_stats = True 64 | self.device = 'cuda' if cuda else "cpu" 65 | 66 | # Whether we toggle decoupled weight decay - this is the result of the paper "DECOUPLED WEIGHT DECAY REGULARIZATION" 67 | # which showed that in Adam, decoupled weight decay demonstrated better result and established that L2 regu- 68 | # larisation is not equivalent to weight decay for optimisers with a non-identity preconditioning matrix to the 69 | # gradient update - which is true for all second order method including K-FAC. use this option to activate 70 | # decoupled style weight decay. 71 | self.decoupled_weight_decay = decoupled_wd 72 | self.wd = weight_decay 73 | 74 | # Auto-damping facility: adaptively compute lambda value for damping - requires one additional forward pass 75 | self.adaptive_mode = adaptive_mode 76 | # Turning on adaptive mode will activate both adaptive scaling and adaptive damping 77 | self.omega = omega 78 | self.Tadapt = Tadapt 79 | 80 | def _save_input(self, module, input): 81 | if torch.is_grad_enabled() and self.steps % self.TCov == 0: 82 | aa = self.CovAHandler(input[0].data, module) 83 | # Initialize buffers 84 | if self.steps == 0: 85 | self.m_aa[module] = torch.diag(aa.new(aa.size(0)).fill_(1)) 86 | update_running_stat(aa, self.m_aa[module], self.stat_decay) 87 | 88 | def _save_grad_output(self, module, grad_input, grad_output): 89 | # Accumulate statistics for Fisher matrices 90 | if self.acc_stats and self.steps % self.TCov == 0: 91 | gg = self.CovGHandler(grad_output[0].data, module, self.batch_averaged) 92 | # Initialize buffers 93 | if self.steps == 0: 94 | self.m_gg[module] = torch.diag(gg.new(gg.size(0)).fill_(1)) 95 | update_running_stat(gg, self.m_gg[module], self.stat_decay) 96 | 97 | def _prepare_model(self): 98 | count = 0 99 | #print(self.model) 100 | #print("=> We keep following layers in KFAC. ") 101 | for module in self.model.modules(): 102 | classname = module.__class__.__name__ 103 | # print('=> We keep following layers in KFAC. <=') 104 | if classname in self.known_modules: 105 | self.modules.append(module) 106 | module.register_forward_pre_hook(self._save_input) 107 | module.register_backward_hook(self._save_grad_output) 108 | #print('(%s): %s' % (count, module)) 109 | count += 1 110 | 111 | def _update_inv(self, m): 112 | """Do eigen decomposition for computing inverse of the ~ fisher. 113 | :param m: The layer 114 | :return: no returns. 115 | """ 116 | eps = 1e-10 # for numerical stability 117 | self.d_a[m], self.Q_a[m] = torch.symeig( 118 | self.m_aa[m], eigenvectors=True) 119 | self.d_g[m], self.Q_g[m] = torch.symeig( 120 | self.m_gg[m], eigenvectors=True) 121 | 122 | # XW: squaring the eigenvalues? 123 | self.d_a[m].mul_((self.d_a[m] > eps).float()) 124 | self.d_g[m].mul_((self.d_g[m] > eps).float()) 125 | 126 | def _get_matrix_form_grad(self, m, classname): 127 | """ 128 | :param m: the layer 129 | :param classname: the class name of the layer 130 | :return: a matrix form of the gradient. it should be a [output_dim, input_dim] matrix. 131 | return 1) the matrix form of the gradient 132 | 2) the list form of the gradient 133 | """ 134 | # Xingchen edit on 28 Oct - if using l2 regularisation, the weight norm should be added to the grad before 135 | # the conditioning step. 136 | if classname == 'Conv2d': 137 | p_grad_mat = m.weight.grad.data.view(m.weight.grad.data.size(0), -1) + self.wd * m.weight.data.view(m.weight.grad.data.size(0), -1) 138 | # n_filters * (in_c * kw * kh) 139 | p_grad_list = [m.weight.grad.detach().to(self.device)] 140 | #param_list = [m.weight.data.detach().requires_grad_(True)] 141 | else: 142 | p_grad_mat = m.weight.grad.data #+ self.l2_reg * m.weight.data 143 | p_grad_list = [m.weight.grad.detach().to(self.device)] 144 | #param_list = [m.weight.data.detach().requires_grad_(True)] 145 | if m.bias is not None: 146 | bias_grad = m.bias.grad.data.view(-1, 1) 147 | bias_grad += m.bias.data.view(-1, 1) * self.wd 148 | p_grad_list = [m.weight.grad.detach().to(self.device), m.bias.grad.detach().to(self.device)] 149 | p_grad_mat = torch.cat([p_grad_mat, bias_grad], 1) 150 | #param_list = [m.weight.data.detach().requires_grad_(True), m.bias.data.detach().requires_grad_(True)] 151 | 152 | return p_grad_mat, p_grad_list 153 | 154 | def _get_natural_grad(self, m, p_grad_mat, damping): 155 | """ 156 | :param m: the layer 157 | :param p_grad_mat: the gradients in matrix form 158 | :return: a list of gradients w.r.t to the parameters in `m` 159 | """ 160 | # p_grad_mat is of output_dim * input_dim 161 | # inv((ss')) p_grad_mat inv(aa') = [ Q_g (1/R_g) Q_g^T ] @ p_grad_mat @ [Q_a (1/R_a) Q_a^T] 162 | v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m] 163 | v2 = v1 / (self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + damping) 164 | 165 | v = self.Q_g[m] @ v2 @ self.Q_a[m].t() 166 | 167 | if m.bias is not None: 168 | # we always put gradient w.r.t weight in [0] 169 | # and w.r.t bias in [1] 170 | v = [v[:, :-1], v[:, -1:]] 171 | v[0] = v[0].view(m.weight.grad.data.size()) 172 | v[1] = v[1].view(m.bias.grad.data.size()) 173 | else: 174 | v = [v.view(m.weight.grad.data.size())] 175 | return v, None 176 | 177 | def _kl_clip_and_update_grad(self, updates, lr, ): 178 | # do kl clip 179 | vg_sum = 0 180 | for m in self.modules: 181 | v = updates[m] 182 | #v[0] *= scaling 183 | vg_sum += (v[0] * m.weight.grad.data * lr ** 2).sum().item() 184 | if m.bias is not None: 185 | #v[1] *= scaling 186 | vg_sum += (v[1] * m.bias.grad.data * lr ** 2).sum().item() 187 | nu = min(1.0, math.sqrt(self.kl_clip / vg_sum)) 188 | 189 | for m in self.modules: 190 | v = updates[m] 191 | m.weight.grad.data.copy_(v[0]) 192 | m.weight.grad.data.mul_(nu) 193 | if m.bias is not None: 194 | m.bias.grad.data.copy_(v[1]) 195 | m.bias.grad.data.mul_(nu) 196 | 197 | def _step(self, closure=None): 198 | # FIXME (CW): Modified based on SGD (removed nestrov and dampening in momentum.) 199 | # FIXME (CW): 1. no nesterov, 2. buf.mul_(momentum).add_(1 - dampening , d_p) 200 | for group in self.param_groups: 201 | weight_decay = group['weight_decay'] 202 | momentum = group['momentum'] 203 | 204 | for p in group['params']: 205 | if p.grad is None: 206 | continue 207 | if weight_decay != 0 and self.steps >= 20 * self.TCov: 208 | if self.decoupled_weight_decay: 209 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 210 | 211 | d_p = p.grad.data 212 | 213 | # If using normal weight decay... 214 | if weight_decay != 0 and self.steps >= 20 * self.TCov: 215 | if not self.decoupled_weight_decay: 216 | d_p.add_(weight_decay, p.data) 217 | 218 | if momentum != 0: 219 | param_state = self.state[p] 220 | if 'momentum_buffer' not in param_state: 221 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 222 | buf.mul_(momentum).add_(d_p) 223 | else: 224 | buf = param_state['momentum_buffer'] 225 | buf.mul_(momentum).add_(1, d_p) 226 | d_p = buf 227 | 228 | # If using decoupled weight decay... 229 | 230 | 231 | p.data.add_(-group['lr'], d_p) 232 | 233 | def step(self, model_fn=None, loss_fn=None): 234 | # FIXME(CW): temporal fix for compatibility with Official LR scheduler. 235 | group = self.param_groups[0] 236 | lr = group['lr'] 237 | damping = group['damping'] 238 | natural_grad = {} 239 | grad = {} 240 | param = {} 241 | 242 | if model_fn is not None and loss_fn is not None: 243 | output = model_fn() 244 | output_d = output.detach().requires_grad_(True) 245 | loss = loss_fn(output_d) 246 | elif self.adaptive_mode: 247 | raise ValueError("Model_fn and loss_fn need to be supplied for adaptive mode") 248 | else: 249 | loss, output, output_d = None, None, None 250 | 251 | for m in self.modules: 252 | classname = m.__class__.__name__ 253 | #if self.adaptive_mode: 254 | # Add Tikhonov Damping to the A_i and G_i matrices i.e. the Kronecker blocks 255 | # self._tikhonov_step(m) 256 | if self.steps % self.TInv == 0: 257 | self._update_inv(m) 258 | p_grad_mat, p_grad_list = self._get_matrix_form_grad(m, classname) 259 | # Save the gradient - this is required for auto-damping 260 | grad[m] = p_grad_list 261 | #param[m] = param_list 262 | if not self.adaptive_mode: 263 | v, _ = self._get_natural_grad(m, p_grad_mat, damping) 264 | else: 265 | v, _ = self._get_natural_grad(m, p_grad_mat, damping) 266 | 267 | natural_grad[m] = v 268 | # natural_grad[.] is the unscaled proposal 269 | 270 | if self.adaptive_mode: 271 | # if self.steps % self.Tadapt == 0: 272 | # M is the change of objective function under quadratic model 273 | lr, M = self._rescale_and_get_quadratic_change(natural_grad, loss, output, output_d, grad) 274 | lr = min(lr, 1) 275 | self._kl_clip_and_update_grad(natural_grad, lr,) 276 | group['lr'] = lr 277 | 278 | self._step() 279 | if self.adaptive_mode and self.steps % self.Tadapt == 0: 280 | loss = self.auto_lambda(loss_fn, model_fn, loss, M) 281 | self.steps += 1 282 | return (loss, output) 283 | 284 | def auto_lambda(self, loss_fn, model_fn, prev_loss, M): 285 | """Automatically adjust the value of lambda by comparing the difference between the parabolic approximation 286 | and the true loss 287 | """ 288 | loss = loss_fn(model_fn()) 289 | # rho - the reduction ratio in Section 6.5 290 | # print("M", M, "loss_diff", loss-prev_loss) 291 | rho = (loss - prev_loss) / M 292 | factor = self.omega ** self.Tadapt 293 | if rho > 0.75: 294 | self.param_groups[0]['damping'] *= factor 295 | elif rho < 0.25: 296 | self.param_groups[0]['damping'] /= factor 297 | print(rho, self.param_groups[0]['lr'], self.param_groups[0]['damping']) 298 | return loss 299 | 300 | def _tikhonov_step(self, m): 301 | "Regularise A and G for m-th layer using factored Tikhonov - Section 6.3" 302 | A_norm = torch.trace(self.m_aa[m]) / (self.m_aa[m].shape[0] + 1) 303 | G_norm = torch.trace(self.m_gg[m]) / self.m_gg[m].shape[0] 304 | 305 | # Compute pi 306 | pi = torch.sqrt(A_norm / G_norm).cuda() 307 | # pi = 1. 308 | 309 | # Get /eta (l2 regularisation) and /lambda (damping coefficient) 310 | eta = torch.tensor(self.wd).cuda() 311 | lambd = torch.tensor(self.param_groups[0]['damping']).cuda() 312 | self.m_aa[m].add_(pi * torch.sqrt(eta + lambd), torch.eye(self.m_aa[m].shape[0], device='cuda')) 313 | self.m_gg[m].add_(torch.sqrt(eta + lambd) / pi, torch.eye(self.m_gg[m].shape[0], device='cuda')) 314 | 315 | def _rescale_and_get_quadratic_change(self, natural_grad, loss, output, output_d, grads): 316 | """ 317 | Compute scaling (/alpha) in Section 6.4 to the exact F - here we use Generalised Gauss Newton 318 | Delta: the unscaled natural gradient. 319 | Here update argument is the /Delta in Section 6.4 320 | 321 | :param natural_grad: the natural gradient (i.e. gradient preconditioned by inverse Fisher) 322 | :param loss: the network loss 323 | :param output: the network output - these two are required for the GGN-vector product computation 324 | :param grads: the gradient without pre-conditioning - this has been computed previously. 325 | Return: 326 | M: the predicted change by the quadratic model, under optimal alpha which is just M = 0.5 \nabla h^T(eta) 327 | """ 328 | 329 | # First compute the numerator $-\nabla h^T \Delta$ 330 | grad_delta_product = 0 331 | natural_grad_list = [] 332 | #param_list = list(params.values()) 333 | param_list = [] 334 | for m in self.modules: 335 | v = natural_grad[m] 336 | grad = grads[m] 337 | grad_delta_product += (v[0] * grad[0]).sum().item() 338 | natural_grad_list.append(v[0]) 339 | param_list.append(m.weight) 340 | if m.bias is not None: 341 | grad_delta_product += (v[1] * grad[1]).sum().item() 342 | natural_grad_list.append(v[1]) 343 | param_list.append(m.bias) 344 | # The update tensor - flattened. 345 | # natural_grad_list = torch.tensor(natural_grad_list).flatten() 346 | natural_grad_vec = torch.cat([v.flatten() for v in natural_grad_list]) 347 | delta_F = self.ggn_vector_product(natural_grad_list, param_list, loss, output, output_d) 348 | delta_F_delta = (delta_F * natural_grad_vec).sum().item() 349 | delta_F_delta += (self.param_groups[0]['damping'] + self.wd) * torch.norm(natural_grad_vec).item() 350 | # Compute the new scaling factor /alpha 351 | alpha = grad_delta_product / delta_F_delta 352 | #print("alpha", alpha) 353 | #print("grad_delta_product", grad_delta_product) 354 | return alpha, 0.5 * alpha * grad_delta_product 355 | 356 | def ggn_vector_product(self, vector_list, param_list, loss, output, output_d): 357 | """ 358 | Compute the GGN-vector product to compute alpha. Code lifted from CurveBall optimiser 359 | This actually computes v^TGv, which is different from the usual v^T computation 360 | """ 361 | from torch.autograd import grad 362 | (Jz,) = self._fmad(output, param_list, vector_list) # equivalent but slower 363 | 364 | # compute loss gradient Jl, retaining the graph to allow 2nd-order gradients 365 | (Jl,) = grad(loss, output_d, create_graph=True) 366 | Jl_d = Jl.detach() # detached version, without requiring gradients 367 | 368 | # compute loss Hessian (projected by Jz) using 2nd-order gradients 369 | (Hl_Jz,) = grad(Jl, output_d, grad_outputs=Jz, retain_graph=True) 370 | 371 | # compute J * (Hl_Jz + Jl) using RMAD (back-propagation). 372 | # note this is still missing the lambda * z term. 373 | delta_zs = grad(output, param_list, Hl_Jz + Jl_d, retain_graph=True) 374 | Gv = torch.cat([j.detach().view(-1) for j in delta_zs]) 375 | return Gv 376 | 377 | @staticmethod 378 | def _fmad(ys, xs, dxs): 379 | """Forward-mode automatic differentiation - used to compute the exact Generalised Gauss Newton - lifted from CurveBall""" 380 | v = torch.zeros_like(ys, requires_grad=True) 381 | g = torch.autograd.grad(ys, xs, grad_outputs=v, create_graph=True) 382 | return torch.autograd.grad(g, v, grad_outputs=dxs) 383 | 384 | ### An alternative Implementation - 385 | # https://discuss.pytorch.org/t/adding-functionality-hessian-and-fisher-information-vector-products/23295 386 | 387 | def FisherVectorProduct(self, vector_list, param_list, loss, output, output_d): 388 | Jv = self.Rop(output, param_list, vector_list) 389 | batch, dims = output.size(0), output.size(1) 390 | if loss.grad_fn.__class__.__name__ == 'NllLossBackward': 391 | outputsoftmax = torch.nn.functional.softmax(output, dim=1) 392 | M = torch.zeros(batch, dims, dims).cuda() if outputsoftmax.is_cuda else torch.zeros(batch, dims, dims) 393 | M.reshape(batch, -1)[:, ::dims + 1] = outputsoftmax 394 | H = M - torch.einsum('bi,bj->bij', (outputsoftmax, outputsoftmax)) 395 | HJv = [torch.squeeze(H @ torch.unsqueeze(Jv[0], -1)) / batch] 396 | else: 397 | HJv = self.HesssianVectorProduct(loss, output, Jv) 398 | JHJv = self.Lop(output, param_list, HJv) 399 | 400 | return torch.cat([torch.flatten(v) for v in JHJv]) 401 | 402 | def HesssianVectorProduct(self, f, x, v): 403 | df_dx = torch.autograd.grad(f, x, create_graph=True, retain_graph=True) 404 | Hv = self.Rop(df_dx, x, v) 405 | return tuple([j.detach() for j in Hv]) 406 | 407 | @staticmethod 408 | def Rop(ys, xs, vs): 409 | if isinstance(ys, tuple): 410 | ws = [torch.tensor(torch.zeros_like(y), requires_grad=True) for y in ys] 411 | else: 412 | ws = torch.tensor(torch.zeros_like(ys), requires_grad=True) 413 | 414 | gs = torch.autograd.grad(ys, xs, grad_outputs=ws, create_graph=True, retain_graph=True, allow_unused=True) 415 | re = torch.autograd.grad(gs, ws, grad_outputs=vs, create_graph=True, retain_graph=True, allow_unused=True) 416 | return tuple([j.detach() for j in re]) 417 | 418 | @staticmethod 419 | def Lop(ys, xs, ws): 420 | vJ = torch.autograd.grad(ys, xs, grad_outputs=ws, create_graph=True, retain_graph=True, allow_unused=True) 421 | return tuple([j.detach() for j in vJ]) -------------------------------------------------------------------------------- /curvature/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | import itertools 4 | import tqdm 5 | import torch 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from optimizers import KFACOptimizer 9 | from backpack import extend, backpack 10 | from backpack.extensions import DiagGGN 11 | 12 | 13 | def save_checkpoint(dir, index, name='checkpoint', **kwargs): 14 | filepath = os.path.join(dir, '%s-%05d.pt' % (name, index)) 15 | state = dict(**kwargs) 16 | torch.save(state, filepath) 17 | 18 | 19 | def adjust_learning_rate(optimizer, lr): 20 | for param_group in optimizer.param_groups: 21 | param_group['lr'] = lr 22 | return lr 23 | 24 | 25 | def adjust_kfac_damping(optimizer, damping): 26 | for param_group in optimizer.param_groups: 27 | param_group['damping'] = damping 28 | return damping 29 | 30 | 31 | def adjust_learning_rate_and_momentum(optimizer, lr, momentum): 32 | for param_group in optimizer.param_groups: 33 | param_group['lr'] = lr 34 | param_group['momentum'] = momentum 35 | return lr, momentum 36 | 37 | 38 | def train_epoch(loader, model, criterion, optimizer, cuda=True, verbose=False, subset=None, backpacked_model=False, 39 | *backpack_extensions): 40 | """ 41 | Train the model with one pass over the entire dataset (i.e. one epoch) 42 | :param loader: 43 | :param model: 44 | :param criterion: 45 | :param optimizer: 46 | :param cuda: 47 | :param verbose: 48 | :param subset: 49 | :param backpacked_model: toggle to true if the model has additional backpack functionality 50 | :param backpack_extensions: the backpack extensions you would like to enable 51 | :return: 52 | """ 53 | loss_criterion = torch.nn.CrossEntropyLoss() 54 | loss_sum = 0.0 55 | stats_sum = defaultdict(float) 56 | correct_1 = 0.0 57 | correct_5 = 0.0 58 | verb_stage = 0 59 | 60 | num_objects_current = 0 61 | num_batches = len(loader) 62 | 63 | extensions = [] 64 | if backpacked_model and len(backpack_extensions) != 0: 65 | for extension in backpack_extensions: 66 | assert extension in backpack.extensions, str(extension) + " is not found in backpack.extensions list!" 67 | e = getattr(backpack.extensions, extension) 68 | extensions.append(e()) 69 | 70 | model.train() 71 | 72 | if subset is not None: 73 | num_batches = int(num_batches * subset) 74 | loader = itertools.islice(loader, num_batches) 75 | 76 | if verbose: 77 | loader = tqdm.tqdm(loader, total=num_batches) 78 | 79 | for i, (input, target) in enumerate(loader): 80 | if cuda: 81 | #input = input.cuda(non_blocking=True) 82 | input = input.cuda(non_blocking=True) 83 | #target = target.cuda(non_blocking=True) 84 | target = target.cuda(non_blocking=True) 85 | 86 | loss, output, stats = criterion(model, input, target) 87 | 88 | optimizer.zero_grad() 89 | 90 | if isinstance(optimizer, KFACOptimizer) and optimizer.steps % optimizer.TCov == 0: 91 | # Compute true fisher 92 | optimizer.acc_stats = True 93 | with torch.no_grad(): 94 | sampled_y = torch.multinomial(torch.nn.functional.softmax(output.cpu().data, dim=1), 1).squeeze().cuda() 95 | loss_sample = loss_criterion(output, sampled_y) 96 | loss_sample.backward(retain_graph=True) 97 | optimizer.acc_stats = False 98 | optimizer.zero_grad() 99 | 100 | # If the list of backpack extension is non-empty 101 | if len(extensions): 102 | with backpack.backpack(*extensions): 103 | loss.backward() 104 | # Normal step 105 | else: 106 | loss.backward() 107 | 108 | optimizer.step() 109 | loss_sum += loss.data.item() * input.size(0) 110 | for key, value in stats.items(): 111 | stats_sum[key] += value * input.size(0) 112 | 113 | #pred = output.data.argmax(1, keepdim=True) 114 | #correct += pred.eq(target.data.view_as(pred)).sum().item() 115 | _, pred = output.topk(5, 1, True, True) 116 | pred = pred.t() 117 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 118 | correct_1 += correct[0].view(-1).float().sum(0) 119 | correct_5 += correct[:5].view(-1).float().sum(0) 120 | 121 | num_objects_current += input.size(0) 122 | 123 | if verbose and 10 * (i + 1) / num_batches >= verb_stage + 1: 124 | print('Stage %d/10. Loss: %12.4f. Acc: %6.2f. Top 5 Acc: %6.2f' % ( 125 | verb_stage + 1, loss_sum / num_objects_current, 126 | correct_1 / num_objects_current * 100.0, 127 | correct_5 / num_objects_current * 100.0 128 | )) 129 | verb_stage += 1 130 | # print(loss_sum / num_objects_current) 131 | correct_5 = correct_5.cpu() 132 | correct_1 = correct_1.cpu() 133 | return { 134 | 'loss': loss_sum / num_objects_current, 135 | 'accuracy': correct_1 / num_objects_current * 100.0, 136 | 'top5_accuracy': correct_5 / num_objects_current * 100.0, 137 | 'stats': {key: value / num_objects_current for key, value in stats_sum.items()} 138 | } 139 | 140 | 141 | def eval(loader, model, criterion, cuda=True, verbose=False): 142 | loss_sum = 0.0 143 | correct_1 = 0.0 144 | correct_5 = 0.0 145 | stats_sum = defaultdict(float) 146 | num_objects_total = len(loader.dataset) 147 | 148 | model.eval() 149 | 150 | with torch.no_grad(): 151 | if verbose: 152 | loader = tqdm.tqdm(loader) 153 | for i, (input, target) in enumerate(loader): 154 | if cuda: 155 | input = input.cuda(non_blocking=True) 156 | target = target.cuda(non_blocking=True) 157 | 158 | if criterion.__name__ != 'cross_entropy_func': 159 | loss, output, stats = criterion(model, input, target) 160 | else: 161 | model_fn, loss_fn = criterion(model, input, target) 162 | output = model_fn() 163 | loss = loss_fn(output) 164 | stats = {} 165 | loss_sum += loss.item() * input.size(0) 166 | for key, value in stats.items(): 167 | stats_sum[key] += value 168 | 169 | #pred = output.data.argmax(1, keepdim=True) 170 | #correct += pred.eq(target.data.view_as(pred)).sum().item() 171 | 172 | _, pred = output.topk(5, 1, True, True) 173 | pred = pred.t() 174 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 175 | correct_1 += correct[0].view(-1).float().sum(0) / num_objects_total * 100.0 176 | correct_5 += correct[:5].view(-1).float().sum(0) / num_objects_total * 100.0 177 | 178 | correct_1 = correct_1.cpu() 179 | correct_5 = correct_5.cpu() 180 | 181 | return { 182 | 'loss': loss_sum / num_objects_total, 183 | 'accuracy': correct_1, 184 | 'top5_accuracy': correct_5, 185 | 'stats': {key: value / num_objects_total for key, value in stats_sum.items()} 186 | } 187 | 188 | 189 | def predict(loader, model, verbose=False): 190 | predictions = list() 191 | targets = list() 192 | 193 | model.eval() 194 | 195 | if verbose: 196 | loader = tqdm.tqdm(loader) 197 | 198 | offset = 0 199 | with torch.no_grad(): 200 | for input, target in loader: 201 | input = input.cuda(non_blocking=True) 202 | output = model(input) 203 | 204 | predictions.append(F.softmax(output, dim=1).cpu().numpy()) 205 | targets.append(target.numpy()) 206 | offset += input.size(0) 207 | 208 | return { 209 | 'predictions': np.vstack(predictions), 210 | 'targets': np.concatenate(targets) 211 | } 212 | 213 | 214 | def _check_bn(module, flag): 215 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 216 | flag[0] = True 217 | 218 | 219 | def check_bn(model): 220 | flag = [False] 221 | model.apply(lambda module: _check_bn(module, flag)) 222 | return flag[0] 223 | 224 | 225 | def reset_bn(module): 226 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 227 | module.running_mean = torch.zeros_like(module.running_mean) 228 | module.running_var = torch.ones_like(module.running_var) 229 | 230 | 231 | def _get_momenta(module, momenta): 232 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 233 | momenta[module] = module.momentum 234 | 235 | 236 | def _set_momenta(module, momenta): 237 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 238 | module.momentum = momenta[module] 239 | 240 | 241 | def bn_update(loader, model, verbose=False, subset=None, **kwargs): 242 | """ 243 | BatchNorm buffers update (if any). 244 | Performs 1 epochs to estimate buffers average using train dataset. 245 | 246 | :param loader: train dataset loader for buffers average estimation. 247 | :param model: model being update 248 | :return: None 249 | """ 250 | if not check_bn(model): 251 | return 252 | model.train() 253 | momenta = {} 254 | model.apply(reset_bn) 255 | model.apply(lambda module: _get_momenta(module, momenta)) 256 | n = 0 257 | num_batches = len(loader) 258 | 259 | with torch.no_grad(): 260 | if subset is not None: 261 | num_batches = int(num_batches * subset) 262 | loader = itertools.islice(loader, num_batches) 263 | if verbose: 264 | loader = tqdm.tqdm(loader, total=num_batches) 265 | 266 | for input, _ in loader: 267 | input = input.cuda(non_blocking=True) 268 | input_var = torch.autograd.Variable(input) 269 | b = input_var.data.size(0) 270 | 271 | momentum = b / (n + b) 272 | for module in momenta.keys(): 273 | module.momentum = momentum 274 | 275 | model(input_var, **kwargs) 276 | n += b 277 | 278 | model.apply(lambda module: _set_momenta(module, momenta)) 279 | 280 | 281 | def set_weights(model, vector, device=None): 282 | offset = 0 283 | for param in model.parameters(): 284 | param.data.copy_(vector[offset:offset + param.numel()].view(param.size()).to(device)) 285 | offset += param.numel() 286 | 287 | 288 | def _bn_train_mode(module): 289 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 290 | module.train() 291 | 292 | 293 | def hess_vec(vector, loader, model, criterion, cuda=True, bn_train_mode=True): 294 | param_list = list(model.parameters()) 295 | vector_list = [] 296 | 297 | offset = 0 298 | for param in param_list: 299 | vector_list.append(vector[offset:offset + param.numel()].detach().view_as(param).to(param.device)) 300 | offset += param.numel() 301 | 302 | model.eval() 303 | if bn_train_mode: 304 | model.apply(_bn_train_mode) 305 | 306 | model.zero_grad() 307 | N = len(loader.dataset) 308 | for input, target in loader: 309 | if cuda: 310 | input = input.cuda(non_blocking=True) 311 | target = target.cuda(non_blocking=True) 312 | loss, _, _ = criterion(model, input, target) 313 | loss *= input.size()[0] / N 314 | 315 | grad_list = torch.autograd.grad(loss, param_list, create_graph=True) 316 | dL_dvec = torch.zeros(1) 317 | if cuda: 318 | dL_dvec = dL_dvec.cuda() 319 | for v, g in zip(vector_list, grad_list): 320 | dL_dvec += torch.sum(v * g) 321 | dL_dvec.backward() 322 | #print(param_list[0].grad.size()) 323 | model.eval() 324 | return torch.cat([param.grad.view(-1) for param in param_list]).view(-1) 325 | 326 | 327 | # Xingchen Wan code addition - 29 Nov 328 | def curv_diag(loader, model_name, model, criterion, num_classes=100, cuda=True, bn_train_mode=True, extensions=None): 329 | """Compute the hessian/GGN diagonal element. This function uses backpack package and will fail otherwise. 330 | The model and criterion need to be EXTENDED by backpack before use! 331 | Note that currently this model only supports AllCNN and VGG16. 332 | """ 333 | from functools import partial 334 | from curvature import models 335 | from curvature.models.vgg import get_backpacked_VGG 336 | try: 337 | import backpack 338 | except ImportError: 339 | print('this function call requires backpack. Aborting') 340 | return 341 | 342 | if model_name not in ['VGG16', 'AllCNN_CIFAR100']: 343 | raise NotImplementedError(str(model_name) + " is not currently supported.") 344 | if model_name == 'VGG16': 345 | model = get_backpacked_VGG(model, depth=16, num_classes=num_classes) 346 | 347 | if extensions == 'ggn_diag': extensions = ('DiagGGNExact, ', ) 348 | elif extensions == 'hessian_diag': extensions = ('DiagHessian', ) 349 | elif extensions == 'ggn_diag_mc': extensions = ('DiagGGNMC', ) 350 | else: raise NotImplementedError 351 | 352 | # dictionary between the name of the method and the name of variable 353 | method2variable = { 354 | 'DiagGGNMC': 'diag_ggn_mc', 355 | 'DiagGGNExact': 'diag_ggn_exact', 356 | 'DiagHessian': 'diag_h' 357 | } 358 | 359 | bn_extensions = [] 360 | for e in extensions: 361 | assert e in list(method2variable.keys()), e + 'should be oe of ' + str(list(method2variable.keys())) 362 | ext_method = getattr(backpack.extensions, e) 363 | bn_extensions.append(ext_method()) 364 | 365 | # Extract the nn.Sequential(.) representation of the model required by the backpack package 366 | bp_model = backpack.extend(model, debug=False) 367 | bp_criterion = partial(criterion, backpacked_model=True) 368 | 369 | # bp_model.eval() 370 | if bn_train_mode: 371 | bp_model.apply(_bn_train_mode) 372 | 373 | # Create a dictionary of outputs 374 | result_list = [] 375 | result = {} 376 | 377 | # Initialise each curvature diagonal with list of zero tensors, each of which 378 | # has the same shape as the parameters of the layers 379 | param_list = list(bp_model.parameters()) 380 | for param in param_list: 381 | result_list.append(torch.zeros_like(param).to(param.device)) 382 | 383 | # Assign each curvature diagonal with this zero for all curvature diagonal in extensions 384 | for i in range(len(extensions)): 385 | if i == 0: 386 | result[extensions[i]] = result_list 387 | else: 388 | result[extensions[i]] = result_list.copy() 389 | 390 | bp_model.zero_grad() 391 | N = len(loader.dataset) 392 | for input, target in loader: 393 | if cuda: 394 | input = input.cuda(non_blocking=True) 395 | target = target.cuda(non_blocking=True) 396 | with backpack.backpack(*bn_extensions): 397 | loss, _, _ = bp_criterion(model, input, target) 398 | loss *= input.size()[0] / N 399 | loss.backward() 400 | for i in range(len(extensions)): 401 | j = 0 402 | for param in param_list: 403 | v = getattr(param, method2variable[extensions[i]]) 404 | result[extensions[i]][j] += v 405 | j += 1 406 | 407 | # Finally, apply vec(.) operation on the result to obtain a long list of diagonals 408 | for k, v in result.items(): 409 | result[k] = torch.cat([v_.view(-1) for v_ in v]).view(-1).cpu() 410 | return result 411 | 412 | 413 | def covgrad_vec(vector, loader, model, criterion, cuda=True, bn_train_mode=True): 414 | param_list = list(model.parameters()) 415 | vector_list = [] 416 | # vector_list2 = [] 417 | 418 | offset = 0 419 | for param in param_list: 420 | vector_list.append(vector[offset:offset + param.numel()].detach().view_as(param).to(param.device)) 421 | offset += param.numel() 422 | 423 | # vector2 = torch.zeros_like(vector) 424 | # for param in param_list: 425 | # vector_list2.append(vector2[offset:offset + param.numel()].detach().view_as(param).to(param.device)) 426 | # offset += param.numel() 427 | 428 | vector_list2 = torch.zeros_like(vector) 429 | 430 | model.eval() 431 | if bn_train_mode: 432 | model.apply(_bn_train_mode) 433 | 434 | model.zero_grad() 435 | N = len(loader.dataset) 436 | for input, target in loader: 437 | if cuda: 438 | input = input.cuda(non_blocking=True) 439 | target = target.cuda(non_blocking=True) 440 | loss, _, _ = criterion(model, input, target) 441 | loss *= input.size()[0] / N 442 | 443 | grad_list = torch.autograd.grad(loss, param_list, create_graph=True) 444 | 445 | dL_dvec = torch.zeros(1) 446 | if cuda: 447 | dL_dvec = dL_dvec.cuda() 448 | vector_list2.cuda() 449 | for v, g in zip(vector_list, grad_list): 450 | dL_dvec += torch.sum(v * g) 451 | dL_dvec *= grad_list 452 | vector_list2 += dL_dvec 453 | #dL_dvec.backward() 454 | #print(param_list[0].grad.size()) 455 | model.eval() 456 | return vector_list2 457 | #return torch.cat([vector_list2(-1) for vector in vector_list2]).view(-1) 458 | 459 | 460 | # Xingchen Wan code addition - 20 Nov 2019 461 | def hess_noise_vec(vector, full_loader, batch_loader, model, criterion, cuda=True, bn_train_mode=True): 462 | """Compute the matrix-vector product between the Hessian noise matrix""" 463 | full_hess_vec_prod = hess_vec(vector, full_loader, model, criterion, cuda=cuda, bn_train_mode=bn_train_mode) 464 | batch_hess_vec_prod = hess_vec(vector, batch_loader, model, criterion, cuda=cuda, bn_train_mode=bn_train_mode) 465 | return full_hess_vec_prod - batch_hess_vec_prod 466 | 467 | 468 | # Xingchen Wan code addition - 1 Oct 2019 469 | def gn_vec(vector, loader, model, criterion, cuda=True, bn_train_mode=True): 470 | param_list = list(model.parameters()) 471 | vector_list = [] 472 | num_parameters = sum(p.numel() for p in param_list) 473 | 474 | offset = 0 475 | for param in param_list: 476 | vector_list.append(vector[offset:offset + param.numel()].detach().view_as(param).to(param.device)) 477 | offset += param.numel() 478 | 479 | model.eval() 480 | if bn_train_mode: 481 | model.apply(_bn_train_mode) 482 | 483 | model.zero_grad() 484 | N = len(loader.dataset) 485 | Gv = torch.zeros(num_parameters, dtype=torch.float32, device="cuda" if cuda else "cpu") 486 | 487 | for input, target in loader: 488 | if cuda: 489 | input = input.cuda(non_blocking=True) 490 | target = target.cuda(non_blocking=True) 491 | loss, output, _ = criterion(model, input, target) 492 | loss *= input.size()[0] / N 493 | 494 | Jv = R_op(output, param_list, vector_list) 495 | grad = torch.autograd.grad(loss, output, create_graph=True) 496 | HJv = R_op(grad, output, Jv) 497 | JHJv = torch.autograd.grad( 498 | output, param_list, grad_outputs=HJv, retain_graph=True) 499 | Gv += torch.cat([j.detach().view(-1) for j in JHJv]) 500 | # model.eval() 501 | return Gv 502 | # return torch.cat([param.grad.view(-1) for param in param_list]).view(-1) 503 | 504 | 505 | # Xingchen Wan code addition - 20 Nov 2019 506 | def gn_noise_vec(vector, full_loader, batch_loader, model, criterion, cuda=True, bn_train_mode=True): 507 | """Compute the matrix-vector product between the GN noise matrix""" 508 | full_gn_vec_prod = gn_vec(vector, full_loader, model, criterion, cuda=cuda, bn_train_mode=bn_train_mode) 509 | batch_gn_vec_prod = gn_vec(vector, batch_loader, model, criterion, cuda=cuda, bn_train_mode=bn_train_mode) 510 | return full_gn_vec_prod - batch_gn_vec_prod 511 | 512 | 513 | def R_op(y, x, v): 514 | """ 515 | Compute the Jacobian-vector product (dy_i/dx_j)v_j. R-operator using the two backward diff trick 516 | :return: 517 | """ 518 | if isinstance(y, tuple): 519 | ws = [torch.zeros_like(y_i).requires_grad_(True) for y_i in y] 520 | else: 521 | ws = torch.zeros_like(y).requires_grad_(True) 522 | jacobian = torch.autograd.grad(y, x, grad_outputs=ws, create_graph=True) 523 | Jv = torch.autograd.grad(jacobian, ws, grad_outputs=v, retain_graph=True) 524 | return tuple([j.detach() for j in Jv]) 525 | 526 | 527 | def _gn_vec(model, loss, output, vec, ): 528 | """Compute the Gauss-newton vector product 529 | """ 530 | views = [] 531 | offset = 0 532 | param_list = list(model.parameters()) 533 | for param in param_list: 534 | views.append(vec[offset:offset + param.numel()].detach().view_as(param).to(param.device)) 535 | offset += param.numel() 536 | 537 | vec_ = list(views) 538 | 539 | Jv = R_op(output, param_list, vec_) 540 | 541 | gradient = torch.autograd.grad(loss, output, create_graph=True) 542 | HJv = R_op(gradient, output, Jv) 543 | JHJv = torch.autograd.grad( 544 | output, param_list, grad_outputs=HJv, retain_graph=True) 545 | Gv = torch.cat([j.detach().flatten() for j in JHJv]) 546 | return Gv 547 | 548 | # Xingchen Wan code addition ends 549 | 550 | 551 | def shrinkage(loader, model, criterion, cuda=True, batch_loader=None, bn_train_mode=True, verbose=True): 552 | param_list = list(model.parameters()) 553 | num_parameters = sum(p.numel() for p in param_list) 554 | 555 | z = torch.randn(num_parameters).to(param_list[0].device) 556 | z /= torch.sqrt(torch.sum(z ** 2)) 557 | 558 | H_z = hess_vec( 559 | z, 560 | batch_loader if batch_loader is not None else loader, 561 | model, 562 | criterion, 563 | cuda=cuda, 564 | bn_train_mode=bn_train_mode 565 | ) 566 | 567 | mean_value = torch.sum(z * H_z) 568 | 569 | beta = torch.sum((H_z - z * mean_value) ** 2).cpu() 570 | 571 | z_list = [] 572 | offset = 0 573 | for param in param_list: 574 | z_list.append(z[offset:offset + param.numel()].detach().view_as(param).to(param.device)) 575 | offset += param.numel() 576 | 577 | model.eval() 578 | if bn_train_mode: 579 | raise NotImplementedError 580 | model.apply(_bn_train_mode) 581 | 582 | gamma = torch.zeros(1) 583 | 584 | num_batches = len(loader) 585 | for input, target in tqdm.tqdm(loader): 586 | 587 | model.zero_grad() 588 | if cuda: 589 | input = input.cuda(non_blocking=True) 590 | target = target.cuda(non_blocking=True) 591 | loss, _, _ = criterion(model, input, target) 592 | 593 | grad_list = torch.autograd.grad(loss, param_list, create_graph=True) 594 | 595 | dL_dvec = torch.zeros(1) 596 | if cuda: 597 | dL_dvec = dL_dvec.cuda() 598 | for v, g in zip(z_list, grad_list): 599 | dL_dvec += torch.sum(v * g) 600 | dL_dvec.backward() 601 | 602 | H_z_i = torch.cat([p.grad.view(-1) for p in param_list]) 603 | gamma += (torch.sum((H_z - H_z_i) ** 2)).cpu() / num_batches 604 | model.eval() 605 | return 1.0 - beta / torch.max(beta, gamma), mean_value, beta, gamma 606 | 607 | 608 | # Xingchen Wan code modification: 609 | def loss_stats_old(loader, model, criterion, cuda=True, bn_train_mode=True, verbose=False, curvature_matrix='hessian'): 610 | param_list = list(model.parameters()) 611 | num_parameters = sum(p.numel() for p in param_list) 612 | 613 | model.eval() 614 | if bn_train_mode: 615 | # raise NotImplementedError 616 | model.apply(_bn_train_mode) 617 | 618 | loss_mean = torch.zeros(1) 619 | loss_sq_mean = torch.zeros(1) 620 | 621 | grad_mean = torch.zeros(num_parameters) 622 | grad_norm_sq_mean = torch.zeros(1) 623 | 624 | z = torch.randn(num_parameters) 625 | z /= torch.sqrt(torch.sum(z ** 2)) 626 | 627 | H_z_mean = torch.zeros(num_parameters) 628 | H_z_norm_sq_mean = torch.zeros(1) 629 | 630 | if cuda: 631 | grad_mean = grad_mean.cuda() 632 | z = z.cuda() 633 | H_z_mean = H_z_mean.cuda() 634 | 635 | num_batches = len(loader) 636 | if verbose: 637 | loader = tqdm.tqdm(loader) 638 | for input, target in loader: 639 | model.zero_grad() 640 | if cuda: 641 | input = input.cuda(non_blocking=True) 642 | target = target.cuda(non_blocking=True) 643 | loss, _, _ = criterion(model, input, target) 644 | 645 | grad_list = torch.autograd.grad(loss, param_list, create_graph=True) 646 | grad_i = torch.cat([g.view(-1) for g in grad_list]) 647 | 648 | dL_dz = torch.sum(grad_i * z) 649 | dL_dz.backward() 650 | 651 | H_z_i = torch.cat([p.grad.detach().view(-1) for p in param_list]) 652 | grad_i = grad_i.detach() 653 | 654 | loss_mean += loss.cpu() / num_batches 655 | loss_sq_mean += loss.cpu() ** 2 / num_batches 656 | 657 | grad_mean += grad_i / num_batches 658 | grad_norm_sq_mean += torch.sum(grad_i ** 2).cpu() / num_batches 659 | 660 | H_z_mean += H_z_i / num_batches 661 | H_z_norm_sq_mean += torch.sum(H_z_i ** 2).cpu() / num_batches 662 | 663 | model.eval() 664 | 665 | loss_var = loss_sq_mean - loss_mean ** 2 666 | 667 | grad_mean_norm_sq = torch.sum(grad_mean ** 2).cpu() 668 | grad_var = grad_norm_sq_mean - grad_mean_norm_sq 669 | 670 | H_z_mean_norm_sq = torch.sum(H_z_mean ** 2).cpu() 671 | hess_var = H_z_norm_sq_mean - H_z_mean_norm_sq 672 | 673 | hess_mu = torch.sum(z * H_z_mean).cpu() 674 | delta = torch.sum((H_z_mean - z * hess_mu.item()) ** 2).cpu() 675 | alpha = torch.max(torch.tensor(0.0), 1.0 - hess_var / num_batches / delta) 676 | 677 | return { 678 | 'loss_mean': loss_mean, 679 | 'loss_var': loss_var, 680 | 'grad_mean_norm_sq': grad_mean_norm_sq, 681 | 'grad_var': grad_var, 682 | 'hess_mean_norm_sq': H_z_mean_norm_sq, 683 | 'hess_var': hess_var, 684 | 'hess_mu': hess_mu, 685 | 'delta': delta, 686 | 'alpha': alpha 687 | } 688 | 689 | 690 | def loss_stats(loader, model, criterion, cuda=True, bn_train_mode=True, verbose=False, curvature_matrix='hessian'): 691 | """ 692 | Compute and save the loss_stats 693 | :param loader: 694 | :param model: 695 | :param criterion: 696 | :param cuda: 697 | :param bn_train_mode: 698 | :param verbose: 699 | :param curvature_matrix: select the curvature matrix to be used. Available options: 700 | 'hessian' - Hessian matrix 701 | 'gn' - Gauss-Newton matrix 702 | Other curvature_matrix argument input will result in a ValueError. 703 | :return: 704 | Note: for the sake of compatibility, in the final dictionary returned, regardless of the type of curvature matrix used 705 | the column names will be hess_*, etc. 706 | """ 707 | param_list = list(model.parameters()) 708 | num_parameters = sum(p.numel() for p in param_list) 709 | z = torch.randn(num_parameters) 710 | z /= torch.sqrt(torch.sum(z ** 2)) 711 | 712 | vector_list = [] 713 | offset = 0 714 | for param in param_list: 715 | vector_list.append(z[offset:offset + param.numel()].detach().view_as(param).to(param.device)) 716 | 717 | model.eval() 718 | if bn_train_mode: 719 | model.apply(_bn_train_mode) 720 | 721 | loss_mean = torch.zeros(1) 722 | loss_sq_mean = torch.zeros(1) 723 | 724 | grad_mean = torch.zeros(num_parameters) 725 | grad_norm_sq_mean = torch.zeros(1) 726 | 727 | 728 | H_z_mean = torch.zeros(num_parameters) 729 | H_z_norm_sq_mean = torch.zeros(1) 730 | 731 | if cuda: 732 | grad_mean = grad_mean.cuda() 733 | z = z.cuda() 734 | H_z_mean = H_z_mean.cuda() 735 | 736 | num_batches = len(loader) 737 | if verbose: 738 | loader = tqdm.tqdm(loader) 739 | for input, target in loader: 740 | model.zero_grad() 741 | if cuda: 742 | input = input.cuda(non_blocking=True) 743 | target = target.cuda(non_blocking=True) 744 | loss, output, _ = criterion(model, input, target) 745 | 746 | grad_list = torch.autograd.grad(loss, param_list, create_graph=True) 747 | grad_i = torch.cat([g.view(-1) for g in grad_list]) 748 | 749 | if curvature_matrix == 'hessian': 750 | dL_dz = torch.sum(grad_i * z) 751 | dL_dz.backward() 752 | H_z_i = torch.cat([p.grad.detach().view(-1) for p in param_list]) 753 | 754 | elif curvature_matrix == 'gn': 755 | Jv = R_op(output, param_list, vector_list) 756 | grad = torch.autograd.grad(loss, output, create_graph=True) 757 | HJv = R_op(grad, output, Jv) 758 | JHJv = torch.autograd.grad( 759 | output, param_list, grad_outputs=HJv, retain_graph=False) 760 | H_z_i = torch.cat([j.detach().view(-1) for j in JHJv]) 761 | 762 | else: 763 | raise ValueError('Invalid curvature matrix'+curvature_matrix) 764 | grad_i = grad_i.detach() 765 | loss_mean += loss.cpu() / num_batches 766 | loss_sq_mean += loss.cpu() ** 2 / num_batches 767 | 768 | grad_mean += grad_i / num_batches 769 | grad_norm_sq_mean += torch.sum(grad_i ** 2).cpu() / num_batches 770 | H_z_mean += H_z_i / num_batches 771 | H_z_norm_sq_mean += torch.sum(H_z_i ** 2).cpu() / num_batches 772 | model.eval() 773 | 774 | loss_var = loss_sq_mean - loss_mean ** 2 775 | 776 | grad_mean_norm_sq = torch.sum(grad_mean ** 2).cpu() 777 | grad_var = grad_norm_sq_mean - grad_mean_norm_sq 778 | 779 | H_z_mean_norm_sq = torch.sum(H_z_mean ** 2).cpu() 780 | hess_var = H_z_norm_sq_mean - H_z_mean_norm_sq 781 | 782 | hess_mu = torch.sum(z * H_z_mean).cpu() 783 | delta = torch.sum((H_z_mean - z * hess_mu.item()) ** 2).cpu() 784 | alpha = torch.max(torch.tensor(0.0), 1.0 - hess_var / num_batches / delta) 785 | 786 | return { 787 | 'loss_mean': loss_mean, 788 | 'loss_var': loss_var, 789 | 'grad_mean_norm_sq': grad_mean_norm_sq, 790 | 'grad_var': grad_var, 791 | 'hess_mean_norm_sq': H_z_mean_norm_sq, 792 | 'hess_var': hess_var, 793 | 'hess_mu': hess_mu, 794 | 'delta': delta, 795 | 'alpha': alpha 796 | } 797 | 798 | 799 | def grad(loader, model, criterion, cuda=True, bn_train_mode=False): 800 | model.eval() 801 | if bn_train_mode: 802 | raise NotImplementedError 803 | model.apply(_bn_train_mode) 804 | 805 | model.zero_grad() 806 | N = len(loader.dataset) 807 | for input, target in loader: 808 | if cuda: 809 | input = input.cuda(non_blocking=True) 810 | target = target.cuda(non_blocking=True) 811 | loss, _, _ = criterion(model, input, target) 812 | loss *= input.size()[0] / N 813 | loss.backward() 814 | 815 | return torch.cat([param.grad.view(-1) for param in model.parameters()]).view(-1) 816 | 817 | 818 | def loss_stats_layerwise(loader, model, criterion, cuda=True, bn_train_mode=True, verbose=False): 819 | param_list = list(model.parameters()) 820 | num_parameters = sum(p.numel() for p in param_list) 821 | 822 | model.eval() 823 | if bn_train_mode: 824 | raise NotImplementedError 825 | model.apply(_bn_train_mode) 826 | 827 | z_list = [] 828 | H_z_mean_list = [] 829 | H_z_mean_norm_sq_list = [] 830 | 831 | for param in param_list: 832 | z = torch.randn(param.size()) 833 | z /= torch.sqrt(torch.sum(z ** 2)) 834 | z = z.to(param.device) 835 | z_list.append(z) 836 | H_z_mean_list.append(torch.zeros_like(param)) 837 | H_z_mean_norm_sq_list.append(torch.zeros(1).to(param.device)) 838 | 839 | num_batches = len(loader) 840 | if verbose: 841 | loader = tqdm.tqdm(loader) 842 | for input, target in loader: 843 | if cuda: 844 | input = input.cuda(non_blocking=True) 845 | target = target.cuda(non_blocking=True) 846 | loss, _, _ = criterion(model, input, target) 847 | 848 | grad_list = torch.autograd.grad(loss, param_list, create_graph=True) 849 | 850 | for param, grad, z, H_z_mean, H_z_mean_norm_sq in zip(param_list, grad_list, z_list, H_z_mean_list, H_z_mean_norm_sq_list): 851 | if param.grad is not None: 852 | param.grad.detach_() 853 | param.grad.zero_() 854 | 855 | dL_dz = torch.sum(grad * z) 856 | dL_dz.backward(retain_graph=True) 857 | 858 | H_z = param.grad 859 | H_z_mean += H_z / num_batches 860 | H_z_mean_norm_sq += torch.sum(H_z ** 2) / num_batches 861 | 862 | alpha_list = [] 863 | delta_list = [] 864 | hess_mu_list = [] 865 | hess_var_list = [] 866 | for z, H_z_mean, H_z_norm_sq_mean in zip(z_list, H_z_mean_list, H_z_mean_norm_sq_list): 867 | hess_mu = torch.sum(z * H_z_mean) 868 | hess_var = (H_z_norm_sq_mean - torch.sum(H_z_mean ** 2)) 869 | 870 | delta = torch.sum((H_z_mean - hess_mu * z) ** 2) 871 | alpha = torch.max(torch.tensor(0.0), (1.0 - hess_var / num_batches / delta).cpu()) 872 | 873 | hess_mu_list.append(hess_mu.cpu()) 874 | hess_var_list.append(hess_var.cpu()) 875 | delta_list.append(delta.cpu()) 876 | alpha_list.append(alpha.cpu()) 877 | 878 | model.eval() 879 | 880 | return { 881 | 'hess_mean_norm_sq_list': H_z_mean_norm_sq_list, 882 | 'hess_var_list': hess_var_list, 883 | 'hess_mu_list': hess_mu_list, 884 | 'delta_list': delta_list, 885 | 'alpha_list': alpha_list 886 | } 887 | 888 | 889 | # XW addition 890 | def save_weight_norm(dir, index, name, model): 891 | """Save the L2 and L-inf norms of the weights of a model""" 892 | filepath = os.path.join(dir, '%s-%05d.pt' % (name, index)) 893 | 894 | w = torch.cat([param.detach().cpu().view(-1) for param in model.parameters()]) 895 | l2_norm = torch.norm(w).numpy() 896 | linf_norm = torch.norm(w, float('inf')).numpy() 897 | np.savez( 898 | filepath, 899 | l2_norms=l2_norm, 900 | linf_norms=linf_norm 901 | ) --------------------------------------------------------------------------------