├── torchtrustncg ├── __init__.py ├── __pycache__ │ ├── utils.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ └── trust_region.cpython-39.pyc ├── utils.py └── trust_region.py ├── requirements.txt ├── examples ├── SoftMaxRegression │ ├── Results │ │ ├── cg.png │ │ └── krylov.png │ └── main.py └── ToyExamples │ └── main.py ├── LICENSE ├── README.md ├── setup.py └── .gitignore /torchtrustncg/__init__.py: -------------------------------------------------------------------------------- 1 | from .trust_region import TrustRegion 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | numpy>=1.17.0 3 | 4 | setuptools~=57.0.0 5 | matplotlib~=3.3.4 6 | torchvision~=0.11.2 -------------------------------------------------------------------------------- /examples/SoftMaxRegression/Results/cg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vchoutas/torch-trust-ncg/HEAD/examples/SoftMaxRegression/Results/cg.png -------------------------------------------------------------------------------- /examples/SoftMaxRegression/Results/krylov.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vchoutas/torch-trust-ncg/HEAD/examples/SoftMaxRegression/Results/krylov.png -------------------------------------------------------------------------------- /torchtrustncg/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vchoutas/torch-trust-ncg/HEAD/torchtrustncg/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /torchtrustncg/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vchoutas/torch-trust-ncg/HEAD/torchtrustncg/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /torchtrustncg/__pycache__/trust_region.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vchoutas/torch-trust-ncg/HEAD/torchtrustncg/__pycache__/trust_region.cpython-39.pyc -------------------------------------------------------------------------------- /torchtrustncg/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @author Vasileios Choutas 4 | # Contact: vassilis.choutas@tuebingen.mpg.de 5 | 6 | from __future__ import absolute_import 7 | from __future__ import print_function 8 | from __future__ import division 9 | 10 | import torch 11 | 12 | 13 | def rosenbrock(tensor, alpha=1.0, beta=100): 14 | x, y = tensor[..., 0], tensor[..., 1] 15 | return (alpha - x) ** 2 + beta * (y - x ** 2) ** 2 16 | 17 | 18 | def branin(tensor, **kwargs): 19 | x, y = tensor[..., 0], tensor[..., 1] 20 | loss = ((y - 0.129 * x ** 2 + 1.6 * x - 6) ** 2 + 6.07 * torch.cos(x) + 10) 21 | return loss 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Vassilis Choutas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torch-TrustNCG 2 | 3 | Pytorch implementation of a Trust Region Newton Conjugate Gradient method. 4 | 5 | ## Installation 6 | 7 | To install the model please follow the next steps in the specified order: 8 | 1. Clone this repository and install it using the *setup.py* script: 9 | ```Shell 10 | git clone https://github.com/vchoutas/torch-trust-ncg.git 11 | ``` 12 | 2. If you do not wish to modify the optimizer then run: 13 | ```Shell 14 | python setup.py install 15 | ``` 16 | 1. If you want to be able to modify the optimizer then run: 17 | ```Shell 18 | python setup.py build develop 19 | ``` 20 | 21 | ## Usage 22 | 23 | To create the optimizer simply run: 24 | ```Python 25 | optimizer = TrustRegion(parameter_list) 26 | ``` 27 | where paremeter_list is the list of parameters you wish to optimize. To perform 28 | one optimization step simply call the step function and pass a closure that 29 | computes the loss and the gradients. Note the the closure should have a boolean 30 | argument named *backward*, so that the optimizer avoids unnecessary backward 31 | passes. 32 | 33 | For a simple example see the __main__.py function. To run it for the rosenbrock 34 | function execute the following command: 35 | ```Shell 36 | python -m torchtrustncg 37 | ``` 38 | 39 | ## Citation 40 | 41 | For more details see chapter 7.2 of "Numerical Optimization, Nocedal and 42 | Wright": 43 | 44 | ``` 45 | @Book{NoceWrig06, 46 | Title = {Numerical Optimization}, 47 | Author = {Jorge Nocedal and Stephen J. Wright}, 48 | Publisher = {Springer}, 49 | Year = {2006}, 50 | Address = {New York, NY, USA}, 51 | Edition = {second} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @author Vasileios Choutas 4 | # Contact: vassilis.choutas@tuebingen.mpg.de 5 | 6 | 7 | import io 8 | import os 9 | import os.path as osp 10 | 11 | from setuptools import find_packages, setup 12 | 13 | import torch 14 | 15 | # Package meta-data. 16 | NAME = 'torch_trust_ncg' 17 | DESCRIPTION = 'PyTorch Trust Region Newton Conjugate gradient method' 18 | URL = '' 19 | EMAIL = 'vassilis.choutas@tuebingen.mpg.de' 20 | AUTHOR = 'Vassilis Choutas' 21 | REQUIRES_PYTHON = '>=3.6.0' 22 | VERSION = '0.1.0' 23 | 24 | here = os.path.abspath(os.path.dirname(__file__)) 25 | 26 | try: 27 | FileNotFoundError 28 | except NameError: 29 | FileNotFoundError = IOError 30 | # Import the README and use it as the long-description. 31 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 32 | try: 33 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 34 | long_description = '\n' + f.read() 35 | except FileNotFoundError: 36 | long_description = DESCRIPTION 37 | 38 | # Load the package's __version__.py module as a dictionary. 39 | about = {} 40 | if not VERSION: 41 | with open(os.path.join(here, NAME, '__version__.py')) as f: 42 | exec(f.read(), about) 43 | else: 44 | about['__version__'] = VERSION 45 | 46 | setup(name=NAME, 47 | version=about['__version__'], 48 | description=DESCRIPTION, 49 | long_description=long_description, 50 | long_description_content_type='text/markdown', 51 | author=AUTHOR, 52 | author_email=EMAIL, 53 | python_requires=REQUIRES_PYTHON, 54 | url=URL, 55 | packages=find_packages(), 56 | classifiers=[ 57 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 58 | "Environment :: Console", 59 | "Programming Language :: Python", 60 | "Programming Language :: Python :: 3.6", 61 | "Programming Language :: Python :: 3.7"], 62 | install_requires=[ 63 | 'torch>=1.0.1', 64 | ], 65 | ) 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #### joe made this: http://goel.io/joe 2 | 3 | #### python #### 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ 142 | -------------------------------------------------------------------------------- /examples/ToyExamples/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchtrustncg import TrustRegion 3 | from torchtrustncg.utils import rosenbrock, branin 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | if __name__ == "__main__": 9 | import sys 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--function', type=str, default='rosenbrock', choices=['rosenbrock', 'branin'], 14 | help='Test function to optimize') 15 | parser.add_argument('--num-iters', dest='num_iters', default=200, type=str, 16 | help='Number of iterations to use') 17 | parser.add_argument('--gtol', default=1e-4, type=float, 18 | help='Gradient tolerance') 19 | parser.add_argument('--init-point', nargs=2, default=[2.5, 2.5], type=float, dest='init_point', 20 | help='Point to initialize the optimization') 21 | 22 | parser.add_argument('--plot', 23 | type=lambda x: x.lower() in ['true', '1'], 24 | help='Plot the function trajectory') 25 | 26 | args = parser.parse_args() 27 | function = args.function 28 | num_iters = args.num_iters 29 | gtol = args.gtol 30 | init_point = args.init_point 31 | plot = args.plot 32 | 33 | if function == 'rosenbrock': 34 | loss_func = rosenbrock 35 | elif function == 'branin': 36 | loss_func = branin 37 | else: 38 | print(f'Unknown loss function {function}') 39 | sys.exit(-1) 40 | 41 | variable = torch.empty([1, 2], requires_grad=True) 42 | with torch.no_grad(): 43 | variable[0, 0] = init_point[0] 44 | variable[0, 1] = init_point[1] 45 | 46 | optimizer = TrustRegion([variable], opt_method='krylov') 47 | 48 | def closure(backward=True): 49 | if backward: 50 | optimizer.zero_grad() 51 | loss = loss_func(variable) 52 | if backward: 53 | loss.backward(create_graph=True) 54 | return loss 55 | 56 | points = [] 57 | values = [] 58 | 59 | for n in range(num_iters): 60 | loss = optimizer.step(closure) 61 | 62 | np_var = variable.detach().cpu().numpy().squeeze() 63 | 64 | if plot: 65 | points.append(np_var.copy()) 66 | values.append(loss.item()) 67 | 68 | if torch.norm(variable.grad).item() < gtol: 69 | break 70 | if torch.norm(optimizer.param_step, dim=-1).lt(gtol).all(): 71 | break 72 | 73 | print( 74 | f'[{n:04d}]: ' + 75 | f'Loss at ({variable[0, 0]:.4f}, {variable[0, 1]:.4f}) = ' + 76 | f'{loss.item():.4f}') 77 | 78 | if plot: 79 | N = 100 80 | X, Y = np.meshgrid( 81 | np.linspace(-4 + init_point[0], init_point[0] + 4, N), 82 | np.linspace(-4 + init_point[1], init_point[1] + 4, N) 83 | ) 84 | 85 | grid_points = np.stack([X, Y], axis=-1).reshape(-1, 2) 86 | Z = loss_func(torch.from_numpy(grid_points)).detach().numpy() 87 | 88 | cs = plt.contourf(X, Y, Z.reshape(N, N)) 89 | 90 | points = np.stack(points, axis=0) 91 | 92 | plt.plot(points[:, 0], points[:, 1], 'x-', color='black') 93 | plt.show() 94 | -------------------------------------------------------------------------------- /examples/SoftMaxRegression/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.autograd import Variable 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | import torchvision.datasets as dsets 8 | import torchvision.transforms as transforms 9 | 10 | from torchtrustncg import TrustRegion 11 | 12 | 13 | class SoftMaxRegression(torch.nn.Module): 14 | def __init__(self, input_dim, output_dim): 15 | super(SoftMaxRegression, self).__init__() 16 | self.linear = torch.nn.Linear(input_dim, output_dim, bias=True) 17 | 18 | def forward(self, x): 19 | outputs = self.linear(x) 20 | return outputs 21 | 22 | 23 | def acc(model, data_loader, device): 24 | with torch.no_grad(): 25 | correct = 0 26 | total = 0 27 | for samples, labels in data_loader: 28 | ####################### 29 | # USE GPU FOR MODEL # 30 | ####################### 31 | samples = Variable(samples.view(-1, 28 * 28)).to(device) 32 | labels = labels.to(dtype=torch.float32).to(device) 33 | outputs = model(samples) 34 | _, predicted = torch.max(outputs.data, 1) 35 | total += labels.size(0) 36 | ####################### 37 | # USE GPU FOR MODEL # 38 | ####################### 39 | # Total correct predictions 40 | if torch.cuda.is_available(): 41 | correct += (predicted.cpu() == labels.cpu()).sum() 42 | else: 43 | correct += (predicted == labels).sum() 44 | 45 | accuracy = 100 * correct / total 46 | return accuracy 47 | 48 | 49 | def main(): 50 | ####################### 51 | # USE GPU FOR MODEL # 52 | ####################### 53 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 54 | device = 'cpu' # Fixed in this case 55 | # print('Device to run: {}'.format(device)) 56 | 57 | model = SoftMaxRegression(input_dim, output_dim) 58 | model.to(device) 59 | 60 | # computes softmax and then the cross entropy 61 | criterion = torch.nn.CrossEntropyLoss() 62 | 63 | opt_method = 'krylov' 64 | optimizer = TrustRegion( 65 | model.parameters(), max_trust_radius=1000, initial_trust_radius=.005, 66 | eta=0.15, kappa_easy=0.01, max_newton_iter=150, max_krylov_dim=150, 67 | lanczos_tol=1e-5, gtol=1e-05, hutchinson_approx=True, 68 | opt_method=opt_method) 69 | 70 | tr_losses = [] 71 | tr_accuracies, tst_accuracies = [], [] 72 | n_iter = 0 73 | best_acc = 0.0 74 | n_runs = 1 # For mean efficiency 75 | for r in range(n_runs): 76 | for n_epoch in range(int(epochs)): 77 | running_loss = 0 78 | running_samples = 0 79 | for i, (samples, labels) in enumerate(train_loader): 80 | samples = Variable(samples.view(-1, 28 * 28), 81 | requires_grad=False).to(device) 82 | labels = Variable(labels, requires_grad=False).to(device) 83 | 84 | def closure(backward=True): 85 | if torch.is_grad_enabled() and backward: 86 | optimizer.zero_grad() 87 | model_outputs = model(samples) 88 | cri_loss = criterion(model_outputs, labels) 89 | if cri_loss.requires_grad and backward: 90 | cri_loss.backward(retain_graph=True, create_graph=True) 91 | return cri_loss 92 | 93 | tr_loss = optimizer.step(closure=closure) 94 | 95 | batch_loss = tr_loss.detach().cpu() 96 | running_loss += batch_loss * train_loader.batch_size 97 | running_samples += train_loader.batch_size 98 | 99 | n_iter += 1 100 | # if n_iter % 500 == 0: 101 | # _tst_acc = acc(model, test_loader, device) 102 | # print("n_iteration: {}. Tr. Loss: {}. Tst. Accuracy: {}.".format(n_iter, 103 | # running_loss / running_samples, 104 | # _tst_acc)) 105 | 106 | tr_acc = acc(model, train_loader, device) 107 | tst_acc = acc(model, test_loader, device) 108 | 109 | print("n_iteration: {} - n_epoch {} - Tr. Loss: {} - Tr. Accuracy: {} - Tst. Accuracy: {}". 110 | format(n_iter, n_epoch + 1, running_loss / running_samples, tr_acc, tst_acc)) 111 | 112 | if tst_acc > best_acc: 113 | best_acc = tst_acc 114 | 115 | tr_losses.append(running_loss / len(train_loader.sampler)) 116 | tr_accuracies.append(tr_acc) 117 | tst_accuracies.append(tst_acc) 118 | 119 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) 120 | ax1.plot(tr_losses, 'o-', label='Train Loss') 121 | ax1.title.set_text('Loss using {}'.format(opt_method)) 122 | ax1.set_yscale('log') 123 | ax1.set_xlabel('Epochs') 124 | ax1.set_ylabel('log Loss') 125 | ax1.legend() 126 | ax1.grid() 127 | 128 | ax2.plot(tr_accuracies, 'o-', label='Train acc.') 129 | ax2.plot(tst_accuracies, 'o-', label='Test acc.') 130 | ax2.title.set_text('Acc using {}'.format(opt_method)) 131 | ax2.set_xlabel('Epochs') 132 | ax2.set_ylabel('Accuracy') 133 | ax2.legend() 134 | ax2.grid() 135 | 136 | plt.savefig('Results/{}.png'.format(opt_method)) 137 | plt.show() 138 | 139 | 140 | if __name__ == "__main__": 141 | batch_size = 512 142 | epochs = 15 143 | input_dim = 784 144 | output_dim = 10 145 | lr_rate = 0.01 146 | 147 | train_dataset = dsets.MNIST( 148 | root='../../Datasets', train=True, transform=transforms.ToTensor(), 149 | download=True) 150 | test_dataset = dsets.MNIST( 151 | root='../../Datasets', train=False, transform=transforms.ToTensor()) 152 | 153 | train_loader = torch.utils.data.DataLoader( 154 | dataset=train_dataset, batch_size=batch_size, shuffle=True) 155 | test_loader = torch.utils.data.DataLoader( 156 | dataset=test_dataset, batch_size=batch_size, shuffle=False) 157 | 158 | main() 159 | -------------------------------------------------------------------------------- /torchtrustncg/trust_region.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @author Vasileios Choutas 4 | # Contact: vassilis.choutas@tuebingen.mpg.de 5 | 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | from typing import NewType, List, Tuple 12 | 13 | import torch 14 | from torch import norm 15 | import torch.optim as optim 16 | import torch.autograd as autograd 17 | 18 | from loguru import logger 19 | 20 | import math 21 | 22 | Tensor = NewType('Tensor', torch.Tensor) 23 | 24 | 25 | def eye_like(tensor, device): 26 | return torch.eye(*tensor.size(), out=torch.empty_like(tensor, device=device), device=device) 27 | 28 | 29 | class TrustRegion(optim.Optimizer): 30 | 31 | def __init__( 32 | self, 33 | params: List[Tensor], 34 | max_trust_radius: float = 1000, 35 | initial_trust_radius: float = 0.5, 36 | eta: float = 0.15, 37 | gtol: float = 1e-05, 38 | kappa_easy: float = 0.1, 39 | max_newton_iter: int = 50, 40 | max_krylov_dim: int = 15, 41 | lanczos_tol: float = 1e-4, 42 | opt_method: str = 'cg', 43 | epsilon: float = 1.0e-09, 44 | **kwargs 45 | ) -> None: 46 | """ Trust Region 47 | Newton Conjugate Gradient 48 | Uses the Conjugate Gradient Algorithm to find the solution of the 49 | trust region sub-problem. For more details see Algorithm 7.2 of 50 | "Numerical Optimization, Nocedal and Wright" 51 | Generalized Lanczos Method 52 | Uses the GEneralized Lanczos Algorithm to find the solution of the 53 | trust region sub-problem. For more details see Algorithm7.5.2 of 54 | "Trust Region Methods, Conn et al." 55 | Arguments: 56 | params (iterable): A list or iterable of tensors that will be 57 | optimized 58 | max_trust_radius: float 59 | The maximum value for the trust radius 60 | initial_trust_radius: float 61 | The initial value for the trust region 62 | eta: float 63 | Minimum improvement ration for accepting a step 64 | kappa_easy: float 65 | Parameter related to the convergence of Krylov method, see Lemma 7.3.5 Conn et al. 66 | max_newton_iter: int 67 | Maximum Newton iterations for root finding 68 | max_krylov_dim: int 69 | Maximum Krylov dimension 70 | lanczos_tol: float 71 | Approximation error of the optimizer in Krylov subspace, see Theorem 7.5.10 Conn et al. 72 | opt_method: string 73 | The method to solve the subproblem. 74 | gtol: float 75 | Gradient tolerance for stopping the optimization 76 | """ 77 | defaults = dict() 78 | 79 | super(TrustRegion, self).__init__(params, defaults) 80 | 81 | self.steps = 0 82 | self.max_trust_radius = max_trust_radius 83 | self.initial_trust_radius = initial_trust_radius 84 | self.eta = eta 85 | self.gtol = gtol 86 | self._params = self.param_groups[0]['params'] 87 | 88 | self.kappa_easy = kappa_easy 89 | self.opt_method = opt_method 90 | self.lanczos_tol = lanczos_tol 91 | self.max_krylov_dim = max_krylov_dim 92 | self.max_newton_iter = max_newton_iter 93 | self.kwargs = kwargs 94 | 95 | self.epsilon = epsilon 96 | 97 | self.T_lambda = lambda _lambda, T_x, device: T_x.to( 98 | device) + _lambda * eye_like(T_x, device) 99 | self.lambda_const = lambda lambda_k: ( 100 | 1 + lambda_k) * torch.sqrt(torch.tensor(torch.finfo(torch.float32).eps)) 101 | 102 | if not (opt_method == 'cg' or opt_method == 'krylov'): 103 | raise ValueError('opt_method should be "cg" or "krylov"') 104 | 105 | @torch.enable_grad() 106 | def _compute_hessian_vector_product( 107 | self, 108 | gradient: Tensor, 109 | p: Tensor) -> Tensor: 110 | 111 | hess_vp = autograd.grad( 112 | torch.sum(gradient * p, dim=-1), self._params, 113 | only_inputs=True, retain_graph=True, allow_unused=True) 114 | return torch.cat([torch.flatten(vp) for vp in hess_vp], dim=-1) 115 | # hess_vp = torch.cat( 116 | # [torch.flatten(vp) for vp in hess_vp], dim=-1) 117 | # return torch.flatten(hess_vp) 118 | 119 | def _gather_flat_grad(self) -> Tensor: 120 | """ Concatenates all gradients into a single gradient vector 121 | """ 122 | views = [] 123 | for p in self._params: 124 | if p.grad is None: 125 | view = p.data.new(p.data.numel()).zero_() 126 | elif p.grad.data.is_sparse: 127 | view = p.grad.to_dense().view(-1) 128 | else: 129 | view = p.grad.view(-1) 130 | views.append(view) 131 | output = torch.cat(views, 0) 132 | return output 133 | 134 | @torch.no_grad() 135 | def _improvement_ratio(self, p, start_loss, gradient, closure): 136 | """ Calculates the ratio of the actual to the expected improvement 137 | 138 | Arguments: 139 | p (torch.tensor): The update vector for the parameters 140 | start_loss (torch.tensor): The value of the loss function 141 | before applying the optimization step 142 | gradient (torch.tensor): The flattened gradient vector of the 143 | parameters 144 | closure (callable): The function that evaluates the loss for 145 | the current values of the parameters 146 | Returns: 147 | The ratio of the actual improvement of the loss to the expected 148 | improvement, as predicted by the local quadratic model 149 | """ 150 | 151 | # Apply the update on the parameter to calculate the loss on the new 152 | # point 153 | hess_vp = self._compute_hessian_vector_product(gradient, p) 154 | 155 | # Apply the update of the parameter vectors. 156 | # Use a torch.no_grad() context since we are updating the parameters in 157 | # place 158 | with torch.no_grad(): 159 | start_idx = 0 160 | for param in self._params: 161 | num_els = param.numel() 162 | curr_upd = p[start_idx:start_idx + num_els] 163 | param.data.add_(curr_upd.view_as(param)) 164 | start_idx += num_els 165 | 166 | # No need to backpropagate since we only need the value of the loss at 167 | # the new point to find the ratio of the actual and the expected 168 | # improvement 169 | new_loss = closure(backward=False) 170 | # The numerator represents the actual loss decrease 171 | numerator = start_loss - new_loss 172 | 173 | new_quad_val = self._quad_model(p, start_loss, gradient, hess_vp) 174 | 175 | # The denominator 176 | denominator = start_loss - new_quad_val 177 | 178 | # TODO: Convert to epsilon, print warning 179 | ratio = numerator / (denominator + 1e-20) 180 | return ratio 181 | 182 | @torch.no_grad() 183 | def _quad_model( 184 | self, 185 | p: Tensor, 186 | loss: float, 187 | gradient: Tensor, 188 | hess_vp: Tensor) -> float: 189 | """ Returns the value of the local quadratic approximation 190 | """ 191 | return (loss + torch.flatten(gradient * p).sum(dim=-1) + 192 | 0.5 * torch.flatten(hess_vp * p).sum(dim=-1)) 193 | 194 | @torch.no_grad() 195 | def calc_boundaries( 196 | self, 197 | iterate: Tensor, 198 | direction: Tensor, 199 | trust_radius: float) -> Tuple[Tensor, Tensor]: 200 | """ Calculates the offset to the boundaries of the trust region 201 | """ 202 | 203 | a = torch.sum(direction ** 2) 204 | b = 2 * torch.sum(direction * iterate) 205 | c = torch.sum(iterate ** 2) - trust_radius ** 2 206 | sqrt_discriminant = torch.sqrt(b * b - 4 * a * c) 207 | ta = (-b + sqrt_discriminant) / (2 * a) 208 | tb = (-b - sqrt_discriminant) / (2 * a) 209 | if ta.item() < tb.item(): 210 | return [ta, tb] 211 | else: 212 | return [tb, ta] 213 | 214 | @torch.no_grad() 215 | def _solve_subproblem_cg( 216 | self, 217 | loss: float, 218 | flat_grad: Tensor, 219 | trust_radius: float) -> Tuple[Tensor, bool]: 220 | ''' Solves the quadratic subproblem in the trust region 221 | ''' 222 | 223 | # The iterate vector that contains the increment from the starting 224 | # point 225 | iterate = torch.zeros_like(flat_grad, requires_grad=False) 226 | 227 | # The residual of the CG algorithm 228 | residual = flat_grad.detach() 229 | # The first direction of descent 230 | direction = -residual 231 | 232 | jac_mag = torch.norm(flat_grad).item() 233 | # Tolerance define in Nocedal & Wright in chapter 7.1 234 | tolerance = min(0.5, math.sqrt(jac_mag)) * jac_mag 235 | 236 | # If the magnitude of the gradients is smaller than the tolerance then 237 | # exit 238 | if jac_mag <= tolerance: 239 | return iterate, False 240 | 241 | # Iterate to solve the subproblem 242 | while True: 243 | # Calculate the Hessian-Vector product 244 | # start = time.time() 245 | hessian_vec_prod = self._compute_hessian_vector_product( 246 | flat_grad, direction 247 | ) 248 | # torch.cuda.synchronize() 249 | # print('Hessian Vector Product', time.time() - start) 250 | 251 | # This term is equal to p^T * H * p 252 | # start = time.time() 253 | hevp_dot_prod = torch.sum(hessian_vec_prod * direction) 254 | # print('p^T H p', time.time() - start) 255 | 256 | # If non-positive curvature 257 | if hevp_dot_prod.item() <= 0: 258 | # Find boundaries and select minimum 259 | # start = time.time() 260 | ta, tb = self.calc_boundaries(iterate, direction, trust_radius) 261 | pa = iterate + ta * direction 262 | pb = iterate + tb * direction 263 | 264 | # Calculate the point on the boundary with the smallest value 265 | bound1_val = self._quad_model(pa, loss, flat_grad, 266 | hessian_vec_prod) 267 | bound2_val = self._quad_model(pb, loss, flat_grad, 268 | hessian_vec_prod) 269 | # torch.cuda.synchronize() 270 | # print('First if', time.time() - start) 271 | # print() 272 | if bound1_val.item() < bound2_val.item(): 273 | return pa, True 274 | else: 275 | return pb, True 276 | 277 | # The squared euclidean norm of the residual needed for the CG 278 | # update 279 | # start = time.time() 280 | residual_sq_norm = torch.sum(residual * residual, dim=-1) 281 | 282 | # Compute the step size for the CG algorithm 283 | cg_step_size = residual_sq_norm / hevp_dot_prod 284 | 285 | # Update the point 286 | next_iterate = iterate + cg_step_size * direction 287 | 288 | iterate_norm = torch.norm(next_iterate, dim=-1) 289 | # torch.cuda.synchronize() 290 | # print('CG Updates', time.time() - start) 291 | 292 | # If the point is outside of the trust region project it on the 293 | # border and return 294 | if iterate_norm.item() >= trust_radius: 295 | # start = time.time() 296 | ta, tb = self.calc_boundaries(iterate, direction, trust_radius) 297 | p_boundary = iterate + tb * direction 298 | 299 | # torch.cuda.synchronize() 300 | # print('Second if', time.time() - start) 301 | # print() 302 | return p_boundary, True 303 | 304 | # start = time.time() 305 | # Update the residual 306 | next_residual = residual + cg_step_size * hessian_vec_prod 307 | # torch.cuda.synchronize() 308 | # print('Residual update', time.time() - start) 309 | # If the residual is small enough, exit 310 | if torch.norm(next_residual, dim=-1).item() < tolerance: 311 | # print() 312 | return next_iterate, False 313 | 314 | # start = time.time() 315 | beta = torch.sum(next_residual ** 2, dim=-1) / residual_sq_norm 316 | # Compute the new search direction 317 | direction = (-next_residual + beta * direction).squeeze() 318 | if torch.isnan(direction).sum() > 0: 319 | raise RuntimeError 320 | 321 | iterate = next_iterate 322 | residual = next_residual 323 | # torch.cuda.synchronize() 324 | # print('Replacing vectors', time.time() - start) 325 | # print(trust_radius) 326 | # print() 327 | 328 | @torch.no_grad() 329 | def _converged(self, s, trust_radius): 330 | 331 | if abs(norm(s) - trust_radius) <= self.kappa_easy * trust_radius: 332 | return True 333 | else: 334 | return False 335 | 336 | @torch.no_grad() 337 | def _lambda_one_plus(self, T, device): 338 | 339 | eigen_pairs = torch.linalg.eigh(T) 340 | 341 | Lambda, U = eigen_pairs.eigenvalues, eigen_pairs.eigenvectors 342 | lambda_n, u_n = Lambda[0].to(device=device), U[:, 0].to(device=device) 343 | 344 | return torch.maximum(-lambda_n, torch.tensor([0], device=device)), lambda_n, u_n[:, None] 345 | 346 | @torch.no_grad() 347 | def _quad_model_krylov( 348 | self, 349 | lanczos_g: Tensor, 350 | loss: float, 351 | s_x: Tensor, 352 | T_x: Tensor) -> float: 353 | """ 354 | Returns the value of the local quadratic approximation 355 | """ 356 | 357 | return (loss + torch.sum(lanczos_g * s_x) + 1 / 2 * torch.sum(T_x.mm(s_x) * s_x)).item() 358 | 359 | def _root_finder(self, trust_radius, T_x, lanczos_g, loss, device): 360 | 361 | n_iter_nu, n_iter_r = 0, 0 362 | lambda_k, lambda_n, u_n = self._lambda_one_plus(T_x, device) 363 | lambda_const = self.lambda_const(lambda_k).to(device=device) 364 | if lambda_k == 0: # T_x is positive definite 365 | _lambda = torch.tensor( 366 | [0], dtype=torch.float32, device=device) # + lambda_const 367 | else: 368 | _lambda = lambda_k + lambda_const 369 | 370 | s, L = self._compute_s(_lambda=_lambda, lambda_const=lambda_const, 371 | lanczos_g=lanczos_g, T_x=T_x, device=device) 372 | 373 | if norm(s) <= trust_radius: 374 | 375 | if _lambda == 0 or norm(s) == trust_radius: 376 | return s 377 | else: 378 | ta, tb = self.calc_boundaries( 379 | iterate=s, direction=u_n, trust_radius=trust_radius) 380 | pa = s + ta * u_n 381 | pb = s + tb * u_n 382 | 383 | # Calculate the point on the boundary with the smallest value 384 | bound1_val = self._quad_model_krylov(lanczos_g, loss, pa, T_x) 385 | bound2_val = self._quad_model_krylov(lanczos_g, loss, pb, T_x) 386 | 387 | if bound1_val < bound2_val: 388 | return pa 389 | else: 390 | return pb 391 | 392 | while True: 393 | if self._converged(s, trust_radius) or norm(s) < torch.finfo(float).eps: 394 | break 395 | 396 | # w = torch.triangular_solve( 397 | # s, L.T.to(device=device), upper=False).solution 398 | w = torch.linalg.solve_triangular( 399 | L.T.to(device=device), s, upper=False) 400 | _lambda = self._nu_next(_lambda, trust_radius, s, w) 401 | 402 | s, L = self._compute_s(_lambda, lambda_const, 403 | lanczos_g, T_x, device) 404 | 405 | n_iter_nu += 1 406 | if n_iter_nu > self.max_newton_iter - 1: # self.max_krylov_dim: 407 | print(RuntimeWarning( 408 | 'Maximum number of newton iterations exceeded for _lambda: {}'.format(_lambda))) 409 | break 410 | 411 | return s 412 | 413 | @torch.no_grad() 414 | def _nu_next(self, _lambda, trust_radius, s, w): 415 | 416 | norm_s = norm(s) 417 | norm_w = norm(w) 418 | 419 | phi = 1 / norm_s - 1 / trust_radius 420 | 421 | phi_prime = norm_w ** 2 / norm_s ** 3 422 | 423 | return _lambda - phi / phi_prime 424 | 425 | @torch.no_grad() 426 | def _compute_s(self, _lambda, lambda_const, lanczos_g, T_x, device): 427 | try: 428 | L = torch.linalg.cholesky(self.T_lambda(_lambda, T_x, device)) 429 | except RuntimeError: 430 | # print('Recursion') 431 | lambda_const *= 2 432 | # RecursionError: maximum recursion depth exceeded while calling a Python object 433 | s, L = self._compute_s( 434 | _lambda + lambda_const, lambda_const, lanczos_g, T_x, device) 435 | 436 | s = torch.cholesky_solve(-lanczos_g[:, None], 437 | L.to(device=device), upper=True) 438 | return s, L 439 | 440 | @torch.no_grad() 441 | def _solve_subproblem_krylov( 442 | self, 443 | loss: float, 444 | flat_grad: Tensor, 445 | trust_radius: float) -> Tuple[Tensor, bool]: 446 | """ 447 | Solves the quadratic subproblem in the trust region using Generalized Lanczos Method, 448 | see Algorithm 7.5.2 Conn et al. 449 | """ 450 | INTERIOR_FLAG = True 451 | Q, diagonals, off_diagonals = [], [], [] 452 | 453 | flat_grads_detached = flat_grad.detach() 454 | n_features = len(flat_grads_detached) 455 | h = torch.zeros_like(flat_grads_detached, requires_grad=False) 456 | q, p = flat_grads_detached, -flat_grads_detached 457 | 458 | gamma0 = torch.norm(q) 459 | 460 | krylov_dim, sigma = 0, 1 461 | 462 | device = flat_grad.device 463 | targs = {'device': device, 'dtype': flat_grad.dtype} 464 | 465 | while True: 466 | Hp = self._compute_hessian_vector_product(flat_grad, p) 467 | ptHp = torch.sum(Hp * p) 468 | alpha = torch.norm(q) ** 2 / ptHp 469 | # if alpha == 0: 470 | # print('hard case') 471 | if krylov_dim == 0: 472 | diagonals.append(1. / alpha.clamp_(min=self.epsilon).item()) 473 | off_diagonals.append(float('inf')) # dummy value 474 | Q.append(sigma * q / norm(q)) 475 | T_x = torch.tensor([diagonals], **targs) 476 | alpha_prev = alpha 477 | else: 478 | diagonals.append(1. / alpha.item() + 479 | beta.item() / alpha_prev.item()) 480 | sigma = - torch.sign(alpha_prev) * sigma 481 | Q.append(sigma * q / norm(q)) 482 | T_x = (torch.diag(torch.tensor(diagonals, **targs), 0) 483 | + torch.diag(torch.tensor(off_diagonals[1:], **targs), -1) 484 | + torch.diag(torch.tensor(off_diagonals[1:], **targs), 1)) 485 | alpha_prev = alpha 486 | 487 | if INTERIOR_FLAG and alpha < 0 or torch.norm(h + alpha * p) >= trust_radius: 488 | INTERIOR_FLAG = False 489 | 490 | if INTERIOR_FLAG: 491 | h = h + alpha * p 492 | else: 493 | # Lanczos Step 2: solve problem in subspace 494 | e_1 = torch.eye(1, krylov_dim + 1, 495 | device=flat_grad.device).flatten() 496 | lanczos_g = gamma0 * e_1 497 | s = self._root_finder(trust_radius=trust_radius, 498 | T_x=T_x, lanczos_g=lanczos_g, 499 | loss=loss, device=flat_grad.device) 500 | s = s.to(flat_grad.device) 501 | 502 | q_next = q + alpha * Hp 503 | 504 | # test for convergence 505 | if INTERIOR_FLAG and norm(q_next) ** 2 < self.lanczos_tol: 506 | break 507 | if not INTERIOR_FLAG and torch.norm(q_next) * abs(s[-1]) < self.lanczos_tol: 508 | break 509 | 510 | if krylov_dim == n_features: 511 | # print(RuntimeWarning( 512 | # 'Krylov dimensionality reach full space! Breaking out..')) 513 | break 514 | # return h 515 | 516 | if krylov_dim > self.max_krylov_dim: 517 | # print(RuntimeWarning('Max Krylov dimension reached! Breaking out..')) 518 | break 519 | 520 | beta = torch.dot(q_next, q_next) / torch.dot(q, q) 521 | off_diagonals.append(torch.sqrt(beta) / torch.abs(alpha_prev)) 522 | p = -q_next + beta * p 523 | q = q_next 524 | krylov_dim = krylov_dim + 1 525 | 526 | if not INTERIOR_FLAG: 527 | # Return to the original space 528 | Q = torch.vstack(Q).T 529 | h = torch.sum(Q * torch.squeeze(s), dim=1) 530 | 531 | return h, not INTERIOR_FLAG # INTERIOR_FLAG is False == hit_boundary is True 532 | 533 | def step(self, closure=None) -> float: 534 | starting_loss = closure(backward=True) 535 | 536 | flat_grad = self._gather_flat_grad() 537 | 538 | state = self.state 539 | if len(state) == 0: 540 | state['trust_radius'] = torch.full([1], 541 | self.initial_trust_radius, 542 | dtype=flat_grad.dtype, 543 | device=flat_grad.device) 544 | trust_radius = state['trust_radius'] 545 | 546 | if self.opt_method == 'cg': 547 | param_step, hit_boundary = self._solve_subproblem_cg( 548 | starting_loss, flat_grad, trust_radius) 549 | else: 550 | param_step, hit_boundary = self._solve_subproblem_krylov( 551 | starting_loss, flat_grad, trust_radius) 552 | 553 | self.param_step = param_step 554 | 555 | if torch.norm(param_step).item() <= self.gtol: 556 | return starting_loss 557 | 558 | improvement_ratio = self._improvement_ratio( 559 | param_step, starting_loss, flat_grad, closure) 560 | 561 | if improvement_ratio.item() < 0.25: 562 | trust_radius.mul_(0.25) 563 | else: 564 | if improvement_ratio.item() > 0.75 and hit_boundary: 565 | trust_radius.mul_(2).clamp_(0.0, self.max_trust_radius) 566 | 567 | if improvement_ratio.item() <= self.eta: 568 | # If the improvement is not sufficient, then undo the update 569 | start_idx = 0 570 | for param in self._params: 571 | num_els = param.numel() 572 | curr_upd = param_step[start_idx:start_idx + num_els] 573 | param.data.add_(-curr_upd.view_as(param)) 574 | start_idx += num_els 575 | 576 | self.steps += 1 577 | return starting_loss 578 | --------------------------------------------------------------------------------