├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── dist ├── statopt-0.2-py3-none-any.whl └── statopt-0.2.tar.gz ├── examples ├── cifar_example.ipynb └── cifar_example.py ├── setup.py └── statopt ├── __init__.py ├── bucket.py ├── qhm.py ├── salsa.py ├── sasa.py ├── slope.py └── ssls.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | # dist/ 14 | data/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributing 3 | 4 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 5 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 6 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 7 | 8 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 9 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 10 | provided by the bot. You will only need to do this once across all repos using our CLA. 11 | 12 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 13 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 14 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Statistical Adaptive Stochastic Gradient Methods 2 | 3 | A package of PyTorch optimizers that can automatically schedule learning rates based on online statistical tests. 4 | 5 | * main algorithms: SALSA and SASA 6 | * auxiliary codes: QHM and SSLS 7 | 8 | Companion paper: [Statistical Adaptive Stochastic Gradient Methods](https://www.microsoft.com/en-us/research/publication/statistical-adaptive-stochastic-gradient-methods) by Zhang, Lang, Liu and Xiao, 2020. 9 | 10 | ## Install 11 | 12 | pip install statopt 13 | 14 | Or from Github: 15 | 16 | pip install git+git://github.com/microsoft/statopt.git#egg=statopt 17 | 18 | 19 | ## Usage of SALSA and SASA 20 | 21 | Here we outline the key steps on CIFAR10. 22 | Complete Python code is given in [examples/cifar_example.py](examples/cifar_example.py). 23 | 24 | ### Common setups 25 | 26 | First, choose a batch size and prepare the dataset and data loader as in [this PyTorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html): 27 | 28 | ```python 29 | import torch, torchvision 30 | 31 | batch_size = 128 32 | trainset = torchvision.datasets.CIFAR10(root='../data', train=True, ...) 33 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, ...) 34 | ``` 35 | 36 | Choose device, network model, and loss function: 37 | 38 | ```python 39 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 40 | net = torchvision.models.resnet18().to(device) 41 | loss_func = torch.nn.CrossEntropyLoss() 42 | ``` 43 | 44 | ### SALSA 45 | Import ```statopt```, and initialize SALSA with a small learning rate and two extra parameters: 46 | 47 | ```python 48 | import statopt 49 | 50 | gamma = math.sqrt(batch_size/len(trainset)) # smoothing parameter for line search 51 | testfreq = min(1000, len(trainloader)) # frequency to perform statistical test 52 | 53 | optimizer = statopt.SALSA(net.parameters(), lr=1e-3, # any small initial learning rate 54 | momentum=0.9, weight_decay=5e-4, # common choices for CIFAR10/100 55 | gamma=gamma, testfreq=testfreq) # two extra parameters for SALSA 56 | ``` 57 | 58 | Training code using SALSA 59 | 60 | ```python 61 | for epoch in range(100): 62 | for (images, labels) in trainloader: 63 | net.train() # always switch to train() mode 64 | 65 | # Compute model outputs and loss function 66 | images, labels = images.to(device), labels.to(device) 67 | loss = loss_func(net(images), labels) 68 | 69 | # Compute gradient with back-propagation 70 | optimizer.zero_grad() 71 | loss.backward() 72 | 73 | # SALSA requires a closure function for line search 74 | def eval_loss(eval_mode=True): 75 | if eval_mode: 76 | net.eval() 77 | with torch.no_grad(): 78 | loss = loss_func(net(images), labels) 79 | return loss 80 | 81 | optimizer.step(closure=eval_loss) 82 | 83 | ``` 84 | 85 | ### SASA 86 | 87 | SASA requires a good (hand-tuned) initial learning rate like most other optimizers, but do not use line search: 88 | 89 | ```python 90 | optimizer = statopt.SASA(net.parameters(), lr=1.0, # need a good initial learning rate 91 | momentum=0.9, weight_decay=5e-4, # common choices for CIFAR10/100 92 | testfreq=testfreq) # frequency for statistical tests 93 | ``` 94 | 95 | Within the training loop: ```optimizer.step()``` does NOT need any closure function. 96 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /dist/statopt-0.2-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/statopt/dc99aa2843ed477aef0f409959b68ce6b1411663/dist/statopt-0.2-py3-none-any.whl -------------------------------------------------------------------------------- /dist/statopt-0.2.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/statopt/dc99aa2843ed477aef0f409959b68ce6b1411663/dist/statopt-0.2.tar.gz -------------------------------------------------------------------------------- /examples/cifar_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import math\n", 10 | "import torch\n", 11 | "import torchvision\n", 12 | "import statopt" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stderr", 22 | "output_type": "stream", 23 | "text": [ 24 | "\r", 25 | "0it [00:00, ?it/s]" 26 | ] 27 | }, 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "Preparing data ...\n", 33 | "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz\n" 34 | ] 35 | }, 36 | { 37 | "name": "stderr", 38 | "output_type": "stream", 39 | "text": [ 40 | "170500096it [00:04, 39303301.73it/s] \n" 41 | ] 42 | }, 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "Files already downloaded and verified\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "#----------------------------------------------------\n", 53 | "# Prepare datasets, download to the directory ../data \n", 54 | "print('Preparing data ...')\n", 55 | "batch_size = 128\n", 56 | "normalizer = torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),\n", 57 | " (0.2023, 0.1994, 0.2010))\n", 58 | "transform_train = torchvision.transforms.Compose(\n", 59 | " [torchvision.transforms.RandomCrop(32, padding=4),\n", 60 | " torchvision.transforms.RandomHorizontalFlip(),\n", 61 | " torchvision.transforms.ToTensor(), normalizer,])\n", 62 | "trainset = torchvision.datasets.CIFAR10(root='../data', train=True,\n", 63 | " download=True, \n", 64 | " transform=transform_train)\n", 65 | "sampler = torch.utils.data.sampler.RandomSampler(trainset)\n", 66 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n", 67 | " sampler=sampler, num_workers=4)\n", 68 | "transform_test = torchvision.transforms.Compose(\n", 69 | " [torchvision.transforms.ToTensor(), normalizer,])\n", 70 | "testset = torchvision.datasets.CIFAR10(root='../data', train=False,\n", 71 | " download=True, \n", 72 | " transform=transform_test)\n", 73 | "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n", 74 | " shuffle=False, num_workers=4)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "#-----------------------------------------------\n", 84 | "# Choose device, network model and loss function\n", 85 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 86 | "net = torchvision.models.resnet18(num_classes=10)\n", 87 | "\n", 88 | "cifarify = True\n", 89 | "if cifarify:\n", 90 | " class Identity(torch.nn.Module):\n", 91 | " def __init__(self):\n", 92 | " super(Identity, self).__init__()\n", 93 | "\n", 94 | " def forward(self, x):\n", 95 | " return x\n", 96 | "\n", 97 | " net.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)\n", 98 | " net.maxpool = Identity()\n", 99 | "\n", 100 | "net = net.to(device)\n", 101 | "loss_func = torch.nn.CrossEntropyLoss()" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 10, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "#--------------------------------------------------------\n", 111 | "# Choose optimizer from the list ['sgd', 'sasa', 'salsa']\n", 112 | "optimizer_name = 'sgd'\n", 113 | "print('Using optimier {}'.format(optimizer_name))\n", 114 | "\n", 115 | "if optimizer_name == 'sasa':\n", 116 | " testfreq = min(1000, len(trainloader))\n", 117 | " optimizer = statopt.SASA(net.parameters(), lr=1.0, \n", 118 | " momentum=0.9, weight_decay=5e-4, \n", 119 | " testfreq=testfreq)\n", 120 | "elif optimizer_name == 'salsa':\n", 121 | " gamma = math.sqrt(batch_size/len(trainset)) \n", 122 | " testfreq = min(1000, len(trainloader))\n", 123 | " optimizer = statopt.SALSA(net.parameters(), lr=1e-3, \n", 124 | " momentum=0.9, weight_decay=5e-4, \n", 125 | " gamma=gamma, testfreq=testfreq)\n", 126 | "else:\n", 127 | " optimizer_name = 'sgd' # SGD with a Step learning rate scheduler\n", 128 | " optimizer = torch.optim.SGD(net.parameters(), lr=0.1,\n", 129 | " momentum=0.9, weight_decay=5e-4)\n", 130 | " scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50*len(trainloader),\n", 131 | " gamma=0.1, last_epoch=-1)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 11, 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "Start training ...\n", 144 | " epoch 1: average loss 0.014\n", 145 | " epoch 2: average loss 0.013\n", 146 | "Finished training.\n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "#----------------------------------\n", 152 | "# Training the neural network model\n", 153 | "print('Start training ...')\n", 154 | "\n", 155 | "for epoch in range(250):\n", 156 | " # Reset accumulative running loss at beginning or each epoch\n", 157 | " running_loss = 0.0\n", 158 | "\n", 159 | " for (images, labels) in trainloader:\n", 160 | " # switch to train mode each time due to potential use of eval mode\n", 161 | " net.train()\n", 162 | " \n", 163 | " # Compute model outputs and loss function \n", 164 | " images, labels = images.to(device), labels.to(device)\n", 165 | " outputs = net(images)\n", 166 | " loss = loss_func(outputs, labels)\n", 167 | " \n", 168 | " # Compute gradient with back-propagation \n", 169 | " optimizer.zero_grad()\n", 170 | " loss.backward()\n", 171 | " \n", 172 | " # Call the step() method of different optimizers\n", 173 | " if optimizer_name == 'sgd':\n", 174 | " optimizer.step()\n", 175 | " scheduler.step()\n", 176 | " elif optimizer_name == 'sasa':\n", 177 | " optimizer.step()\n", 178 | " elif optimizer_name == 'salsa':\n", 179 | " def eval_loss(eval_mode=True):\n", 180 | " if eval_mode:\n", 181 | " net.eval()\n", 182 | " with torch.no_grad():\n", 183 | " loss = loss_func(net(images), labels)\n", 184 | " return loss\n", 185 | " optimizer.step(closure=eval_loss)\n", 186 | "\n", 187 | " # Accumulate running loss during each epoch\n", 188 | " running_loss += loss.item()\n", 189 | " print(' epoch {:3d}: average loss {:.3f}'.format(\n", 190 | " epoch + 1, running_loss / len(trainset))) \n", 191 | "\n", 192 | "print('Finished training.')" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 12, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "Accuracy of the model on 10000 test images: 38.73 %\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "#-------------------------------------\n", 210 | "# Compute accuracy on the test dataset\n", 211 | "n_correct = 0\n", 212 | "n_testset = 0\n", 213 | "with torch.no_grad():\n", 214 | " net.eval()\n", 215 | " for (images, labels) in testloader:\n", 216 | " images, labels = images.to(device), labels.to(device)\n", 217 | " outputs = net(images)\n", 218 | " _, predicted = torch.max(outputs.data, 1)\n", 219 | " n_testset += labels.size(0)\n", 220 | " n_correct += (predicted == labels).sum().item()\n", 221 | "\n", 222 | "print('Accuracy of the model on {} test images: {} %'.format(\n", 223 | " n_testset, 100 * n_correct / n_testset))" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [] 232 | } 233 | ], 234 | "metadata": { 235 | "kernelspec": { 236 | "display_name": "Python 3", 237 | "language": "python", 238 | "name": "python3" 239 | }, 240 | "language_info": { 241 | "codemirror_mode": { 242 | "name": "ipython", 243 | "version": 3 244 | }, 245 | "file_extension": ".py", 246 | "mimetype": "text/x-python", 247 | "name": "python", 248 | "nbconvert_exporter": "python", 249 | "pygments_lexer": "ipython3", 250 | "version": "3.6.8" 251 | }, 252 | "pycharm": { 253 | "stem_cell": { 254 | "cell_type": "raw", 255 | "source": [], 256 | "metadata": { 257 | "collapsed": false 258 | } 259 | } 260 | } 261 | }, 262 | "nbformat": 4, 263 | "nbformat_minor": 2 264 | } -------------------------------------------------------------------------------- /examples/cifar_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import math 5 | import torch 6 | import torchvision 7 | import statopt 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser(description='PyTorch Cifar10 Training') 11 | parser.add_argument('--opt', choices=['sgd', 'sasa', 'salsa'], default='sgd') 12 | parser.add_argument('--cifarify', type=int, default=0) 13 | args = parser.parse_args() 14 | 15 | #---------------------------------------------------- 16 | # Prepare datasets, download to the directory ../data 17 | print('Preparing data ...') 18 | batch_size = 128 19 | normalizer = torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), 20 | (0.2023, 0.1994, 0.2010)) 21 | transform_train = torchvision.transforms.Compose( 22 | [torchvision.transforms.RandomCrop(32, padding=4), 23 | torchvision.transforms.RandomHorizontalFlip(), 24 | torchvision.transforms.ToTensor(), normalizer,]) 25 | trainset = torchvision.datasets.CIFAR10(root='../data', train=True, 26 | download=True, 27 | transform=transform_train) 28 | sampler = torch.utils.data.sampler.RandomSampler(trainset) 29 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 30 | sampler=sampler, num_workers=4) 31 | transform_test = torchvision.transforms.Compose( 32 | [torchvision.transforms.ToTensor(), normalizer,]) 33 | testset = torchvision.datasets.CIFAR10(root='../data', train=False, 34 | download=True, 35 | transform=transform_test) 36 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 37 | shuffle=False, num_workers=4) 38 | 39 | #----------------------------------------------- 40 | # Choose device, network model and loss function 41 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 42 | net = torchvision.models.resnet18(num_classes=10) 43 | 44 | if args.cifarify: 45 | class Identity(torch.nn.Module): 46 | def __init__(self): 47 | super(Identity, self).__init__() 48 | 49 | def forward(self, x): 50 | return x 51 | 52 | net.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 53 | net.maxpool = Identity() 54 | 55 | net = net.to(device) 56 | loss_func = torch.nn.CrossEntropyLoss() 57 | 58 | #-------------------------------------------------------- 59 | # Choose optimizer from the list ['sgd', 'sasa', 'salsa'] 60 | optimizer_name = args.opt 61 | print('Using optimier {}'.format(optimizer_name)) 62 | 63 | if optimizer_name == 'sasa': 64 | testfreq = min(1000, len(trainloader)) 65 | optimizer = statopt.SASA(net.parameters(), lr=1.0, 66 | momentum=0.9, weight_decay=5e-4, 67 | testfreq=testfreq) 68 | elif optimizer_name == 'salsa': 69 | gamma = math.sqrt(batch_size/len(trainset)) 70 | testfreq = min(1000, len(trainloader)) 71 | optimizer = statopt.SALSA(net.parameters(), lr=1e-3, 72 | momentum=0.9, weight_decay=5e-4, 73 | gamma=gamma, testfreq=testfreq) 74 | else: 75 | optimizer_name = 'sgd' # SGD with a Step learning rate scheduler 76 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1, 77 | momentum=0.9, weight_decay=5e-4) 78 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50*len(trainloader), 79 | gamma=0.1, last_epoch=-1) 80 | 81 | #---------------------------------- 82 | # Training the neural network model 83 | print('Start training ...') 84 | 85 | for epoch in range(250): 86 | # Reset accumulative running loss at beginning or each epoch 87 | running_loss = 0.0 88 | 89 | for (images, labels) in trainloader: 90 | # switch to train mode each time due to potential use of eval mode 91 | net.train() 92 | 93 | # Compute model outputs and loss function 94 | images, labels = images.to(device), labels.to(device) 95 | outputs = net(images) 96 | loss = loss_func(outputs, labels) 97 | 98 | # Compute gradient with back-propagation 99 | optimizer.zero_grad() 100 | loss.backward() 101 | 102 | # Call the step() method of different optimizers 103 | if optimizer_name == 'sgd': 104 | optimizer.step() 105 | scheduler.step() 106 | elif optimizer_name == 'sasa': 107 | optimizer.step() 108 | elif optimizer_name == 'salsa': 109 | def eval_loss(eval_mode=True): 110 | if eval_mode: 111 | net.eval() 112 | with torch.no_grad(): 113 | loss = loss_func(net(images), labels) 114 | return loss 115 | optimizer.step(closure=eval_loss) 116 | 117 | # Accumulate running loss during each epoch 118 | running_loss += loss.item() 119 | print(' epoch {:3d}: average loss {:.3f}'.format( 120 | epoch + 1, running_loss / len(trainset))) 121 | 122 | print('Finished training.') 123 | 124 | #------------------------------------- 125 | # Compute accuracy on the test dataset 126 | n_correct = 0 127 | n_testset = 0 128 | with torch.no_grad(): 129 | net.eval() 130 | for (images, labels) in testloader: 131 | images, labels = images.to(device), labels.to(device) 132 | outputs = net(images) 133 | _, predicted = torch.max(outputs.data, 1) 134 | n_testset += labels.size(0) 135 | n_correct += (predicted == labels).sum().item() 136 | 137 | print('Accuracy of the model on {} test images: {} %'.format( 138 | n_testset, 100 * n_correct / n_testset)) 139 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='statopt', 5 | version='0.2', 6 | packages=find_packages(exclude=['tests*']), 7 | license='MIT', 8 | description='Statistical adaptive stochastic optimization methods', 9 | long_description=open('README.md').read(), 10 | long_description_content_type='text/markdown', 11 | install_requires=['numpy', 'scipy', 'torch'], 12 | url='https://github.com/microsoft/statopt' 13 | ) 14 | -------------------------------------------------------------------------------- /statopt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .qhm import QHM 5 | from .sasa import SASA 6 | from .ssls import SSLS 7 | from .salsa import SALSA 8 | from .slope import SLOPE 9 | -------------------------------------------------------------------------------- /statopt/bucket.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import math 5 | from scipy import stats 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | # Use a Leaky Bucket to store a fraction of most recent statistics 10 | # Wikepedia article: https://en.wikipedia.org/wiki/Leaky_bucket 11 | class LeakyBucket(object): 12 | def __init__(self, size, ratio, dtype, device, fixed_len=-1): 13 | ''' 14 | size: size of allocated memory buffer to keep the leaky bucket queue, 15 | which will be doubled whenever the memory is full 16 | ratio: integer ratio of total number of samples to numbers to be kept: 17 | 1 - keep all, 18 | 2 - keep most recent 1/2, 19 | 3 - keep most recent 1/3, 20 | ... 21 | fixed_len: fixed length to keep, ratio >=1 becomes irrelevant 22 | ''' 23 | self.size = size 24 | self.ratio = int(ratio) 25 | self.fixed_len = int(fixed_len) 26 | 27 | self.buffer = torch.zeros(size, dtype=dtype, device=device) 28 | self.count = 0 # number of elements kept in queue (excluding leaked) 29 | self.start = 0 # count = end - start 30 | self.end = 0 31 | self.total_count = 0 # total number of elements added (including leaked) 32 | 33 | def reset(self): 34 | self.buffer.zero_() 35 | self.count = 0 36 | self.start = 0 37 | self.end = 0 38 | self.total_count = 0 39 | 40 | def double_size(self): 41 | self.size *= 2 42 | self.buffer.resize_(self.size) 43 | 44 | def add(self, val): 45 | if self.end == self.size: # when the end index reach size 46 | self.double_size() # double the size of buffer 47 | 48 | self.buffer[self.end] = val # always put new value at the end 49 | self.end += 1 # and increase end index by one 50 | 51 | if self.fixed_len > 0: 52 | if self.count == self.fixed_len: 53 | self.start += 1 54 | else: 55 | self.count += 1 56 | else: 57 | if self.total_count % self.ratio == 0: # if leaky_count is multiple of ratio 58 | self.count += 1 # increase count in queue by one 59 | else: # otherwise leak and keep same count 60 | self.start += 1 # increase start index by one 61 | 62 | self.total_count += 1 # always increase total_count by one 63 | 64 | # reset start index to 0 and end index to count to save space 65 | if self.start >= self.count: 66 | self.buffer[0:self.count] = self.buffer[self.start:self.end] 67 | self.start = 0 68 | self.end = self.count 69 | 70 | # ! Need to add safeguard to allow compute only if there are enough entries 71 | def mean_std(self, mode='bm'): 72 | mean = torch.mean(self.buffer[self.start:self.end]).item() 73 | 74 | if mode == 'bm': # batch mean variance 75 | b_n = int(math.floor(math.sqrt(self.count))) 76 | Yks = F.avg_pool1d(self.buffer[self.start:self.end].unsqueeze(0).unsqueeze(0), kernel_size=b_n, stride=b_n).view(-1) 77 | diffs = Yks - mean 78 | std = math.sqrt(b_n /(len(Yks)-1))*torch.norm(diffs).item() 79 | dof = b_n - 1 80 | elif mode == 'olbm': # overlapping batch mean 81 | b_n = int(math.floor(math.sqrt(self.count))) 82 | Yks = F.avg_pool1d(self.buffer[self.start:self.end].unsqueeze(0).unsqueeze(0), kernel_size=b_n, stride=1).view(-1) 83 | diffs = Yks - mean 84 | std = math.sqrt(b_n*self.count/(len(Yks)*(len(Yks)-1)))*torch.norm(diffs).item() 85 | dof = self.count - b_n 86 | else: # otherwise use mode == 'iid' 87 | std = torch.std(self.buffer[self.start:self.end]).item() 88 | dof = self.count - 1 89 | 90 | return mean, std, dof 91 | 92 | def stats_test(self, sigma, mode='bm', composite_test=False): 93 | mean, std, dof = self.mean_std(mode=mode) 94 | 95 | # confidence interval 96 | t_sigma_dof = stats.t.ppf(1-sigma/2., dof) 97 | half_width = std * t_sigma_dof / math.sqrt(self.count) 98 | lower = mean - half_width 99 | upper = mean + half_width 100 | # The simple confidence interval test 101 | # stationarity = lower < 0 and upper > 0 102 | 103 | # A more stable test is to also check if two half-means are of the same sign 104 | half_point = self.start + int(math.floor(self.count / 2)) 105 | mean1 = torch.mean(self.buffer[self.start : half_point]).item() 106 | mean2 = torch.mean(self.buffer[half_point : self.end]).item() 107 | stationarity = (lower < 0 and upper > 0) and (mean1 * mean2 > 0) 108 | 109 | if composite_test: 110 | # Use two half tests to avoid false positive caused by crossing 0 in transient phase 111 | lb1 = mean1 - half_width 112 | ub1 = mean1 + half_width 113 | lb2 = mean2 - half_width 114 | ub2 = mean2 + half_width 115 | stationarity = (lb1 * ub1 < 0) and (lb2 * ub2 < 0) and (mean1 * mean2 > 0) 116 | 117 | return stationarity, mean, lower, upper 118 | 119 | # method to test if average loss after line search is no longer decreasing 120 | def rel_reduction(self): 121 | if self.count < 4: 122 | return 0.5 123 | half_point = self.start + int(math.floor(self.count / 2)) 124 | mean1 = torch.mean(self.buffer[self.start : half_point]).item() 125 | mean2 = torch.mean(self.buffer[half_point : self.end]).item() 126 | return (mean1 - mean2) / mean1 127 | 128 | # method to test if average loss after line search is no longer decreasing 129 | def is_decreasing(self, min_cnt=1000, dec_rate=0.01): 130 | if self.count < min_cnt: 131 | return True 132 | half_point = self.start + int(math.floor(self.count / 2)) 133 | mean1 = torch.mean(self.buffer[self.start : half_point]).item() 134 | mean2 = torch.mean(self.buffer[half_point : self.end]).item() 135 | return (mean1 - mean2) / mean1 > dec_rate 136 | 137 | def linregress(self, sigma, mode='linear'): 138 | """ 139 | calculate a linear regression 140 | sigma: the confidence of the one-side test 141 | H0: slope >= 0 vs H1: slope < 0 142 | mode: whether log scale the x axis 143 | """ 144 | TINY = 1.0e-20 145 | x = torch.arange(self.total_count-self.count, self.total_count, 146 | dtype=self.buffer.dtype, device=self.buffer.device) 147 | if mode == 'log': 148 | x = torch.log(x) 149 | # both x and y has dimension (self.count,) 150 | xy = torch.cat([x.view(1, -1), 151 | self.buffer[self.start:self.end].view(1, -1)], 152 | dim=0) 153 | # compute covariance matrix 154 | fact = 1.0 / self.count 155 | xy -= torch.mean(xy, dim=1, keepdim=True) 156 | xyt = xy.t() 157 | cov = fact * xy.matmul(xyt).squeeze() 158 | # compute the t-statistics 159 | r_num = cov[0, 1].item() 160 | r_den = torch.sqrt(cov[0, 0]*cov[1, 1]).item() 161 | if r_den == 0.0: 162 | r = 0.0 163 | else: 164 | r = r_num / r_den 165 | # test for numerical error propagation 166 | if r > 1.0: 167 | r = 1.0 168 | elif r < -1.0: 169 | r = -1.0 170 | 171 | df = self.count - 2 172 | t = r * math.sqrt(df / ((1.0 - r + TINY) * (1.0 + r + TINY))) 173 | # one-sided test for decreasing 174 | prob = stats.t.cdf(t, df) 175 | is_decreasing = prob < sigma 176 | # slop 177 | slope = r_num / cov[0, 0].item() 178 | return is_decreasing, slope, prob 179 | -------------------------------------------------------------------------------- /statopt/qhm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | from torch.optim import Optimizer 6 | 7 | 8 | class QHM(Optimizer): 9 | r""" 10 | Stochastic gradient method with Quasi-Hyperbolic Momentum (QHM): 11 | 12 | h(k) = (1 - \beta) * g(k) + \beta * h(k-1) 13 | d(k) = (1 - \nu) * g(k) + \nu * h(k) 14 | x(k+1) = x(k) - \alpha * d(k) 15 | 16 | "Quasi-hyperbolic momentum and Adam for deep learning" 17 | by Jerry Ma and Denis Yarats, ICLR 2019 18 | 19 | optimizer = QHM(params, lr=-1, momentum=0, qhm_nu=1, weight_decay=0) 20 | 21 | Args: 22 | params (iterable): iterable params to optimize or dict of param groups 23 | lr (float): learning rate, \alpha in QHM update (default:-1 need input) 24 | momentum (float, optional): \beta in QHM update, range[0,1) (default:0) 25 | qhm_nu (float, optional): \nu in QHM update, range[0,1] (default: 1) 26 | \nu = 0: SGD without momentum (\beta is ignored) 27 | \nu = 1: SGD with momentum \beta and dampened gradient (1-\beta) 28 | \nu = \beta: SGD with "Nesterov momentum" \beta 29 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 30 | 31 | Example: 32 | >>> optimizer = torch.optim.QHM(model.parameters(), lr=0.1, momentum=0.9) 33 | >>> optimizer.zero_grad() 34 | >>> loss_fn(model(input), target).backward() 35 | >>> optimizer.step() 36 | 37 | """ 38 | 39 | def __init__(self, params, lr=-1, momentum=0, qhm_nu=1, weight_decay=0): 40 | # nu can take values outside of the interval [0,1], but no guarantee of convergence? 41 | if lr <= 0: 42 | raise ValueError("Invalid value for learning rate (>0): {}".format(lr)) 43 | if momentum < 0 or momentum > 1: 44 | raise ValueError("Invalid value for momentum [0,1): {}".format(momentum)) 45 | if weight_decay < 0: 46 | raise ValueError("Invalid value for weight_decay (>=0): {}".format(weight_decay)) 47 | 48 | defaults = dict(lr=lr, momentum=momentum, qhm_nu=qhm_nu, weight_decay=weight_decay) 49 | super(QHM, self).__init__(params, defaults) 50 | 51 | # extra_buffer == True only in SSLS with momentum > 0 and nu != 1 52 | self.state['allocate_step_buffer'] = False 53 | 54 | def step(self, closure=None): 55 | """ 56 | Performs a single optimization step. 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates model and returns loss. 59 | """ 60 | loss = None 61 | if closure is not None: 62 | loss = closure() 63 | 64 | self.add_weight_decay() 65 | self.qhm_direction() 66 | self.qhm_update() 67 | 68 | return loss 69 | 70 | def add_weight_decay(self): 71 | # weight_decay is the same as adding L2 regularization 72 | for group in self.param_groups: 73 | weight_decay = group['weight_decay'] 74 | for p in group['params']: 75 | if p.grad is None: 76 | continue 77 | if weight_decay > 0: 78 | p.grad.data.add_(weight_decay, p.data) 79 | 80 | def qhm_direction(self): 81 | 82 | for group in self.param_groups: 83 | momentum = group['momentum'] 84 | qhm_nu = group['qhm_nu'] 85 | 86 | for p in group['params']: 87 | if p.grad is None: 88 | continue 89 | x = p.data # Optimization parameters 90 | g = p.grad.data # Stochastic gradient 91 | 92 | # Compute the (negative) step directoin d and necessary momentum 93 | state = self.state[p] 94 | if abs(momentum) < 1e-12 or abs(qhm_nu) < 1e-12: # simply SGD if beta=0 or nu=0 95 | d = state['step_buffer'] = g 96 | else: 97 | if 'momentum_buffer' not in state: 98 | h = state['momentum_buffer'] = torch.zeros_like(x) 99 | else: 100 | h = state['momentum_buffer'] 101 | # Update momentum buffer: h(k) = (1 - \beta) * g(k) + \beta * h(k-1) 102 | h.mul_(momentum).add_(1 - momentum, g) 103 | 104 | if abs(qhm_nu - 1) < 1e-12: # if nu=1, then same as SGD with momentum 105 | d = state['step_buffer'] = h 106 | else: 107 | if self.state['allocate_step_buffer']: # copy from gradient 108 | if 'step_buffer' not in state: 109 | state['step_buffer'] = torch.zeros_like(g) 110 | d = state['step_buffer'].copy_(g) 111 | else: # use gradient buffer 112 | d = state['step_buffer'] = g 113 | # Compute QHM momentum: d(k) = (1 - \nu) * g(k) + \nu * h(k) 114 | d.mul_(1 - qhm_nu).add_(qhm_nu, h) 115 | 116 | def qhm_update(self): 117 | """ 118 | Perform QHM update, need to call compute_qhm_direction() before calling this. 119 | """ 120 | for group in self.param_groups: 121 | for p in group['params']: 122 | if p.grad is not None: 123 | p.data.add_(-group['lr'], self.state[p]['step_buffer']) 124 | 125 | -------------------------------------------------------------------------------- /statopt/salsa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .ssls import SSLS 5 | from .sasa import SASA 6 | from .bucket import LeakyBucket 7 | 8 | 9 | class SALSA(SASA, SSLS): 10 | r""" 11 | SALSA: Statistical Approximation with Line-search and Statistical Adaptation 12 | a combination of the following two methods with automatic switch 13 | SSLS: Smoothed Stochastic Line Search 14 | SASA: Statistical Adaptive Stochastic Approximation 15 | 16 | Stochastic gradient with Quasi-Hyperbolic Momentum (QHM): 17 | h(k) = (1 - \beta) * g(k) + \beta * h(k-1) 18 | d(k) = (1 - \nu) * g(k) + \nu * h(k) 19 | x(k+1) = x(k) - \alpha(k) * d(k) 20 | 21 | How to use it: (same as SSLS, except for a warmup parameter for SASA) 22 | >>> optimizer = SALSA(model.parameters(), lr=1, momentum=0.9, qhm_nu=1, 23 | >>> weight_decay=1e-4, gamma=0.01, warmup=1000) 24 | >>> for input, target in dataset: 25 | >>> model.train() 26 | >>> optimizer.zero_grad() 27 | >>> loss_func(model(input), target).backward() 28 | >>> def eval_loss(eval_mode=True): # closure function for line search 29 | >>> if eval_mode: 30 | >>> model.eval() 31 | >>> with torch.no_grad(): 32 | >>> output = model(input) 33 | >>> loss = loss_func(output, target) 34 | >>> return loss 35 | >>> optimizer.step(eval_loss) 36 | """ 37 | 38 | def __init__(self, params, lr=1e-3, momentum=0, qhm_nu=1, weight_decay=0, gamma=0.1, 39 | ls_sdc=0.05, ls_inc=2.0, ls_dec=0.5, ls_max=2, 40 | ls_evl=1, ls_dir='g', ls_cos=0, # should not change these 3 defaults 41 | auto_switch=1, 42 | warmup=0, drop_factor=10, significance=0.05, var_mode='mb', 43 | leak_ratio=8, minN_stats=1000, testfreq=100, logstats=0): 44 | 45 | SASA.__init__(self, params, lr, momentum, qhm_nu, weight_decay, 46 | warmup, drop_factor, significance, var_mode, 47 | leak_ratio, minN_stats, testfreq, logstats) 48 | # State from QHM: Extra_buffer used only if momentum > 0 and nu != 1 49 | self.state['allocate_step_buffer'] = True 50 | 51 | # Initialize states of SSLS here 52 | self.state['lr'] = float(lr) 53 | self.state['gamma'] = gamma 54 | self.state['ls_sdc'] = ls_sdc 55 | self.state['ls_inc'] = ls_inc 56 | self.state['ls_dec'] = ls_dec 57 | self.state['ls_max'] = int(ls_max) 58 | self.state['ls_evl'] = bool(ls_evl) 59 | self.state['ls_dir'] = ls_dir 60 | self.state['ls_cos'] = bool(ls_cos) 61 | 62 | # state for tracking cosine of angle between g and d. 63 | self.state['cosine'] = 0.0 64 | self.state['ls_eta'] = 0.0 65 | self.state['ls_cnt'] = 0 66 | 67 | self.state['auto_switch'] = bool(auto_switch) 68 | self.state['switched'] = False 69 | 70 | # State initialization: leaky bucket to store mini-batch loss values 71 | p = self.param_groups[0]['params'][0] 72 | self.state['ls_bucket'] = LeakyBucket(1000, leak_ratio, p.dtype, p.device) 73 | 74 | def step(self, closure): 75 | """ 76 | Performs a single optimization step. 77 | Arguments: 78 | closure (callable): A closure that reevaluates model and returns loss. 79 | """ 80 | 81 | if self.state['auto_switch']: 82 | self.step_auto_switch(closure) 83 | else: 84 | self.step_mannual_switch(closure) 85 | 86 | return None 87 | 88 | 89 | def step_auto_switch(self, closure): 90 | 91 | self.add_weight_decay() 92 | self.qhm_direction() 93 | 94 | if not self.state['switched']: 95 | _, fval = self.line_search(closure) 96 | self.state['ls_bucket'].add(fval) 97 | if self.state['ls_bucket'].count > self.state['minN_stats']: 98 | is_decreasing = self.state['ls_bucket'].linregress(self.state['significance'])[0] 99 | if not is_decreasing: 100 | self.state['switched'] = True 101 | print("SALSA: auto switch due to non-descreasing training loss.") 102 | 103 | self.qhm_update() 104 | 105 | self.state['nSteps'] += 1 106 | self.stats_adaptation() 107 | if self.state['stats_test'] and self.state['stats_stationary']: 108 | self.state['switched'] = True 109 | print("SALSA: auto switch due to stationarityy test") 110 | 111 | 112 | def step_mannual_switch(self, closure): 113 | 114 | if self.state['nSteps'] < self.state['warmup']: 115 | SSLS.step(self, closure) 116 | self.state['nSteps'] += 1 117 | else: 118 | SASA.step(self) 119 | 120 | -------------------------------------------------------------------------------- /statopt/sasa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import math 5 | import torch 6 | from torch.optim import Optimizer 7 | import torch.nn.functional as F 8 | from .qhm import QHM 9 | from .bucket import LeakyBucket 10 | 11 | 12 | class SASA(QHM): 13 | r""" 14 | Statistical Adaptive Stochastic Approximation (SASA+) with master condition. 15 | 16 | optimizer = SASA(params, lr=-1, momentum=0, qhm_nu=1, weight_decay=0, 17 | warmup=0, drop_factor=2, significance=0.02, var_mode='bm', 18 | leak_ratio=4, minN_stats=400, testfreq=100, logstats=0) 19 | 20 | Stochastic gradient with Quasi-Hyperbolic Momentum (QHM): 21 | 22 | h(k) = (1 - \beta) * g(k) + \beta * h(k-1) 23 | d(k) = (1 - \nu) * g(k) + \nu * h(k) 24 | x(k+1) = x(k) - \alpha * d(k) 25 | 26 | Stationary criterion: 27 | E[ ] - (\alpha / 2) * ||d(k)||^2 ] = 0 28 | or equivalently, 29 | E[ ] + (\alpha / 2) * ||d(k)||^2 ] = 0 30 | 31 | Args: 32 | params (iterable): iterable params to optimize or dict of param groups 33 | lr (float): learning rate, \alpha in QHM update (default:-1 need input) 34 | momentum (float, optional): \beta in QHM update, range(0,1) (default:0) 35 | qhm_nu (float, optional): \nu in QHM update, range(0,1) (default: 1) 36 | \nu = 0: SGD without momentum (\beta is ignored) 37 | \nu = 1: SGD with momentum and dampened gradient 38 | \nu = \beta: SGD with "Nesterov momentum" 39 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 40 | warmup (int, optional): number of steps before testing (default: 100) 41 | dropfactor (float, optional): factor of drop learning rate (default: 10) 42 | significance (float, optional): test significance level (default:0.05) 43 | var_mode (string, optional): variance computing mode (default: 'mb') 44 | leak_ratio (int, optional): leaky bucket ratio to kept (default: 8) 45 | minN_stats (int, optional): min number of samples for test (default: 1000) 46 | testfreq (int, optional): number of steps between testing (default:100) 47 | logstats (int, optional): number of steps between logs (0 means no log) 48 | 49 | Example: 50 | >>> optimizer = torch.optim.SASA(model.parameters(), lr=0.1, momentum=0.9, 51 | >>> weight_decay=0.0005) 52 | >>> optimizer.zero_grad() 53 | >>> loss_fn(model(input), target).backward() 54 | >>> optimizer.step() 55 | 56 | """ 57 | 58 | def __init__(self, params, lr=-1, momentum=0, qhm_nu=1, weight_decay=0, 59 | warmup=1000, drop_factor=10, significance=0.05, var_mode='mb', 60 | leak_ratio=8, minN_stats=1000, testfreq=100, logstats=0): 61 | 62 | if lr <= 0: 63 | raise ValueError("Invalid value for learning rate (>0): {}".format(lr)) 64 | if momentum < 0 or momentum > 1: 65 | raise ValueError("Invalid value for momentum [0,1): {}".format(momentum)) 66 | if weight_decay < 0: 67 | raise ValueError("Invalid value for weight_decay (>=0): {}".format(weight_decay)) 68 | if drop_factor < 1: 69 | raise ValueError("Invalid value for drop_factor (>=1): {}".format(drop_factor)) 70 | if significance <= 0 or significance >= 1: 71 | raise ValueError("Invalid value for significance (0,1): {}".format(significance)) 72 | if var_mode not in ['mb', 'olbm', 'iid']: 73 | raise ValueError("Invalid value for var_mode ('mb', 'olmb', or 'iid'): {}".format(var_mode)) 74 | if leak_ratio < 1: 75 | raise ValueError("Invalid value for leak_ratio (int, >=1): {}".format(leak_ratio)) 76 | if minN_stats < 100: 77 | raise ValueError("Invalid value for minN_stats (int, >=100): {}".format(minN_stats)) 78 | if warmup < 0: 79 | raise ValueError("Invalid value for warmup (int, >1): {}".format(warmup)) 80 | if testfreq < 1: 81 | raise ValueError("Invalid value for testfreq (int, >=1): {}".format(testfreq)) 82 | 83 | super(SASA, self).__init__(params, lr=lr, momentum=momentum, qhm_nu=qhm_nu, weight_decay=weight_decay) 84 | # New Python3 way to call super() 85 | # super().__init__(params, lr=lr, momentum=momentum, nu=nu, weight_decay=weight_decay) 86 | 87 | # State initialization: leaky bucket belongs to global state. 88 | p = self.param_groups[0]['params'][0] 89 | if 'bucket' not in self.state: 90 | self.state['bucket'] = LeakyBucket(1000, leak_ratio, p.dtype, p.device) 91 | 92 | self.state['lr'] = float(lr) 93 | self.state['drop_factor'] = drop_factor 94 | self.state['significance'] = significance 95 | self.state['var_mode'] = var_mode 96 | self.state['minN_stats'] = int(minN_stats) 97 | self.state['warmup'] = int(warmup) 98 | self.state['testfreq'] = int(testfreq) 99 | self.state['logstats'] = int(logstats) 100 | self.state['composite_test'] = True # first drop use composite test 101 | self.state['nSteps'] = 0 # steps counter +1 every iteration 102 | 103 | # statistics to monitor 104 | self.state['stats_x1d'] = 0 105 | self.state['stats_ld2'] = 0 106 | self.state['stats_val'] = 0 107 | self.state['stats_test'] = 0 108 | self.state['stats_stationary'] = 0 109 | self.state['stats_mean'] = 0 110 | self.state['stats_lb'] = 0 111 | self.state['stats_ub'] = 0 112 | 113 | def step(self, closure=None): 114 | """ 115 | Performs a single optimization step. 116 | Arguments: 117 | closure (callable, optional): A closure that reevaluates model and returns loss. 118 | """ 119 | loss = None 120 | if closure is not None: 121 | loss = closure() 122 | 123 | self.add_weight_decay() 124 | self.qhm_direction() 125 | self.qhm_update() 126 | self.state['nSteps'] += 1 127 | self.stats_adaptation() 128 | 129 | return loss 130 | 131 | def stats_adaptation(self): 132 | 133 | # compute and ||d(k)||^2 for statistical test 134 | self.state['stats_x1d'] = 0.0 135 | self.state['stats_ld2'] = 0.0 136 | for group in self.param_groups: 137 | for p in group['params']: 138 | if p.grad is None: 139 | continue 140 | xk1 = p.data.view(-1) 141 | dk = self.state[p]['step_buffer'].data.view(-1) # OK after super().step() 142 | self.state['stats_x1d'] += xk1.dot(dk).item() 143 | self.state['stats_ld2'] += dk.dot(dk).item() 144 | self.state['stats_ld2'] *= 0.5 * self.state['lr'] 145 | 146 | # Gather flat buffers can take too much memory for large models 147 | # Compute and ||d(k)||^2 for statistical test 148 | # dk = self._gather_flat_buffer('step_buffer') 149 | # xk1 = self._gather_flat_param() 150 | # self.state['stats_x1d'] = xk1.dot(dk).item() 151 | # self.state['stats_ld2'] = (0.5 * self.state['lr']) * (dk.dot(dk).item()) 152 | 153 | # add statistic to leaky bucket 154 | self.state['stats_val'] = self.state['stats_x1d'] + self.state['stats_ld2'] 155 | bucket = self.state['bucket'] 156 | bucket.add(self.state['stats_val']) 157 | 158 | # check statistics and adjust learning rate 159 | self.state['stats_test'] = 0 160 | self.state['stats_stationary'] = 0 161 | if bucket.count > self.state['minN_stats'] and self.state['nSteps'] % self.state['testfreq'] == 0: 162 | stationary, mean, lb, ub = bucket.stats_test(self.state['significance'], 163 | self.state['var_mode'], 164 | self.state['composite_test']) 165 | self.state['stats_test'] = 1 166 | self.state['stats_stationary'] = int(stationary) 167 | self.state['stats_mean'] = mean 168 | self.state['stats_lb'] = lb 169 | self.state['stats_ub'] = ub 170 | # perform statistical test for stationarity 171 | if self.state['nSteps'] > self.state['warmup'] and stationary: 172 | self.state['lr'] /= self.state['drop_factor'] 173 | for group in self.param_groups: 174 | group['lr'] = self.state['lr'] 175 | self._zero_buffers('momentum_buffer') 176 | self.state['composite_test'] = False 177 | bucket.reset() 178 | 179 | # Log statistics only for debugging. Therefore self.state['stats_test'] remains False 180 | if self.state['logstats'] and not self.state['stats_test']: 181 | if bucket.count > bucket.ratio and self.state['nSteps'] % self.state['logstats'] == 0: 182 | stationary, mean, lb, ub = bucket.stats_test(self.state['significance'], 183 | self.state['var_mode'], 184 | self.state['composite_test']) 185 | self.state['stats_stationary'] = int(stationary) 186 | self.state['stats_mean'] = mean 187 | self.state['stats_lb'] = lb 188 | self.state['stats_ub'] = ub 189 | 190 | 191 | # methods for gather flat parameters 192 | def _gather_flat_param(self): 193 | views = [] 194 | for group in self.param_groups: 195 | for p in group['params']: 196 | view = p.data.view(-1) 197 | views.append(view) 198 | return torch.cat(views, 0) 199 | 200 | # method for gathering/initializing flat buffers that are the same shape as the parameters 201 | def _gather_flat_buffer(self, buf_name): 202 | views = [] 203 | for group in self.param_groups: 204 | for p in group['params']: 205 | state = self.state[p] 206 | if buf_name not in state: # init buffer 207 | view = p.data.new(p.data.numel()).zero_() 208 | else: 209 | view = state[buf_name].data.view(-1) 210 | views.append(view) 211 | return torch.cat(views, 0) 212 | 213 | def _zero_buffers(self, buf_name): 214 | for group in self.param_groups: 215 | for p in group['params']: 216 | state = self.state[p] 217 | if buf_name in state: 218 | state[buf_name].zero_() 219 | return None 220 | -------------------------------------------------------------------------------- /statopt/slope.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import math 5 | import torch 6 | from torch.optim import Optimizer 7 | import torch.nn.functional as F 8 | from .qhm import QHM 9 | from .bucket import LeakyBucket 10 | 11 | 12 | class SLOPE(QHM): 13 | r""" 14 | Statistical Adaptive Slope Test. (t-test) 15 | 16 | optimizer = SLOPE(params, lr=-1, momentum=0, qhm_nu=1, weight_decay=0, 17 | warmup=0, drop_factor=2, significance=0.02, var_mode='linear', 18 | leak_ratio=8, minN_stats=400, testfreq=100, logstats=0) 19 | 20 | Stochastic gradient with Quasi-Hyperbolic Momentum (QHM): 21 | 22 | h(k) = (1 - \beta) * g(k) + \beta * h(k-1) 23 | d(k) = (1 - \nu) * g(k) + \nu * h(k) 24 | x(k+1) = x(k) - \alpha * d(k) 25 | 26 | Stationary criterion: 27 | H0: slope >= 0 vs H1: slope < 0 28 | 29 | Args: 30 | params (iterable): iterable params to optimize or dict of param groups 31 | lr (float): learning rate, \alpha in QHM update (default:-1 need input) 32 | momentum (float, optional): \beta in QHM update, range(0,1) (default:0) 33 | qhm_nu (float, optional): \nu in QHM update, range(0,1) (default: 1) 34 | \nu = 0: SGD without momentum (\beta is ignored) 35 | \nu = 1: SGD with momentum and dampened gradient 36 | \nu = \beta: SGD with "Nesterov momentum" 37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 38 | warmup (int, optional): number of steps before testing (default: 100) 39 | dropfactor (float, optional): factor of drop learning rate (default: 2) 40 | significance (float, optional): test significance level (default:0.02) 41 | var_mode (string, optional): linear regression mode (default: 'linear') 42 | leak_ratio (int, optional): leaky bucket ratio to kept (default: 8) 43 | minN_stats (int, optional): min number of samples for test (default: 400) 44 | testfreq (int, optional): number of steps between testing (default:100) 45 | logstats (int, optional): number of steps between logs (0 means no log) 46 | 47 | Example: 48 | >>> optimizer = torch.optim.SLOPE(model.parameters(), lr=0.1, momentum=0.9, 49 | weight_decay=0.0005) 50 | >>> optimizer.zero_grad() 51 | >>> loss_fn(model(input), target).backward() 52 | >>> optimizer.step() 53 | 54 | """ 55 | 56 | def __init__(self, params, lr=-1, momentum=0, qhm_nu=1, weight_decay=0, 57 | warmup=100, drop_factor=2, significance=0.05, var_mode='linear', 58 | leak_ratio=8, minN_stats=1000, testfreq=100, logstats=0): 59 | 60 | if lr <= 0: 61 | raise ValueError("Invalid value for learning rate (>0): {}".format(lr)) 62 | if momentum < 0 or momentum > 1: 63 | raise ValueError("Invalid value for momentum [0,1): {}".format(momentum)) 64 | if weight_decay < 0: 65 | raise ValueError("Invalid value for weight_decay (>=0): {}".format(weight_decay)) 66 | if drop_factor < 1: 67 | raise ValueError("Invalid value for drop_factor (>=1): {}".format(drop_factor)) 68 | if significance <= 0 or significance >= 1: 69 | raise ValueError("Invalid value for significance (0,1): {}".format(significance)) 70 | if var_mode not in ['linear', 'log']: 71 | raise ValueError("Invalid value for var_mode ('linear', 'log'): {}".format(var_mode)) 72 | if leak_ratio < 1: 73 | raise ValueError("Invalid value for leak_ratio (int, >=1): {}".format(leak_ratio)) 74 | if minN_stats < 100: 75 | raise ValueError("Invalid value for minN_stats (int, >=100): {}".format(minN_stats)) 76 | if warmup < 0: 77 | raise ValueError("Invalid value for warmup (int, >1): {}".format(warmup)) 78 | if testfreq < 1: 79 | raise ValueError("Invalid value for testfreq (int, >=1): {}".format(testfreq)) 80 | 81 | super(SLOPE, self).__init__(params, lr=lr, momentum=momentum, qhm_nu=qhm_nu, weight_decay=weight_decay) 82 | # New Python3 way to call super() 83 | # super().__init__(params, lr=lr, momentum=momentum, nu=nu, weight_decay=weight_decay) 84 | 85 | # State initialization: leaky bucket belongs to global state. 86 | p = self.param_groups[0]['params'][0] 87 | if 'bucket' not in self.state: 88 | self.state['bucket'] = LeakyBucket(1000, leak_ratio, p.dtype, p.device) 89 | 90 | self.state['lr'] = float(lr) 91 | self.state['drop_factor'] = drop_factor 92 | self.state['significance'] = significance 93 | self.state['var_mode'] = var_mode 94 | self.state['minN_stats'] = int(minN_stats) 95 | self.state['warmup'] = int(warmup) 96 | self.state['testfreq'] = int(testfreq) 97 | self.state['logstats'] = int(logstats) 98 | self.state['nSteps'] = 0 99 | 100 | # statistics to monitor 101 | self.state['stats_test'] = 0 102 | self.state['stats_stationary'] = 0 103 | self.state['stats_mean'] = 0 104 | self.state['conf'] = 0 105 | self.state['slope'] = 0 106 | 107 | def step(self, closure): 108 | """ 109 | Performs a single optimization step. 110 | Arguments: 111 | closure (callable, optional): A closure that reevaluates model and returns loss. 112 | """ 113 | loss = closure() 114 | 115 | self.add_weight_decay() 116 | self.qhm_direction() 117 | self.qhm_update() 118 | self.state['nSteps'] += 1 119 | self.stats_adaptation(loss) 120 | 121 | return loss 122 | 123 | def stats_adaptation(self, loss): 124 | # add loss statistic to leaky bucket 125 | bucket = self.state['bucket'] 126 | bucket.add(loss.item() + self.L2_regu_loss()) 127 | 128 | # check statistics and adjust learning rate 129 | self.state['stats_test'] = 0 130 | self.state['stats_stationary'] = 0 131 | if bucket.count > self.state['minN_stats'] and self.state['nSteps'] % self.state['testfreq'] == 0: 132 | is_decreasing, slope, prob = bucket.linregress(self.state['significance'], 133 | self.state['var_mode']) 134 | self.state['stats_test'] = 1 135 | self.state['stats_stationary'] = 1 - int(is_decreasing) 136 | self.state['conf'] = prob 137 | self.state['slope'] = slope 138 | # perform statistical test for stationarity 139 | if self.state['nSteps'] > self.state['warmup'] and not is_decreasing: 140 | self.state['lr'] /= self.state['drop_factor'] 141 | for group in self.param_groups: 142 | group['lr'] = self.state['lr'] 143 | self._zero_buffers('momentum_buffer') 144 | bucket.reset() 145 | 146 | # Log statistics only for debugging. Therefore self.state['stats_test'] remains False 147 | if self.state['logstats'] and not self.state['stats_test']: 148 | if bucket.count > bucket.ratio and self.state['nSteps'] % self.state['logstats'] == 0: 149 | is_decreasing, slope, prob = bucket.linregress(self.state['significance'], 150 | self.state['var_mode']) 151 | self.state['stats_stationary'] = 1 - int(is_decreasing) 152 | self.state['conf'] = prob 153 | self.state['slope'] = slope 154 | 155 | 156 | def _zero_buffers(self, buf_name): 157 | for group in self.param_groups: 158 | for p in group['params']: 159 | state = self.state[p] 160 | if buf_name in state: 161 | state[buf_name].zero_() 162 | return None 163 | 164 | 165 | def L2_regu_loss(self): 166 | L2_loss = 0.0 167 | for group in self.param_groups: 168 | weight_decay = group['weight_decay'] 169 | for p in group['params']: 170 | if p.grad is None: 171 | continue 172 | x = p.data.view(-1) 173 | L2_loss += 0.5 * weight_decay * (x.dot(x)).item() 174 | return L2_loss 175 | -------------------------------------------------------------------------------- /statopt/ssls.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import math 5 | import torch 6 | from torch.optim import Optimizer 7 | from .qhm import QHM 8 | 9 | 10 | class SSLS(QHM): 11 | r""" 12 | QHM with Smoothed Stochastic Line Search (SSLS) for tuning learning rates 13 | 14 | optimizer = SSLS(params, lr=-1, momentum=0, qhm_nu=1, weight_decay=0, gamma=0.1, 15 | ls_sdc=0.1, ls_inc=2.0, ls_dec=0.5, ls_max=10, 16 | ls_evl=1, ls_dir='g', ls_cos=0) 17 | 18 | Stochastic gradient with Quasi-Hyperbolic Momentum (QHM): 19 | h(k) = (1 - \beta) * g(k) + \beta * h(k-1) 20 | d(k) = (1 - \nu) * g(k) + \nu * h(k) 21 | x(k+1) = x(k) - \alpha(k) * d(k) 22 | 23 | where \alpha(k) is smoothed version of \eta(k) obtained by line search 24 | (line search performed on loss defined by current mini-batch) 25 | 26 | \alpha(k) = (1 - \gamma) * \alpha(k-1) + \gamma * \eta(k) 27 | 28 | Suggestion: set smoothing parameter by batch size: \gamma = a * b / n 29 | The cumulative increase or decrease efficiency per epoch is (1-exp(-a)) 30 | 31 | Args: 32 | params (iterable): iterable params to optimize or dict of param groups 33 | lr (float): learning rate, \alpha in QHM update (default:-1 need input) 34 | momentum (float, optional): \beta in QHM update, range(0,1) (default:0) 35 | qhm_nu (float, optional): \nu in QHM update, range(0,1) (default: 1) 36 | \nu = 0: SGD without momentum (\beta is ignored) 37 | \nu = 1: SGD with momentum and dampened gradient 38 | \nu = \beta: SGD with "Nesterov momentum" 39 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) 40 | gamma (float, optional): smoothing parameter for line search (default: 0.01) 41 | The next four arguments can be tuned, but defaults should work well. 42 | ls_sdc (float, optional): sufficient decreasing coefficient (default: 0.05) 43 | ls_inc (float, optional): incremental factor (>1, default: 2.0) 44 | ls_dec (float, optional): decremental factor (<1, default: 0.5) 45 | ls_max (int, optional): maximum number of line searches (default: 2) 46 | The next three arguments are for research purpose only! 47 | ls_evl (bool, optional): whether or not use evaluation mode (default: 1) 48 | ls_dir (char, optional): 'g' for g(k) and 'd' for d(k) (default: 'g') 49 | ls_cos (bool, optional): whether or not use cosine between g and d (default: 0) 50 | 51 | How to use it: 52 | >>> optimizer = SSLS(model.parameters(), lr=1, momentum=0.9, qhm_nu=1, 53 | >>> weight_decay=1e-4, gamma=0.01) 54 | >>> for input, target in dataset: 55 | >>> model.train() 56 | >>> optimizer.zero_grad() 57 | >>> loss_func(model(input), target).backward() 58 | >>> def eval_loss(eval_mode=True): # closure function for line search 59 | >>> if eval_mode: 60 | >>> model.eval() 61 | >>> with torch.no_grad(): 62 | >>> output = model(input) 63 | >>> loss = loss_func(output, target) 64 | >>> return loss 65 | >>> optimizer.step(eval_loss) 66 | """ 67 | 68 | def __init__(self, params, lr=1e-3, momentum=0, qhm_nu=1, weight_decay=0, gamma=0.1, 69 | ls_sdc=0.05, ls_inc=2.0, ls_dec=0.5, ls_max=2, 70 | ls_evl=1, ls_dir='g', ls_cos=0): 71 | 72 | if lr <= 0: 73 | raise ValueError("Invalid value for learning rate (>=0): {}".format(lr)) 74 | if momentum < 0 or momentum > 1: 75 | raise ValueError("Invalid value for momentum [0,1]: {}".format(momentum)) 76 | if weight_decay < 0: 77 | raise ValueError("Invalid value for weight_decay (>=0): {}".format(weight_decay)) 78 | if gamma < 0 or gamma > 1: 79 | raise ValueError("Invalid value for gamma [0,1]: {}".format(gamma)) 80 | if ls_sdc <= 0 or ls_sdc >= 0.5: 81 | raise ValueError("Invalid value for ls_sdc (0,0.5): {}".format(ls_sdc)) 82 | if ls_inc < 1 : 83 | raise ValueError("Invalid value for ls_inc (>=1): {}".format(ls_inc)) 84 | if ls_dec <= 0 or ls_dec >= 1: 85 | raise ValueError("Invalid value for ls_dec (0,1): {}".format(ls_dec)) 86 | if ls_max < 1: 87 | raise ValueError("Invalid value for ls_max (>=1): {}".format(ls_max)) 88 | if ls_dir not in['g', 'd']: 89 | raise ValueError("Invalid value for ls_dir ('g' or 'd'): {}".format(ls_dir)) 90 | 91 | super(SSLS, self).__init__(params, lr=lr, momentum=momentum, qhm_nu=qhm_nu, weight_decay=weight_decay) 92 | # Extra_buffer used only if momentum > 0 and nu != 1 even though True is declared here! 93 | self.state['allocate_step_buffer'] = True 94 | 95 | self.state['lr'] = float(lr) 96 | self.state['gamma'] = gamma 97 | self.state['ls_sdc'] = ls_sdc 98 | self.state['ls_inc'] = ls_inc 99 | self.state['ls_dec'] = ls_dec 100 | self.state['ls_max'] = int(ls_max) 101 | self.state['ls_evl'] = bool(ls_evl) 102 | self.state['ls_dir'] = ls_dir 103 | self.state['ls_cos'] = bool(ls_cos) 104 | 105 | # state for tracking cosine of angle between g and d, line search result and count. 106 | self.state['cosine'] = 0.0 107 | self.state['ls_eta'] = 0.0 108 | self.state['ls_cnt'] = 0 109 | 110 | def step(self, closure): 111 | """ 112 | Performs a single optimization step. 113 | Arguments: 114 | closure (callable, eval_mode): A closure that reevaluates model and returns loss. 115 | """ 116 | 117 | self.add_weight_decay() 118 | self.qhm_direction() 119 | loss, _ = self.line_search(closure) 120 | self.qhm_update() 121 | 122 | return loss 123 | 124 | def line_search(self, closure): 125 | # need loss values using evaluation mode (or not) 126 | loss0 = closure(self.state['ls_evl']) 127 | 128 | # QHM search direction should be determined before line search and update 129 | g_dot_d = 0.0 130 | g_norm2 = 0.0 131 | d_norm2 = 0.0 132 | for group in self.param_groups: 133 | for p in group['params']: 134 | if p.grad is None: 135 | continue 136 | state = self.state[p] 137 | 138 | # first make copy of parameters before doing line search 139 | if 'ls_buffer' not in state: 140 | state['ls_buffer'] = torch.zeros_like(p.data) 141 | state['ls_buffer'].copy_(p.data) 142 | 143 | # compute inner product between g and d 144 | g = p.grad.data.view(-1) 145 | # should use (g.dot(d)).item() to use scalars! 146 | if self.state['ls_dir'] == 'g': 147 | d = g 148 | else: 149 | d = state['step_buffer'].view(-1) 150 | g_dot_d += g.dot(d).item() 151 | 152 | # if self.state['ls_dir'] == 'd' and self.state['ls_cos']: 153 | g_norm2 += g.dot(g).item() 154 | d_norm2 += d.dot(d).item() 155 | 156 | # line search on current mini-batch (not changing input to model) 157 | f0 = loss0.item() + self.L2_regu_loss() 158 | self.state['cosine'] = g_dot_d / math.sqrt(g_norm2 * d_norm2) 159 | # try a large instantaneous step size at beginning of line search 160 | if self.state['ls_dir'] == 'd' and self.state['ls_cos']: 161 | # The following also decreases eta from lr if cosine < 0 162 | # self.state['cosine'] = g_dot_d / math.sqrt(g_norm2 * d_norm2) 163 | eta = self.state['lr'] * math.pow(self.state['ls_inc'], self.state['cosine']) 164 | else: 165 | eta = self.state['lr'] * self.state['ls_inc'] 166 | 167 | ls_count = 0 168 | #while ls_count < self.state['ls_max']: 169 | while g_dot_d > 0 and ls_count < self.state['ls_max']: 170 | # update parameters x := x - eta * d 171 | for group in self.param_groups: 172 | for p in group['params']: 173 | if p.grad is None: 174 | continue 175 | if ls_count > 0: 176 | p.data.copy_(self.state[p]['ls_buffer']) 177 | if self.state['ls_dir'] == 'g': 178 | p.data.add_(-eta, p.grad.data) 179 | else: 180 | p.data.add_(-eta, self.state[p]['step_buffer']) 181 | 182 | # evaluate loss of new parameters 183 | f1 = closure(self.state['ls_evl']).item() + self.L2_regu_loss() 184 | # back-tracking line search 185 | if f1 > f0 - self.state['ls_sdc'] * eta * g_dot_d: 186 | eta *= self.state['ls_dec'] 187 | # Goldstein line search: not effective in increasing learning rate 188 | # elif f1 < f0 - (1 - self.state['ls_sdc']) * eta * g_dot_d: 189 | # eta *= self.state['ls_inc'] 190 | else: 191 | break 192 | ls_count += 1 193 | else: 194 | if g_dot_d <=0 and not self.state['ls_cos']: 195 | # if g_dot_d <=0: 196 | eta = self.state['lr'] * self.state['ls_dec'] 197 | #if g_dot_d > 0, then result of while loop is eta = lr * power(ls_dec, ls_max) 198 | 199 | # After line search over instantaneous step size, update learning rate by smoothing 200 | self.state['ls_eta'] = eta 201 | self.state['ls_cnt'] = ls_count 202 | self.state['lr'] = (1 - self.state['gamma']) * self.state['lr'] + self.state['gamma'] * eta 203 | # update lr in parameter groups AND reset weights to original value before line search 204 | for group in self.param_groups: 205 | group['lr'] = self.state['lr'] 206 | for p in group['params']: 207 | if p.grad is not None: 208 | p.data.copy_(self.state[p]['ls_buffer']) 209 | 210 | # f0 is always computed, but f1 may not. 211 | return loss0, f0 212 | 213 | def L2_regu_loss(self): 214 | L2_loss = 0.0 215 | for group in self.param_groups: 216 | weight_decay = group['weight_decay'] 217 | for p in group['params']: 218 | if p.grad is None: 219 | continue 220 | x = p.data.view(-1) 221 | L2_loss += 0.5 * weight_decay * (x.dot(x)).item() 222 | return L2_loss 223 | 224 | def gradient_norm(self): 225 | normsqrd = 0.0 226 | for group in self.param_groups: 227 | for p in group['params']: 228 | if p.grad is None: 229 | continue 230 | g = p.grad.data.view(-1) 231 | normsqrd += (g.dot(g)).item() 232 | return math.sqrt(normsqrd) 233 | 234 | def buffer_norm(self, buf_name): 235 | normsqrd = 0.0 236 | for group in self.param_groups: 237 | for p in group['params']: 238 | state = self.state[p] 239 | if buf_name in state: 240 | v = state[buf_name].data.view(-1) 241 | normsqrd += v.dot(v) 242 | return math.sqrt(normsqrd) 243 | --------------------------------------------------------------------------------