├── .gitignore ├── LICENSE ├── README.md ├── example └── main.py ├── hessian_eigenthings ├── __init__.py ├── hvp_operator.py ├── lanczos.py ├── operator.py ├── power_iter.py └── utils.py ├── setup.py └── tests ├── principle_eigenvec_tests.py ├── random_matrix_tests.py ├── utils.py └── variance_tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Noah Golmant 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-hessian-eigenthings 2 | 3 | The `hessian-eigenthings` module provides an efficient (and scalable!) way to compute the eigendecomposition of the Hessian for an arbitrary PyTorch model. It uses PyTorch's Hessian-vector product and your choice of (a) the Lanczos method or (b) stochastic power iteration with deflation in order to compute the top eigenvalues and eigenvectors of the Hessian. 4 | 5 | ## Why use this? 6 | 7 | The eigenvalues and eigenvectors of the Hessian have been implicated in many generalization properties of neural networks. For example, many people hypothesize that "flat minima" with lower eigenvalues generalize better, that the Hessians of large models are very low-rank, and that certain optimization algorithms may lead to flatter or sharper minima. However, computing and storing the full Hessian requires memory that is quadratic in the number of parameters, which is infeasible for anything but toy problems. 8 | 9 | Iterative methods like Lanczos and power iteration can be used to find the eigendecomposition of arbitrary linear operators given access to a matrix-vector multiplication function. The Hessian-vector product (HVP) is the matrix-vector multiplication between the Hessian and an arbitrary vector *v*. It can be computed with linear memory usage by taking the derivative of the inner product between the gradient and *v*. So this library combines the Hessian-vector product computation with these iterative methods to compute the eigendecomposition without the quadratic memory bottleneck. 10 | 11 | You can use this library for Hessian-vector product computation, the more general eigendecomposition routines for linear operators, or the conjunction of the two for Hessian spectrum analysis. 12 | 13 | ## Installation 14 | 15 | For now, you have to install from this repo. It's a tiny thing so why put it on pypi. 16 | 17 | `pip install --upgrade git+https://github.com/noahgolmant/pytorch-hessian-eigenthings.git@master#egg=hessian-eigenthings` 18 | 19 | ## Usage 20 | 21 | The main function you're probably interested in is `compute_hessian_eigenthings`. 22 | Sample usage is like so: 23 | 24 | ``` 25 | import torch 26 | from hessian_eigenthings import compute_hessian_eigenthings 27 | 28 | model = ResNet18() 29 | dataloader = ... 30 | loss = torch.nn.functional.cross_entropy 31 | 32 | num_eigenthings = 20 # compute top 20 eigenvalues/eigenvectors 33 | 34 | eigenvals, eigenvecs = compute_hessian_eigenthings(model, dataloader, 35 | loss, num_eigenthings) 36 | ``` 37 | 38 | This also includes a more general power iteration with deflation implementation in `power_iter.py`. `lanczos.py` calls a [`scipy` hook](https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.linalg.eigsh.html) to a battle-tested ARPACK implementation. 39 | 40 | ## Example file 41 | 42 | The example file in `example/main.py` utilizes [`skeletor`](https://github.com/noahgolmant/skeletor) version `0.1.4` for experiment orchestration, which can be installed via `pip install skeletor-ml`, but the rest of this library does not depend on it. You can execute the example via a command like `python example/main.py --mode=power_iter `, where `` is a useful name like `resnet18_cifar10`. But it may just be easier to use a simpler codebase to instantiate PyTorch models and dataloaders (such as [`pytorch-cifar`](https://github.com/kuangliu/pytorch-cifar)). 43 | 44 | ## Citing this work 45 | If you find this repo useful and would like to cite it in a publication (as [others](https://scholar.google.com/scholar?oi=bibs&hl=en&cites=18039594054930134223) have done, thank you!), here is a BibTeX entry: 46 | 47 | @misc{hessian-eigenthings, 48 | author = {Noah Golmant, Zhewei Yao, Amir Gholami, Michael Mahoney, Joseph Gonzalez}, 49 | title = {pytorch-hessian-eigenthings: efficient PyTorch Hessian eigendecomposition}, 50 | month = oct, 51 | year = 2018, 52 | version = {1.0}, 53 | url = {https://github.com/noahgolmant/pytorch-hessian-eigenthings} 54 | } 55 | 56 | 57 | ## Acknowledgements 58 | 59 | This code was written in collaboration with Zhewei Yao, Amir Gholami, Michael Mahoney, and Joseph Gonzalez in UC Berkeley's [RISELab](https://rise.cs.berkeley.edu). 60 | 61 | The deflated power iteration routine is based on code in the [HessianFlow](https://github.com/amirgholami/HessianFlow) repository recently described in the following paper: Z. Yao, A. Gholami, Q. Lei, K. Keutzer, M. Mahoney. "Hessian-based Analysis of Large Batch Training and Robustness to Adversaries", *NIPS'18* ([arXiv:1802.08241](https://arxiv.org/abs/1802.08241)) 62 | 63 | Stochastic power iteration with acceleration is based on the following paper: C. De Sa, B. He, I. Mitliagkas, C. Ré, P. Xu. "Accelerated Stochastic Power Iteration", *PMLR-21* ([arXiv:1707.02670](https://arxiv.org/abs/1707.02670)) 64 | -------------------------------------------------------------------------------- /example/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple example to calculate the top eigenvectors for the hessian of 3 | ResNet18 network for CIFAR-10 4 | """ 5 | 6 | import track 7 | import skeletor 8 | from skeletor.datasets import build_dataset 9 | from skeletor.models import build_model 10 | 11 | import torch 12 | 13 | from hessian_eigenthings import compute_hessian_eigenthings 14 | 15 | 16 | def extra_args(parser): 17 | parser.add_argument( 18 | "--num_eigenthings", 19 | default=5, 20 | type=int, 21 | help="number of eigenvals/vecs to compute", 22 | ) 23 | parser.add_argument( 24 | "--batch_size", default=128, type=int, help="train set batch size" 25 | ) 26 | parser.add_argument( 27 | "--eval_batch_size", default=16, type=int, help="test set batch size" 28 | ) 29 | parser.add_argument( 30 | "--momentum", default=0.0, type=float, help="power iteration momentum term" 31 | ) 32 | parser.add_argument( 33 | "--num_steps", default=50, type=int, help="number of power iter steps" 34 | ) 35 | parser.add_argument("--max_samples", default=2048, type=int) 36 | parser.add_argument("--cuda", action="store_true", help="if true, use CUDA/GPUs") 37 | parser.add_argument( 38 | "--full_dataset", 39 | action="store_true", 40 | help="if true,\ 41 | loop over all batches in set for each gradient step", 42 | ) 43 | parser.add_argument("--fname", default="", type=str) 44 | parser.add_argument("--mode", type=str, choices=["power_iter", "lanczos"]) 45 | 46 | 47 | def main(args): 48 | trainloader, testloader = build_dataset( 49 | "cifar10", 50 | dataroot=args.dataroot, 51 | batch_size=args.batch_size, 52 | eval_batch_size=args.eval_batch_size, 53 | num_workers=2, 54 | ) 55 | if args.fname: 56 | print("Loading model from %s" % args.fname) 57 | model = torch.load(args.fname, map_location="cpu").cuda() 58 | else: 59 | model = build_model("ResNet18", num_classes=10) 60 | criterion = torch.nn.CrossEntropyLoss() 61 | eigenvals, eigenvecs = compute_hessian_eigenthings( 62 | model, 63 | testloader, 64 | criterion, 65 | args.num_eigenthings, 66 | mode=args.mode, 67 | # power_iter_steps=args.num_steps, 68 | max_possible_gpu_samples=args.max_samples, 69 | # momentum=args.momentum, 70 | full_dataset=args.full_dataset, 71 | use_gpu=args.cuda, 72 | ) 73 | print("Eigenvecs:") 74 | print(eigenvecs) 75 | print("Eigenvals:") 76 | print(eigenvals) 77 | # track.metric(iteration=0, eigenvals=eigenvals) 78 | 79 | 80 | if __name__ == "__main__": 81 | skeletor.supply_args(extra_args) 82 | skeletor.execute(main) 83 | -------------------------------------------------------------------------------- /hessian_eigenthings/__init__.py: -------------------------------------------------------------------------------- 1 | """ Top-level module for hessian eigenvec computation """ 2 | from hessian_eigenthings.power_iter import power_iteration, deflated_power_iteration 3 | from hessian_eigenthings.lanczos import lanczos 4 | from hessian_eigenthings.hvp_operator import HVPOperator 5 | 6 | name = "hessian_eigenthings" 7 | 8 | 9 | def compute_hessian_eigenthings( 10 | model, 11 | dataloader, 12 | loss, 13 | num_eigenthings=10, 14 | full_dataset=True, 15 | mode="power_iter", 16 | use_gpu=True, 17 | fp16=False, 18 | max_possible_gpu_samples=2 ** 16, 19 | **kwargs 20 | ): 21 | """ 22 | Computes the top `num_eigenthings` eigenvalues and eigenvecs 23 | for the hessian of the given model by using subsampled power iteration 24 | with deflation and the hessian-vector product 25 | 26 | Parameters 27 | --------------- 28 | 29 | model : Module 30 | pytorch model for this netowrk 31 | dataloader : torch.data.DataLoader 32 | dataloader with x,y pairs for which we compute the loss. 33 | loss : torch.nn.modules.Loss | torch.nn.functional criterion 34 | loss function to differentiate through 35 | num_eigenthings : int 36 | number of eigenvalues/eigenvecs to compute. computed in order of 37 | decreasing eigenvalue magnitude. 38 | full_dataset : boolean 39 | if true, each power iteration call evaluates the gradient over the 40 | whole dataset. 41 | (if False, you might want to check if the eigenvalue estimate variance 42 | depends on batch size) 43 | mode : str ['power_iter', 'lanczos'] 44 | which backend algorithm to use to compute the top eigenvalues. 45 | use_gpu: 46 | if true, attempt to use cuda for all lin alg computatoins 47 | fp16: bool 48 | if true, store and do math with eigenvectors, gradients, etc. in fp16. 49 | (you should test if this is numerically stable for your application) 50 | max_possible_gpu_samples: 51 | the maximum number of samples that can fit on-memory. used 52 | to accumulate gradients for large batches. 53 | (note: if smaller than dataloader batch size, this can have odd 54 | interactions with batch norm statistics) 55 | **kwargs: 56 | contains additional parameters passed onto lanczos or power_iter. 57 | """ 58 | hvp_operator = HVPOperator( 59 | model, 60 | dataloader, 61 | loss, 62 | use_gpu=use_gpu, 63 | full_dataset=full_dataset, 64 | max_possible_gpu_samples=max_possible_gpu_samples, 65 | ) 66 | eigenvals, eigenvecs = None, None 67 | if mode == "power_iter": 68 | eigenvals, eigenvecs = deflated_power_iteration( 69 | hvp_operator, num_eigenthings, use_gpu=use_gpu, fp16=fp16, **kwargs 70 | ) 71 | elif mode == "lanczos": 72 | eigenvals, eigenvecs = lanczos( 73 | hvp_operator, num_eigenthings, use_gpu=use_gpu, fp16=fp16, **kwargs 74 | ) 75 | else: 76 | raise ValueError("Unsupported mode %s (must be power_iter or lanczos)" % mode) 77 | return eigenvals, eigenvecs 78 | 79 | 80 | __all__ = [ 81 | "power_iteration", 82 | "deflated_power_iteration", 83 | "lanczos", 84 | "HVPOperator", 85 | "compute_hessian_eigenthings", 86 | ] 87 | -------------------------------------------------------------------------------- /hessian_eigenthings/hvp_operator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines a linear operator to compute the hessian-vector product 3 | for a given pytorch model using subsampled data. 4 | """ 5 | 6 | from typing import Callable 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.data as data 12 | 13 | 14 | import hessian_eigenthings.utils as utils 15 | 16 | from hessian_eigenthings.operator import Operator 17 | 18 | 19 | class HVPOperator(Operator): 20 | """ 21 | Use PyTorch autograd for Hessian Vec product calculation 22 | model: PyTorch network to compute hessian for 23 | dataloader: pytorch dataloader that we get examples from to compute grads 24 | loss: Loss function to descend (e.g. F.cross_entropy) 25 | use_gpu: use cuda or not 26 | max_possible_gpu_samples: max number of examples per batch using all GPUs. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | model: nn.Module, 32 | dataloader: data.DataLoader, 33 | criterion: Callable[[torch.Tensor], torch.Tensor], 34 | use_gpu: bool = True, 35 | fp16: bool = False, 36 | full_dataset: bool = True, 37 | max_possible_gpu_samples: int = 256, 38 | ): 39 | size = int(sum(p.numel() for p in model.parameters())) 40 | super(HVPOperator, self).__init__(size) 41 | self.grad_vec = torch.zeros(size) 42 | self.model = model 43 | if use_gpu: 44 | self.model = self.model.cuda() 45 | self.dataloader = dataloader 46 | # Make a copy since we will go over it a bunch 47 | self.dataloader_iter = iter(dataloader) 48 | self.criterion = criterion 49 | self.use_gpu = use_gpu 50 | self.fp16 = fp16 51 | self.full_dataset = full_dataset 52 | self.max_possible_gpu_samples = max_possible_gpu_samples 53 | 54 | if not hasattr(self.dataloader, '__len__') and self.full_dataset: 55 | raise ValueError("For full-dataset averaging, dataloader must have '__len__'") 56 | 57 | def apply(self, vec: torch.Tensor): 58 | """ 59 | Returns H*vec where H is the hessian of the loss w.r.t. 60 | the vectorized model parameters 61 | """ 62 | if self.full_dataset: 63 | return self._apply_full(vec) 64 | else: 65 | return self._apply_batch(vec) 66 | 67 | def _apply_batch(self, vec: torch.Tensor) -> torch.Tensor: 68 | """ 69 | Computes the Hessian-vector product for a mini-batch from the dataset. 70 | """ 71 | # compute original gradient, tracking computation graph 72 | self._zero_grad() 73 | grad_vec = self._prepare_grad() 74 | self._zero_grad() 75 | # take the second gradient 76 | # this is the derivative of where <,> is an inner product. 77 | hessian_vec_prod_dict = torch.autograd.grad( 78 | grad_vec, self.model.parameters(), grad_outputs=vec, only_inputs=True 79 | ) 80 | # concatenate the results over the different components of the network 81 | hessian_vec_prod = torch.cat([g.contiguous().view(-1) for g in hessian_vec_prod_dict]) 82 | hessian_vec_prod = utils.maybe_fp16(hessian_vec_prod, self.fp16) 83 | return hessian_vec_prod 84 | 85 | def _apply_full(self, vec: torch.Tensor) -> torch.Tensor: 86 | """ 87 | Computes the Hessian-vector product averaged over all batches in the dataset. 88 | 89 | """ 90 | n = len(self.dataloader) 91 | hessian_vec_prod = None 92 | for _ in range(n): 93 | if hessian_vec_prod is not None: 94 | hessian_vec_prod += self._apply_batch(vec) 95 | else: 96 | hessian_vec_prod = self._apply_batch(vec) 97 | hessian_vec_prod = hessian_vec_prod / n 98 | return hessian_vec_prod 99 | 100 | def _zero_grad(self): 101 | """ 102 | Zeros out the gradient info for each parameter in the model 103 | """ 104 | for p in self.model.parameters(): 105 | if p.grad is not None: 106 | p.grad.data.zero_() 107 | 108 | def _prepare_grad(self) -> torch.Tensor: 109 | """ 110 | Compute gradient w.r.t loss over all parameters and vectorize 111 | """ 112 | try: 113 | all_inputs, all_targets = next(self.dataloader_iter) 114 | except StopIteration: 115 | self.dataloader_iter = iter(self.dataloader) 116 | all_inputs, all_targets = next(self.dataloader_iter) 117 | 118 | num_chunks = max(1, len(all_inputs) // self.max_possible_gpu_samples) 119 | 120 | grad_vec = None 121 | 122 | # This will do the "gradient chunking trick" to create micro-batches 123 | # when the batch size is larger than what will fit in memory. 124 | # WARNING: this may interact poorly with batch normalization. 125 | 126 | input_microbatches = all_inputs.chunk(num_chunks) 127 | target_microbatches = all_targets.chunk(num_chunks) 128 | for input, target in zip(input_microbatches, target_microbatches): 129 | if self.use_gpu: 130 | input = input.cuda() 131 | target = target.cuda() 132 | 133 | output = self.model(input) 134 | loss = self.criterion(output, target) 135 | grad_dict = torch.autograd.grad( 136 | loss, self.model.parameters(), create_graph=True 137 | ) 138 | if grad_vec is not None: 139 | grad_vec += torch.cat([g.contiguous().view(-1) for g in grad_dict]) 140 | else: 141 | grad_vec = torch.cat([g.contiguous().view(-1) for g in grad_dict]) 142 | grad_vec = utils.maybe_fp16(grad_vec, self.fp16) 143 | grad_vec /= num_chunks 144 | self.grad_vec = grad_vec 145 | return self.grad_vec 146 | -------------------------------------------------------------------------------- /hessian_eigenthings/lanczos.py: -------------------------------------------------------------------------------- 1 | """ Use scipy/ARPACK implicitly restarted lanczos to find top k eigenthings """ 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import scipy.sparse.linalg as linalg 7 | from scipy.sparse.linalg import LinearOperator as ScipyLinearOperator 8 | from warnings import warn 9 | 10 | import hessian_eigenthings.utils as utils 11 | 12 | from hessian_eigenthings.operator import Operator 13 | 14 | 15 | def lanczos( 16 | operator: Operator, 17 | num_eigenthings: int =10, 18 | which: str ="LM", 19 | max_steps: int =20, 20 | tol: float =1e-6, 21 | num_lanczos_vectors: int =None, 22 | init_vec: np.ndarray =None, 23 | use_gpu: bool =False, 24 | fp16: bool =False, 25 | ) -> Tuple[np.ndarray, np.ndarray]: 26 | """ 27 | Use the scipy.sparse.linalg.eigsh hook to the ARPACK lanczos algorithm 28 | to find the top k eigenvalues/eigenvectors. 29 | 30 | Please see scipy documentation for details on specific parameters 31 | such as 'which'. 32 | 33 | Parameters 34 | ------------- 35 | operator: operator.Operator 36 | linear operator to solve. 37 | num_eigenthings : int 38 | number of eigenvalue/eigenvector pairs to compute 39 | which : str ['LM', SM', 'LA', SA'] 40 | L,S = largest, smallest. M, A = in magnitude, algebriac 41 | SM = smallest in magnitude. LA = largest algebraic. 42 | max_steps : int 43 | maximum number of arnoldi updates 44 | tol : float 45 | relative accuracy of eigenvalues / stopping criterion 46 | num_lanczos_vectors : int 47 | number of lanczos vectors to compute. if None, > 2*num_eigenthings 48 | for stability. 49 | init_vec: [torch.Tensor, torch.cuda.Tensor] 50 | if None, use random tensor. this is the init vec for arnoldi updates. 51 | use_gpu: bool 52 | if true, use cuda tensors. 53 | fp16: bool 54 | if true, keep operator input/output in fp16 instead of fp32. 55 | 56 | Returns 57 | ---------------- 58 | eigenvalues : np.ndarray 59 | array containing `num_eigenthings` eigenvalues of the operator 60 | eigenvectors : np.ndarray 61 | array containing `num_eigenthings` eigenvectors of the operator 62 | """ 63 | if isinstance(operator.size, int): 64 | size = operator.size 65 | else: 66 | size = operator.size[0] 67 | shape = (size, size) 68 | 69 | if num_lanczos_vectors is None: 70 | num_lanczos_vectors = min(2 * num_eigenthings, size - 1) 71 | if num_lanczos_vectors < 2 * num_eigenthings: 72 | warn( 73 | "[lanczos] number of lanczos vectors should usually be > 2*num_eigenthings" 74 | ) 75 | 76 | def _scipy_apply(x): 77 | x = torch.from_numpy(x) 78 | x = utils.maybe_fp16(x, fp16) 79 | if use_gpu: 80 | x = x.cuda() 81 | out = operator.apply(x) 82 | out = utils.maybe_fp16(out, fp16) 83 | out = out.cpu().numpy() 84 | return out 85 | 86 | scipy_op = ScipyLinearOperator(shape, _scipy_apply) 87 | if init_vec is None: 88 | init_vec = np.random.rand(size) 89 | 90 | eigenvals, eigenvecs = linalg.eigsh( 91 | A=scipy_op, 92 | k=num_eigenthings, 93 | which=which, 94 | maxiter=max_steps, 95 | tol=tol, 96 | ncv=num_lanczos_vectors, 97 | return_eigenvectors=True, 98 | ) 99 | return eigenvals, eigenvecs.T 100 | -------------------------------------------------------------------------------- /hessian_eigenthings/operator.py: -------------------------------------------------------------------------------- 1 | """ Basic linear operator abstractions """ 2 | 3 | 4 | class Operator: 5 | """ 6 | maps x -> Lx for a linear operator L 7 | """ 8 | 9 | def __init__(self, size): 10 | self.size = size 11 | 12 | def apply(self, vec): 13 | """ 14 | Function mapping vec -> L vec where L is a linear operator 15 | """ 16 | raise NotImplementedError 17 | 18 | 19 | class LambdaOperator(Operator): 20 | """ 21 | Linear operator based on a provided lambda function 22 | """ 23 | 24 | def __init__(self, apply_fn, size): 25 | super(LambdaOperator, self).__init__(size) 26 | self.apply_fn = apply_fn 27 | 28 | def apply(self, x): 29 | return self.apply_fn(x) 30 | -------------------------------------------------------------------------------- /hessian_eigenthings/power_iter.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains functions to perform power iteration with deflation 3 | to compute the top eigenvalues and eigenvectors of a linear operator 4 | """ 5 | from typing import Tuple 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from hessian_eigenthings.operator import Operator, LambdaOperator 11 | import hessian_eigenthings.utils as utils 12 | 13 | 14 | def deflated_power_iteration( 15 | operator: Operator, 16 | num_eigenthings: int = 10, 17 | power_iter_steps: int = 20, 18 | power_iter_err_threshold: float = 1e-4, 19 | momentum: float = 0.0, 20 | use_gpu: bool = True, 21 | fp16: bool = False, 22 | to_numpy: bool = True, 23 | ) -> Tuple[np.ndarray, np.ndarray]: 24 | """ 25 | Compute top k eigenvalues by repeatedly subtracting out dyads 26 | operator: linear operator that gives us access to matrix vector product 27 | num_eigenvals number of eigenvalues to compute 28 | power_iter_steps: number of steps per run of power iteration 29 | power_iter_err_threshold: early stopping threshold for power iteration 30 | returns: np.ndarray of top eigenvalues, np.ndarray of top eigenvectors 31 | """ 32 | eigenvals = [] 33 | eigenvecs = [] 34 | current_op = operator 35 | prev_vec = None 36 | 37 | def _deflate(x, val, vec): 38 | return val * vec.dot(x) * vec 39 | 40 | utils.log("beginning deflated power iteration") 41 | for i in range(num_eigenthings): 42 | utils.log("computing eigenvalue/vector %d of %d" % (i + 1, num_eigenthings)) 43 | eigenval, eigenvec = power_iteration( 44 | current_op, 45 | power_iter_steps, 46 | power_iter_err_threshold, 47 | momentum=momentum, 48 | use_gpu=use_gpu, 49 | fp16=fp16, 50 | init_vec=prev_vec, 51 | ) 52 | utils.log("eigenvalue %d: %.4f" % (i + 1, eigenval)) 53 | 54 | def _new_op_fn(x, op=current_op, val=eigenval, vec=eigenvec): 55 | return utils.maybe_fp16(op.apply(x), fp16) - _deflate(x, val, vec) 56 | 57 | current_op = LambdaOperator(_new_op_fn, operator.size) 58 | prev_vec = eigenvec 59 | eigenvals.append(eigenval) 60 | eigenvec = eigenvec.cpu() 61 | if to_numpy: 62 | # Clone so that power_iteration can continue to use torch. 63 | numpy_eigenvec = eigenvec.detach().clone().numpy() 64 | eigenvecs.append(numpy_eigenvec) 65 | else: 66 | eigenvecs.append(eigenvec) 67 | 68 | eigenvals = np.array(eigenvals) 69 | eigenvecs = np.array(eigenvecs) 70 | 71 | # sort them in descending order 72 | sorted_inds = np.argsort(eigenvals) 73 | eigenvals = eigenvals[sorted_inds][::-1] 74 | eigenvecs = eigenvecs[sorted_inds][::-1] 75 | return eigenvals, eigenvecs 76 | 77 | 78 | def power_iteration( 79 | operator: Operator, 80 | steps: int = 20, 81 | error_threshold: float = 1e-4, 82 | momentum: float = 0.0, 83 | use_gpu: bool = True, 84 | fp16: bool = False, 85 | init_vec: torch.Tensor = None, 86 | ) -> Tuple[float, torch.Tensor]: 87 | """ 88 | Compute dominant eigenvalue/eigenvector of a matrix 89 | operator: linear Operator giving us matrix-vector product access 90 | steps: number of update steps to take 91 | returns: (principal eigenvalue, principal eigenvector) pair 92 | """ 93 | vector_size = operator.size # input dimension of operator 94 | if init_vec is None: 95 | vec = torch.rand(vector_size) 96 | else: 97 | vec = init_vec 98 | 99 | vec = utils.maybe_fp16(vec, fp16) 100 | 101 | if use_gpu: 102 | vec = vec.cuda() 103 | 104 | prev_lambda = 0.0 105 | prev_vec = utils.maybe_fp16(torch.randn_like(vec), fp16) 106 | for i in range(steps): 107 | prev_vec = vec / (torch.norm(vec) + 1e-6) 108 | new_vec = utils.maybe_fp16(operator.apply(vec), fp16) - momentum * prev_vec 109 | # need to handle case where we end up in the nullspace of the operator. 110 | # in this case, we are done. 111 | if torch.norm(new_vec).item() == 0.0: 112 | return 0.0, new_vec 113 | lambda_estimate = vec.dot(new_vec).item() 114 | diff = lambda_estimate - prev_lambda 115 | vec = new_vec.detach() / torch.norm(new_vec) 116 | if lambda_estimate == 0.0: # for low-rank 117 | error = 1.0 118 | else: 119 | error = np.abs(diff / lambda_estimate) 120 | utils.progress_bar(i, steps, "power iter error: %.4f" % error) 121 | if error < error_threshold: 122 | break 123 | prev_lambda = lambda_estimate 124 | return lambda_estimate, vec 125 | -------------------------------------------------------------------------------- /hessian_eigenthings/utils.py: -------------------------------------------------------------------------------- 1 | """ small helpers """ 2 | import logging 3 | import shutil 4 | import sys 5 | import time 6 | 7 | TOTAL_BAR_LENGTH = 65.0 8 | 9 | term_width = shutil.get_terminal_size().columns 10 | 11 | 12 | def log(msg): 13 | logging.info("[hessian_eigenthings] " + str(msg)) 14 | 15 | 16 | def maybe_fp16(vec, fp16): 17 | return vec.half() if fp16 else vec.float() 18 | 19 | 20 | last_time = time.time() 21 | begin_time = last_time 22 | 23 | 24 | def format_time(seconds): 25 | """ converts seconds into day-hour-minute-second-ms string format """ 26 | days = int(seconds / 3600 / 24) 27 | seconds = seconds - days * 3600 * 24 28 | hours = int(seconds / 3600) 29 | seconds = seconds - hours * 3600 30 | minutes = int(seconds / 60) 31 | seconds = seconds - minutes * 60 32 | secondsf = int(seconds) 33 | seconds = seconds - secondsf 34 | millis = int(seconds * 1000) 35 | 36 | formatted = "" 37 | i = 1 38 | if days > 0: 39 | formatted += str(days) + "D" 40 | i += 1 41 | if hours > 0 and i <= 2: 42 | formatted += str(hours) + "h" 43 | i += 1 44 | if minutes > 0 and i <= 2: 45 | formatted += str(minutes) + "m" 46 | i += 1 47 | if secondsf > 0 and i <= 2: 48 | formatted += str(secondsf) + "s" 49 | i += 1 50 | if millis > 0 and i <= 2: 51 | formatted += str(millis) + "ms" 52 | i += 1 53 | if formatted == "": 54 | formatted = "0ms" 55 | return formatted 56 | 57 | 58 | def progress_bar(current, total, msg=None): 59 | """handy utility to display an updating progress bar... 60 | percentage completed is computed as current/total 61 | 62 | from: https://github.com/noahgolmant/skeletor/blob/master/skeletor/utils.py 63 | """ 64 | global last_time, begin_time # pylint: disable=global-statement 65 | if current == 0: 66 | begin_time = time.time() # Reset for new bar. 67 | 68 | cur_len = int(TOTAL_BAR_LENGTH * current / total) 69 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 70 | 71 | sys.stdout.write(" [") 72 | for _ in range(cur_len): 73 | sys.stdout.write("=") 74 | sys.stdout.write(">") 75 | for _ in range(rest_len): 76 | sys.stdout.write(".") 77 | sys.stdout.write("]") 78 | 79 | cur_time = time.time() 80 | step_time = cur_time - last_time 81 | last_time = cur_time 82 | tot_time = cur_time - begin_time 83 | 84 | msg_parts = [] 85 | msg_parts.append(" Step: %s" % format_time(step_time)) 86 | msg_parts.append(" | Tot: %s" % format_time(tot_time)) 87 | if msg: 88 | msg_parts.append(" | " + msg) 89 | 90 | msg = "".join(msg_parts) 91 | sys.stdout.write(msg) 92 | for _ in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): 93 | sys.stdout.write(" ") 94 | 95 | # Go back to the center of the bar. 96 | for _ in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2): 97 | sys.stdout.write("\b") 98 | sys.stdout.write(" %d/%d " % (current + 1, total)) 99 | 100 | if current < total - 1: 101 | sys.stdout.write("\r") 102 | else: 103 | sys.stdout.write("\n") 104 | sys.stdout.flush() 105 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """setup.py for hessian_eigenthings""" 2 | 3 | from setuptools import setup, find_packages 4 | 5 | install_requires = [ 6 | 'numpy>=0.14', 7 | 'torch>=0.4', 8 | 'scipy>=1.2.1' 9 | ] 10 | 11 | setup(name="hessian_eigenthings", 12 | author="Noah Golmant", 13 | install_requires=install_requires, 14 | packages=find_packages(), 15 | description='Eigendecomposition of model Hessians in PyTorch!', 16 | version='0.0.2') 17 | -------------------------------------------------------------------------------- /tests/principle_eigenvec_tests.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | from hessian_eigenthings import compute_hessian_eigenthings 5 | from utils import plot_eigenval_estimates, plot_eigenvec_errors 6 | 7 | from torch.utils.data import DataLoader 8 | from torch import nn 9 | import matplotlib.pyplot as plt 10 | from variance_tests import get_full_hessian 11 | 12 | import scipy 13 | 14 | 15 | def test_principal_eigenvec(model, criterion, x, y, ntrials, fp16): 16 | loss = criterion(model(x), y) 17 | loss_grad = torch.autograd.grad(loss, model.parameters(), create_graph=True) 18 | print("computing real hessian") 19 | real_hessian = get_full_hessian(loss_grad, model) 20 | # 21 | real_hessian += 1e-4 * np.eye(len(real_hessian)) 22 | 23 | samples = [(x_i, y_i) for x_i, y_i in zip(x, y)] 24 | # full dataset 25 | dataloader = DataLoader(samples, batch_size=len(x)) 26 | 27 | print("computing numpy principal eigenvec of hessian") 28 | num_params = len(real_hessian) 29 | real_eigenvals, real_eigenvecs = scipy.linalg.eigh( 30 | real_hessian, eigvals=(num_params - 1, num_params - 1) 31 | ) 32 | real_eigenvec, real_eigenval = real_eigenvecs[0], real_eigenvals[0] 33 | 34 | eigenvals = [] 35 | eigenvecs = [] 36 | 37 | nparams = len(real_hessian) 38 | 39 | # for _ in range(ntrials): 40 | est_eigenvals, est_eigenvecs = compute_hessian_eigenthings( 41 | model, 42 | dataloader, 43 | criterion, 44 | num_eigenthings=1, 45 | power_iter_steps=10, 46 | power_iter_err_threshold=1e-5, 47 | momentum=0, 48 | use_gpu=False, 49 | fp16=fp16 50 | ) 51 | est_eigenval, est_eigenvec = est_eigenvecs[0], est_eigenvals[0] 52 | 53 | # compute cosine similarity 54 | print(real_eigenvec, est_eigenvec) 55 | 56 | dotted = np.dot(real_eigenvec, est_eigenvec) 57 | if dotted == 0.0: 58 | score = 1.0 # both in nullspace... nice... 59 | else: 60 | norm = scipy.linalg.norm(real_eigenvec) * scipy.linalg.norm(est_eigenvec) 61 | score = abs(dotted / norm) 62 | print(score) 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser(description='power iteration tester') 67 | 68 | parser.add_argument('--data_dim', type=int, default=100) 69 | parser.add_argument('--hidden_dim', type=int, default=1000) 70 | parser.add_argument('--fp16', action='store_true') 71 | parser.add_argument('--mode', default='power_iter', 72 | choices=['power_iter', 'lanczos']) 73 | args = parser.parse_args() 74 | 75 | indim = outdim = args.data_dim 76 | hidden = args.hidden_dim 77 | nsamples = 10 78 | ntrials = 1 79 | bs = 10 80 | 81 | model = nn.Sequential( 82 | nn.Linear(indim, hidden), 83 | nn.ReLU(inplace=True), 84 | nn.Linear(hidden, outdim), 85 | nn.ReLU(inplace=True), 86 | ) 87 | criterion = torch.nn.MSELoss() 88 | 89 | x = torch.rand((nsamples, indim)) 90 | y = torch.rand((nsamples, outdim)) 91 | 92 | test_principal_eigenvec(model, criterion, x, y, ntrials, fp16=args.fp16) 93 | -------------------------------------------------------------------------------- /tests/random_matrix_tests.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file tests the accuracy of the power iteration methods by comparing 3 | against np.linalg.eig results for various random matrix configurations 4 | """ 5 | 6 | import argparse 7 | import functools 8 | import numpy as np 9 | import torch 10 | from hessian_eigenthings.operator import LambdaOperator 11 | from hessian_eigenthings.power_iter import deflated_power_iteration 12 | from hessian_eigenthings.lanczos import lanczos 13 | import matplotlib.pyplot as plt 14 | from utils import plot_eigenval_estimates, plot_eigenvec_errors 15 | 16 | 17 | parser = argparse.ArgumentParser(description='power iteration tester') 18 | 19 | parser.add_argument('--matrix_dim', type=int, default=100, 20 | help='number of rows/columns in matrix') 21 | parser.add_argument('--num_eigenthings', type=int, default=10, 22 | help='number of eigenvalues to compute') 23 | parser.add_argument('--power_iter_steps', default=20, type=int, 24 | help='number of steps of power iteration') 25 | parser.add_argument('--momentum', default=0, type=float, 26 | help='acceleration term for stochastic power iter') 27 | parser.add_argument('--num_trials', default=30, type=int, 28 | help='number of matrices per test') 29 | parser.add_argument('--seed', default=1, type=int) 30 | parser.add_argument('--fp16', action='store_true') 31 | parser.add_argument('--mode', default='power_iter', 32 | choices=['power_iter', 'lanczos']) 33 | args = parser.parse_args() 34 | 35 | 36 | def test_matrix(mat, ntrials, mode): 37 | """ 38 | Tests the accuracy of deflated power iteration on the given matrix. 39 | It computes the average percent eigenval error and eigenvec simliartiy err 40 | """ 41 | tensor = torch.from_numpy(mat).float() 42 | 43 | # for non-gpu tests, addmv not implemented for fp16 on CPU. have to do float. 44 | op = LambdaOperator(lambda x: torch.matmul(tensor, x.float()), tensor.size()[:1]) 45 | real_eigenvals, true_eigenvecs = np.linalg.eig(mat) 46 | real_eigenvecs = [true_eigenvecs[:, i] for i in range(len(real_eigenvals))] 47 | 48 | eigenvals = [] 49 | eigenvecs = [] 50 | 51 | if mode == 'lanczos': 52 | method = lanczos 53 | else: 54 | method = functools.partial(deflated_power_iteration, 55 | power_iter_steps=args.power_iter_steps, 56 | momentum=args.momentum) 57 | 58 | for _ in range(ntrials): 59 | est_eigenvals, est_eigenvecs = method( 60 | op, 61 | num_eigenthings=args.num_eigenthings, 62 | use_gpu=False, 63 | fp16=args.fp16 64 | ) 65 | est_inds = np.argsort(est_eigenvals) 66 | est_eigenvals = np.array(est_eigenvals)[est_inds][::-1] 67 | est_eigenvecs = np.array(est_eigenvecs)[est_inds][::-1] 68 | 69 | eigenvals.append(est_eigenvals) 70 | eigenvecs.append(est_eigenvecs) 71 | 72 | eigenvals = np.array(eigenvals) 73 | eigenvecs = np.array(eigenvecs) 74 | 75 | # truncate estimates 76 | real_inds = np.argsort(real_eigenvals) 77 | real_eigenvals = np.array(real_eigenvals)[real_inds][-args.num_eigenthings:][::-1] 78 | real_eigenvecs = np.array(real_eigenvecs)[real_inds][-args.num_eigenthings:][::-1] 79 | 80 | # Plot eigenvalue error 81 | plt.suptitle('Random Matrix Eigendecomposition Errors: %d trials' % ntrials) 82 | plt.subplot(1, 2, 1) 83 | plt.title('Eigenvalues') 84 | plt.plot(list(range(len(real_eigenvals))), real_eigenvals, label='True Eigenvals', linestyle='--', linewidth=5) 85 | plot_eigenval_estimates(eigenvals, label='Estimates') 86 | plt.legend() 87 | # Plot eigenvector L2 norm error 88 | plt.subplot(1, 2, 2) 89 | plt.title('Eigenvector cosine simliarity') 90 | plot_eigenvec_errors(real_eigenvecs, eigenvecs, label='Estimates') 91 | plt.legend() 92 | plt.show() 93 | 94 | 95 | def generate_wishart(n, offset=0.0): 96 | """ 97 | Generates a wishart PSD matrix with n rows/cols. 98 | Adds offset * I for conditioning testing. 99 | """ 100 | matrix = np.random.random(size=(n, n)).astype(float) 101 | matrix = matrix.transpose().dot(matrix) 102 | matrix = matrix + offset * np.eye(n) 103 | return (1./n) * matrix 104 | 105 | 106 | def test_wishart(): 107 | m = generate_wishart(args.matrix_dim) 108 | test_matrix(m, args.num_trials, mode=args.mode) 109 | 110 | 111 | if __name__ == '__main__': 112 | test_wishart() 113 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def compute_eigenvec_cos_similarity(actual, estimated): 5 | scores = [] 6 | for estimate in estimated: 7 | score = np.abs(np.dot(actual, estimate)) 8 | scores.append(score) 9 | return scores 10 | 11 | 12 | def plot_eigenval_estimates(estimates, label): 13 | """ 14 | estimates = 2D array (num_trials x num_eigenvalues) 15 | 16 | x-axis = eigenvalue index 17 | y-axis = eigenvalue estimate 18 | """ 19 | if len(estimates.shape) == 1: 20 | var = np.zeros_like(estimates) 21 | else: 22 | var = np.var(estimates, axis=0) 23 | y = np.mean(estimates, axis=0) 24 | x = list(range(len(y))) 25 | error = np.sqrt(var) 26 | plt.plot(x, y, label=label) 27 | plt.fill_between(x, y-error, y+error, alpha=.2) 28 | 29 | 30 | def plot_eigenvec_errors(true, estimates, label): 31 | """ 32 | plots error for all eigenvector estimates in L2 norm 33 | estimates = (num_trials x num_eigenvalues x num_params) 34 | true = (num_eigenvalues x num_params) 35 | """ 36 | diffs = [] 37 | num_eigenvals = true.shape[0] 38 | for i in range(num_eigenvals): 39 | cur_estimates = estimates[:, i, :] 40 | cur_eigenvec = true[i] 41 | diff = compute_eigenvec_cos_similarity(cur_eigenvec, cur_estimates) 42 | diffs.append(diff) 43 | diffs = np.array(diffs).T 44 | var = np.var(diffs, axis=0) 45 | y = np.mean(diffs, axis=0) 46 | x = list(range(len(y))) 47 | 48 | error = np.sqrt(var) 49 | plt.plot(x, y, label=label) 50 | plt.fill_between(x, y-error, y+error, alpha=.2) 51 | -------------------------------------------------------------------------------- /tests/variance_tests.py: -------------------------------------------------------------------------------- 1 | """ 2 | This test looks at the variance of eigenvalue/eigenvector estimates 3 | (1) Full dataset should have deterministic results 4 | (2) Compute variance of repeated trials and the effect of averaging, error 5 | relative to full dataset 6 | (3) Compute variance of full power iteration on a fixed mini-batch (vs. 7 | varying the mini-batch at each step) compared to full dataset 8 | """ 9 | 10 | import numpy as np 11 | import torch 12 | from hessian_eigenthings import compute_hessian_eigenthings 13 | from utils import plot_eigenval_estimates, plot_eigenvec_errors 14 | 15 | from torch.utils.data import DataLoader 16 | import matplotlib.pyplot as plt 17 | 18 | from hessian_eigenthings.utils import progress_bar 19 | 20 | 21 | def get_full_hessian(loss_grad, model): 22 | # from https://discuss.pytorch.org/t/compute-the-hessian-matrix-of-a-network/15270/3 23 | cnt = 0 24 | loss_grad = list(loss_grad) 25 | for i, g in enumerate(loss_grad): 26 | progress_bar( 27 | i, 28 | len(loss_grad), 29 | "flattening to full gradient: %d of %d" % (i, len(loss_grad)), 30 | ) 31 | g_vector = ( 32 | g.contiguous().view(-1) 33 | if cnt == 0 34 | else torch.cat([g_vector, g.contiguous().view(-1)]) 35 | ) 36 | cnt = 1 37 | hessian_size = g_vector.size(0) 38 | hessian = torch.zeros(hessian_size, hessian_size) 39 | for idx in range(hessian_size): 40 | progress_bar( 41 | idx, hessian_size, "full hessian columns: %d of %d" % (idx, hessian_size) 42 | ) 43 | grad2rd = torch.autograd.grad( 44 | g_vector[idx], model.parameters(), create_graph=True 45 | ) 46 | cnt = 0 47 | for g in grad2rd: 48 | g2 = ( 49 | g.contiguous().view(-1) 50 | if cnt == 0 51 | else torch.cat([g2, g.contiguous().view(-1)]) 52 | ) 53 | cnt = 1 54 | hessian[idx] = g2 55 | return hessian.cpu().data.numpy() 56 | 57 | 58 | def test_full_hessian(model, criterion, x, y, ntrials=10): 59 | loss = criterion(model(x), y) 60 | loss_grad = torch.autograd.grad(loss, model.parameters(), create_graph=True) 61 | real_hessian = get_full_hessian(loss_grad, model) 62 | 63 | samples = [(x_i, y_i) for x_i, y_i in zip(x, y)] 64 | # full dataset 65 | dataloader = DataLoader(samples, batch_size=len(x)) 66 | 67 | eigenvals = [] 68 | eigenvecs = [] 69 | 70 | nparams = len(real_hessian) 71 | 72 | for _ in range(ntrials): 73 | est_eigenvals, est_eigenvecs = compute_hessian_eigenthings( 74 | model, 75 | dataloader, 76 | criterion, 77 | num_eigenthings=nparams, 78 | power_iter_steps=100, 79 | power_iter_err_threshold=1e-9, 80 | momentum=0.0, 81 | use_gpu=False, 82 | ) 83 | est_inds = np.argsort(est_eigenvals) 84 | est_eigenvals = np.array(est_eigenvals)[est_inds][::-1] 85 | est_eigenvecs = np.array(est_eigenvecs)[est_inds][::-1] 86 | 87 | eigenvals.append(est_eigenvals) 88 | eigenvecs.append(est_eigenvecs) 89 | 90 | eigenvals = np.array(eigenvals) 91 | eigenvecs = np.array(eigenvecs) 92 | 93 | real_eigenvals, real_eigenvecs = np.linalg.eig(real_hessian) 94 | real_inds = np.argsort(real_eigenvals) 95 | real_eigenvals = np.array(real_eigenvals)[real_inds][::-1] 96 | real_eigenvecs = np.array(real_eigenvecs)[real_inds][::-1] 97 | 98 | # Plot eigenvalue error 99 | plt.suptitle("Hessian eigendecomposition errors: %d trials" % ntrials) 100 | plt.subplot(1, 2, 1) 101 | plt.title("Eigenvalues") 102 | plt.plot(list(range(nparams)), real_eigenvals, label="True Eigenvals", linewidth=3, linestyle='--') 103 | plot_eigenval_estimates(eigenvals, label="Estimates") 104 | plt.legend() 105 | # Plot eigenvector L2 norm error 106 | plt.subplot(1, 2, 2) 107 | plt.title("Eigenvector cosine simliarity") 108 | plot_eigenvec_errors(real_eigenvecs, eigenvecs, label="Estimates") 109 | plt.legend() 110 | plt.savefig("full.png") 111 | plt.clf() 112 | return real_hessian 113 | 114 | 115 | def test_stochastic_hessian(model, criterion, real_hessian, x, y, bs=10, ntrials=10): 116 | samples = [(x_i, y_i) for x_i, y_i in zip(x, y)] 117 | # full dataset 118 | dataloader = DataLoader(samples, batch_size=bs) 119 | 120 | eigenvals = [] 121 | eigenvecs = [] 122 | 123 | nparams = len(real_hessian) 124 | 125 | for _ in range(ntrials): 126 | est_eigenvals, est_eigenvecs = compute_hessian_eigenthings( 127 | model, 128 | dataloader, 129 | criterion, 130 | num_eigenthings=nparams, 131 | power_iter_steps=100, 132 | power_iter_err_threshold=1e-9, 133 | momentum=0, 134 | use_gpu=False, 135 | ) 136 | 137 | est_inds = np.argsort(est_eigenvals) 138 | est_eigenvals = np.array(est_eigenvals)[est_inds][::-1] 139 | est_eigenvecs = np.array(est_eigenvecs)[est_inds][::-1] 140 | 141 | eigenvals.append(est_eigenvals) 142 | eigenvecs.append(est_eigenvecs) 143 | 144 | eigenvals = np.array(eigenvals) 145 | eigenvecs = np.array(eigenvecs) 146 | 147 | real_eigenvals, real_eigenvecs = np.linalg.eig(real_hessian) 148 | real_inds = np.argsort(real_eigenvals) 149 | real_eigenvals = np.array(real_eigenvals)[real_inds][::-1] 150 | real_eigenvecs = np.array(real_eigenvecs)[real_inds][::-1] 151 | 152 | # Plot eigenvalue error 153 | plt.suptitle("Stochastic Hessian eigendecomposition errors: %d trials" % ntrials) 154 | plt.subplot(1, 2, 1) 155 | plt.title("Eigenvalues") 156 | plt.plot(list(range(nparams)), real_eigenvals, label="True Eigenvals", linewidth=3, linestyle='--') 157 | plot_eigenval_estimates(eigenvals, label="Estimates") 158 | plt.legend() 159 | # Plot eigenvector L2 norm error 160 | plt.subplot(1, 2, 2) 161 | plt.title("Eigenvector cosine simliarity") 162 | plot_eigenvec_errors(real_eigenvecs, eigenvecs, label="Estimates") 163 | plt.legend() 164 | plt.savefig("stochastic.png") 165 | plt.clf() 166 | 167 | 168 | def test_fixed_mini(model, criterion, real_hessian, x, y, bs=10, ntrials=10): 169 | x = x[:bs] 170 | y = y[:bs] 171 | 172 | samples = [(x_i, y_i) for x_i, y_i in zip(x, y)] 173 | # full dataset 174 | dataloader = DataLoader(samples, batch_size=len(x)) 175 | 176 | eigenvals = [] 177 | eigenvecs = [] 178 | 179 | nparams = len(real_hessian) 180 | 181 | for _ in range(ntrials): 182 | est_eigenvals, est_eigenvecs = compute_hessian_eigenthings( 183 | model, 184 | dataloader, 185 | criterion, 186 | num_eigenthings=nparams, 187 | mode="lanczos", 188 | power_iter_steps=10, 189 | power_iter_err_threshold=1e-5, 190 | momentum=0, 191 | use_gpu=False, 192 | ) 193 | est_eigenvals = np.array(est_eigenvals) 194 | est_eigenvecs = np.array([t.numpy() for t in est_eigenvecs]) 195 | 196 | est_inds = np.argsort(est_eigenvals) 197 | est_eigenvals = np.array(est_eigenvals)[est_inds][::-1] 198 | est_eigenvecs = np.array(est_eigenvecs)[est_inds][::-1] 199 | 200 | eigenvals.append(est_eigenvals) 201 | eigenvecs.append(est_eigenvecs) 202 | 203 | eigenvals = np.array(eigenvals) 204 | eigenvecs = np.array(eigenvecs) 205 | 206 | real_eigenvals, real_eigenvecs = np.linalg.eig(real_hessian) 207 | real_inds = np.argsort(real_eigenvals) 208 | real_eigenvals = np.array(real_eigenvals)[real_inds][::-1] 209 | real_eigenvecs = np.array(real_eigenvecs)[real_inds][::-1] 210 | 211 | # Plot eigenvalue error 212 | plt.suptitle( 213 | "Fixed mini-batch Hessian eigendecomposition errors: %d trials" % ntrials 214 | ) 215 | plt.subplot(1, 2, 1) 216 | plt.title("Eigenvalues") 217 | plt.plot(list(range(nparams)), real_eigenvals, label="True Eigenvals") 218 | plot_eigenval_estimates(eigenvals, label="Estimates") 219 | plt.legend() 220 | # Plot eigenvector L2 norm error 221 | plt.subplot(1, 2, 2) 222 | plt.title("Eigenvector cosine simliarity") 223 | plot_eigenvec_errors(real_eigenvecs, eigenvecs, label="Estimates") 224 | plt.legend() 225 | plt.savefig("fixed.png") 226 | 227 | 228 | if __name__ == "__main__": 229 | indim = 100 230 | outdim = 1 231 | nsamples = 10 232 | ntrials = 1 233 | bs = 10 234 | 235 | model = torch.nn.Linear(indim, outdim) 236 | criterion = torch.nn.MSELoss() 237 | 238 | x = torch.rand((nsamples, indim)) 239 | y = torch.rand((nsamples, outdim)) 240 | 241 | hessian = test_full_hessian(model, criterion, x, y, ntrials=ntrials) 242 | test_stochastic_hessian(model, criterion, hessian, x, y, bs=bs, ntrials=ntrials) 243 | # test_fixed_mini(model, criterion, hessian, x, y, bs=bs, ntrials=ntrials) 244 | --------------------------------------------------------------------------------