├── .gitignore ├── LICENSE ├── README.md ├── activation.png ├── cifar10_resnet.py ├── kan_pde.py ├── lbfgsb.py ├── lbfgsnew.py ├── loss.png └── resnet9.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LBFGS optimizer 2 | An improved LBFGS (and LBFGS-B) optimizer for PyTorch is provided with the code. Further details are given [in this paper](https://ieeexplore.ieee.org/document/8755567). Also see [this introduction](http://sagecal.sourceforge.net/pytorch/index.html). 3 | 4 | Examples of use: 5 | 6 | * Federated learning: see [these examples](https://github.com/SarodYatawatta/federated-pytorch-test). 7 | 8 | * Calibration and other inverse problems: see [radio interferometric calibration](https://github.com/SarodYatawatta/calibration-pytorch-test). 9 | 10 | * K-harmonic means clustering: see [LOFAR system health management](https://github.com/SarodYatawatta/LSHM). 11 | 12 | * Other problems: see [this example](https://ieeexplore.ieee.org/abstract/document/8588731). 13 | 14 | Files included are: 15 | 16 | ``` lbfgsnew.py ```: New LBFGS optimizer 17 | 18 | ``` lbfgsb.py ```: LBFGS-B optimizer (with bound constraints) 19 | 20 | ``` cifar10_resnet.py ```: CIFAR10 ResNet training example (see figures below) 21 | 22 | ``` kan_pde.py ```: Kolmogorov Arnold network PDE example using LBFGS-B 23 | 24 | ResNet18/101 training loss/time 25 | 26 | The above figure shows the training loss and training time [using Colab](https://colab.research.google.com/notebooks/intro.ipynb) with one GPU. ResNet18 and ResNet101 models are used. Test accuracy after 20 epochs: 84% for LBFGS and 82% for Adam. 27 | 28 | Changing the activation from commonly used ```ReLU``` to others like ```ELU``` gives faster convergence in LBFGS, as seen in the figure below. 29 | 30 | ResNet Wide 50-2 training loss 31 | 32 | Here is a comparison of both training error and test accuracy for ResNet9 using LBFGS and Adam. 33 | 34 | ResNet 9 training loss and test accuracy 35 | 36 | Example usage in full batch mode: 37 | 38 | ``` 39 | from lbfgsnew import LBFGSNew 40 | optimizer = LBFGSNew(model.parameters(), history_size=7, max_iter=100, line_search_fn=True, batch_mode=False) 41 | ``` 42 | 43 | Example usage in minibatch mode: 44 | 45 | ``` 46 | from lbfgsnew import LBFGSNew 47 | optimizer = LBFGSNew(model.parameters(), history_size=7, max_iter=2, line_search_fn=True, batch_mode=True) 48 | ``` 49 | 50 | Note: for certain problems, the gradient can also be part of the cost, for example in TV regularization. In such situations, give the option ```cost_use_gradient=True``` to ```LBFGSNew()```. However, this will increase the computational cost, so only use when needed. 51 | -------------------------------------------------------------------------------- /activation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlesc-dirac/pytorch/283ea6b93785b87ebe409f7fe887401a7ccfb313/activation.png -------------------------------------------------------------------------------- /cifar10_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | 5 | import math 6 | import time 7 | 8 | # (try to) use a GPU for computation? 9 | use_cuda=True 10 | if use_cuda and torch.cuda.is_available(): 11 | mydevice=torch.device('cuda') 12 | else: 13 | mydevice=torch.device('cpu') 14 | 15 | 16 | # try replacing relu with elu 17 | torch.manual_seed(69) 18 | default_batch=128 # no. of batches per epoch 50000/default_batch 19 | batches_for_report=10# 20 | 21 | transform=transforms.Compose( 22 | [transforms.ToTensor(), 23 | transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) 24 | 25 | 26 | trainset=torchvision.datasets.CIFAR10(root='./torchdata', train=True, 27 | download=True, transform=transform) 28 | 29 | trainloader=torch.utils.data.DataLoader(trainset, batch_size=default_batch, 30 | shuffle=True, num_workers=2) 31 | 32 | testset=torchvision.datasets.CIFAR10(root='./torchdata', train=False, 33 | download=True, transform=transform) 34 | 35 | testloader=torch.utils.data.DataLoader(testset, batch_size=default_batch, 36 | shuffle=False, num_workers=0) 37 | 38 | 39 | classes=('plane', 'car', 'bird', 'cat', 40 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 41 | 42 | 43 | 44 | import matplotlib.pyplot as plt 45 | import numpy as np 46 | 47 | from torch.autograd import Variable 48 | import torch.nn as nn 49 | import torch.nn.functional as F 50 | 51 | 52 | '''ResNet in PyTorch. 53 | Reference: 54 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 55 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 56 | 57 | From: https://github.com/kuangliu/pytorch-cifar 58 | ''' 59 | import torch 60 | import torch.nn as nn 61 | import torch.nn.functional as F 62 | 63 | 64 | class BasicBlock(nn.Module): 65 | expansion = 1 66 | 67 | def __init__(self, in_planes, planes, stride=1): 68 | super(BasicBlock, self).__init__() 69 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 70 | self.bn1 = nn.BatchNorm2d(planes) 71 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 72 | self.bn2 = nn.BatchNorm2d(planes) 73 | 74 | self.shortcut = nn.Sequential() 75 | if stride != 1 or in_planes != self.expansion*planes: 76 | self.shortcut = nn.Sequential( 77 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 78 | nn.BatchNorm2d(self.expansion*planes) 79 | ) 80 | 81 | def forward(self, x): 82 | out = F.elu(self.bn1(self.conv1(x))) 83 | out = self.bn2(self.conv2(out)) 84 | out += self.shortcut(x) 85 | out = F.elu(out) 86 | return out 87 | 88 | 89 | class Bottleneck(nn.Module): 90 | expansion = 4 91 | 92 | def __init__(self, in_planes, planes, stride=1): 93 | super(Bottleneck, self).__init__() 94 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 95 | self.bn1 = nn.BatchNorm2d(planes) 96 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 97 | self.bn2 = nn.BatchNorm2d(planes) 98 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 99 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 100 | 101 | self.shortcut = nn.Sequential() 102 | if stride != 1 or in_planes != self.expansion*planes: 103 | self.shortcut = nn.Sequential( 104 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 105 | nn.BatchNorm2d(self.expansion*planes) 106 | ) 107 | 108 | def forward(self, x): 109 | out = F.elu(self.bn1(self.conv1(x))) 110 | out = F.elu(self.bn2(self.conv2(out))) 111 | out = self.bn3(self.conv3(out)) 112 | out += self.shortcut(x) 113 | out = F.elu(out) 114 | return out 115 | 116 | 117 | class ResNet(nn.Module): 118 | def __init__(self, block, num_blocks, num_classes=10): 119 | super(ResNet, self).__init__() 120 | self.in_planes = 64 121 | 122 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 123 | self.bn1 = nn.BatchNorm2d(64) 124 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 125 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 126 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 127 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 128 | self.linear = nn.Linear(512*block.expansion, num_classes) 129 | 130 | def _make_layer(self, block, planes, num_blocks, stride): 131 | strides = [stride] + [1]*(num_blocks-1) 132 | layers = [] 133 | for stride in strides: 134 | layers.append(block(self.in_planes, planes, stride)) 135 | self.in_planes = planes * block.expansion 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | out = F.elu(self.bn1(self.conv1(x))) 140 | out = self.layer1(out) 141 | out = self.layer2(out) 142 | out = self.layer3(out) 143 | out = self.layer4(out) 144 | out = F.avg_pool2d(out, 4) 145 | out = out.view(out.size(0), -1) 146 | out = self.linear(out) 147 | return out 148 | 149 | def ResNet9(): 150 | return ResNet(BasicBlock, [1,1,1,1]) 151 | 152 | def ResNet18(): 153 | return ResNet(BasicBlock, [2,2,2,2]) 154 | 155 | def ResNet34(): 156 | return ResNet(BasicBlock, [3,4,6,3]) 157 | 158 | def ResNet50(): 159 | return ResNet(Bottleneck, [3,4,6,3]) 160 | 161 | def ResNet101(): 162 | return ResNet(Bottleneck, [3,4,23,3]) 163 | 164 | def ResNet152(): 165 | return ResNet(Bottleneck, [3,8,36,3]) 166 | 167 | 168 | # enable this to use wide ResNet 169 | wide_resnet=False 170 | if not wide_resnet: 171 | net=ResNet18().to(mydevice) 172 | else: 173 | # use wide residual net https://arxiv.org/abs/1605.07146 174 | net=torchvision.models.resnet.wide_resnet50_2().to(mydevice) 175 | 176 | 177 | ##################################################### 178 | def verification_error_check(net): 179 | correct=0 180 | total=0 181 | for data in testloader: 182 | images,labels=data 183 | outputs=net(Variable(images).to(mydevice)) 184 | _,predicted=torch.max(outputs.data,1) 185 | correct += (predicted==labels.to(mydevice)).sum() 186 | total += labels.size(0) 187 | 188 | return 100*correct//total 189 | ##################################################### 190 | 191 | lambda1=0.000001 192 | lambda2=0.001 193 | 194 | # loss function and optimizer 195 | import torch.optim as optim 196 | from lbfgsnew import LBFGSNew # custom optimizer 197 | criterion=nn.CrossEntropyLoss() 198 | #optimizer=optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 199 | #optimizer=optim.Adam(net.parameters(), lr=0.001) 200 | optimizer = LBFGSNew(net.parameters(), history_size=7, max_iter=2, line_search_fn=True,batch_mode=True) 201 | 202 | 203 | load_model=False 204 | # update from a saved model 205 | if load_model: 206 | checkpoint=torch.load('./res18.model',map_location=mydevice) 207 | net.load_state_dict(checkpoint['model_state_dict']) 208 | net.train() # initialize for training (BN,dropout) 209 | 210 | start_time=time.time() 211 | use_lbfgs=True 212 | # train network 213 | for epoch in range(20): 214 | running_loss=0.0 215 | for i,data in enumerate(trainloader,0): 216 | # get the inputs 217 | inputs,labels=data 218 | # wrap them in variable 219 | inputs,labels=Variable(inputs).to(mydevice),Variable(labels).to(mydevice) 220 | 221 | if not use_lbfgs: 222 | # zero gradients 223 | optimizer.zero_grad() 224 | # forward+backward optimize 225 | outputs=net(inputs) 226 | loss=criterion(outputs,labels) 227 | loss.backward() 228 | optimizer.step() 229 | else: 230 | if not wide_resnet: 231 | layer1=torch.cat([x.view(-1) for x in net.layer1.parameters()]) 232 | layer2=torch.cat([x.view(-1) for x in net.layer2.parameters()]) 233 | layer3=torch.cat([x.view(-1) for x in net.layer3.parameters()]) 234 | layer4=torch.cat([x.view(-1) for x in net.layer4.parameters()]) 235 | 236 | def closure(): 237 | if torch.is_grad_enabled(): 238 | optimizer.zero_grad() 239 | outputs=net(inputs) 240 | if not wide_resnet: 241 | l1_penalty=lambda1*(torch.norm(layer1,1)+torch.norm(layer2,1)+torch.norm(layer3,1)+torch.norm(layer4,1)) 242 | l2_penalty=lambda2*(torch.norm(layer1,2)+torch.norm(layer2,2)+torch.norm(layer3,2)+torch.norm(layer4,2)) 243 | loss=criterion(outputs,labels)+l1_penalty+l2_penalty 244 | else: 245 | l1_penalty=0 246 | l2_penalty=0 247 | loss=criterion(outputs,labels) 248 | if loss.requires_grad: 249 | loss.backward() 250 | #print('loss %f l1 %f l2 %f'%(loss,l1_penalty,l2_penalty)) 251 | return loss 252 | optimizer.step(closure) 253 | # only for diagnostics 254 | outputs=net(inputs) 255 | loss=criterion(outputs,labels) 256 | running_loss +=loss.data.item() 257 | 258 | if math.isnan(loss.data.item()): 259 | print('loss became nan at %d'%i) 260 | break 261 | 262 | # print statistics 263 | if i%(batches_for_report) == (batches_for_report-1): # after every 'batches_for_report' 264 | print('%f: [%d, %5d] loss: %.5f accuracy: %.3f'% 265 | (time.time()-start_time,epoch+1,i+1,running_loss/batches_for_report, 266 | verification_error_check(net))) 267 | running_loss=0.0 268 | 269 | print('Finished Training') 270 | 271 | 272 | # save model (and other extra items) 273 | torch.save({ 274 | 'model_state_dict':net.state_dict(), 275 | 'epoch':epoch, 276 | 'optimizer_state_dict':optimizer.state_dict(), 277 | 'running_loss':running_loss, 278 | },'./res.model') 279 | 280 | 281 | # whole dataset 282 | correct=0 283 | total=0 284 | for data in trainloader: 285 | images,labels=data 286 | outputs=net(Variable(images).to(mydevice)).cpu() 287 | _,predicted=torch.max(outputs.data,1) 288 | total += labels.size(0) 289 | correct += (predicted==labels).sum() 290 | 291 | print('Accuracy of the network on the %d train images: %d %%'% 292 | (total,100*correct//total)) 293 | 294 | correct=0 295 | total=0 296 | for data in testloader: 297 | images,labels=data 298 | outputs=net(Variable(images).to(mydevice)).cpu() 299 | _,predicted=torch.max(outputs.data,1) 300 | total += labels.size(0) 301 | correct += (predicted==labels).sum() 302 | 303 | print('Accuracy of the network on the %d test images: %d %%'% 304 | (total,100*correct//total)) 305 | 306 | 307 | class_correct=list(0. for i in range(10)) 308 | class_total=list(0. for i in range(10)) 309 | for data in testloader: 310 | images,labels=data 311 | outputs=net(Variable(images).to(mydevice)).cpu() 312 | _,predicted=torch.max(outputs.data,1) 313 | c=(predicted==labels).squeeze() 314 | for i in range(4): 315 | label=labels[i] 316 | class_correct[label] += c[i] 317 | class_total[label] += 1 318 | 319 | for i in range(10): 320 | print('Accuracy of %5s : %2d %%' % 321 | (classes[i],100*float(class_correct[i])/float(class_total[i]))) 322 | -------------------------------------------------------------------------------- /kan_pde.py: -------------------------------------------------------------------------------- 1 | # This is an exmple of training a KAN model, original at 2 | # https://kindxiaoming.github.io/pykan/Examples/Example_6_PDE.html 3 | # using the LBFGS-B optimizer 4 | 5 | from kan import KAN 6 | from lbfgsb import LBFGSB 7 | from lbfgsnew import LBFGSNew 8 | import torch 9 | import matplotlib.pyplot as plt 10 | from torch import autograd 11 | from tqdm import tqdm 12 | import numpy as np 13 | 14 | use_cuda=True 15 | if use_cuda and torch.cuda.is_available(): 16 | mydevice=torch.device('cuda') 17 | else: 18 | mydevice=torch.device('cpu') 19 | 20 | 21 | dim = 2 22 | np_i = 21 # number of interior points (along each dimension) 23 | np_b = 21 # number of boundary points (along each dimension) 24 | ranges = [-1, 1] 25 | 26 | model = KAN(width=[2,2,1], grid=5, k=3, grid_eps=1.0, device=mydevice) 27 | 28 | # get all parameters (all may not be trainable) 29 | n_params = sum([np.prod(p.size()) for p in model.parameters()]) 30 | # lower/upper bounds for parameters 31 | x_l=(torch.ones(n_params)*(-100.0)).to(mydevice) 32 | x_u=(torch.ones(n_params)*(100.0)).to(mydevice) 33 | 34 | def batch_jacobian(func, x, create_graph=False): 35 | # x in shape (Batch, Length) 36 | def _func_sum(x): 37 | return func(x).sum(dim=0) 38 | return autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2) 39 | 40 | # define solution 41 | sol_fun = lambda x: torch.sin(torch.pi*x[:,[0]])*torch.sin(torch.pi*x[:,[1]]) 42 | source_fun = lambda x: -2*torch.pi**2 * torch.sin(torch.pi*x[:,[0]])*torch.sin(torch.pi*x[:,[1]]) 43 | 44 | # interior 45 | sampling_mode = 'random' # 'random' or 'mesh' 46 | 47 | x_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i).to(mydevice) 48 | y_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i).to(mydevice) 49 | X, Y = torch.meshgrid(x_mesh, y_mesh, indexing="ij") 50 | if sampling_mode == 'mesh': 51 | #mesh 52 | x_i = torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0) 53 | else: 54 | #random 55 | x_i = torch.rand((np_i**2,2))*2-1 56 | x_i=x_i.to(mydevice) 57 | 58 | # boundary, 4 sides 59 | helper = lambda X, Y: torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0) 60 | xb1 = helper(X[0], Y[0]) 61 | xb2 = helper(X[-1], Y[0]) 62 | xb3 = helper(X[:,0], Y[:,0]) 63 | xb4 = helper(X[:,0], Y[:,-1]) 64 | x_b = torch.cat([xb1, xb2, xb3, xb4], dim=0) 65 | 66 | steps = 20 67 | alpha = 0.1 68 | log = 1 69 | 70 | #torch.autograd.set_detect_anomaly(True) 71 | def train(): 72 | # try running with batch_mode=True and batch_mode=False (both should work) 73 | optimizer = LBFGSB(model.parameters(), lower_bound=x_l, upper_bound=x_u, history_size=10, tolerance_grad=1e-32, tolerance_change=1e-32, batch_mode=True, cost_use_gradient=True) 74 | #optimizer = LBFGSNew(model.parameters(), history_size=10, tolerance_grad=1e-32, tolerance_change=1e-32, batch_mode=True, cost_use_gradient=True) 75 | 76 | pbar = tqdm(range(steps), desc='description') 77 | 78 | for _ in pbar: 79 | def closure(): 80 | global pde_loss, bc_loss 81 | optimizer.zero_grad() 82 | # interior loss 83 | sol = sol_fun(x_i) 84 | sol_D1_fun = lambda x: batch_jacobian(model, x, create_graph=True)[:,0,:] 85 | sol_D1 = sol_D1_fun(x_i) 86 | sol_D2 = batch_jacobian(sol_D1_fun, x_i, create_graph=True)[:,:,:] 87 | lap = torch.sum(torch.diagonal(sol_D2, dim1=1, dim2=2), dim=1, keepdim=True) 88 | source = source_fun(x_i) 89 | pde_loss = torch.mean((lap - source)**2) 90 | 91 | # boundary loss 92 | bc_true = sol_fun(x_b) 93 | bc_pred = model(x_b) 94 | bc_loss = torch.mean((bc_pred-bc_true)**2) 95 | 96 | loss = alpha * pde_loss + bc_loss 97 | loss.backward() 98 | return loss 99 | 100 | if _ % 5 == 0 and _ < 50: 101 | model.update_grid_from_samples(x_i) 102 | 103 | optimizer.step(closure) 104 | sol = sol_fun(x_i) 105 | loss = alpha * pde_loss + bc_loss 106 | l2 = torch.mean((model(x_i) - sol)**2) 107 | 108 | if _ % log == 0: 109 | pbar.set_description("pde loss: %.2e | bc loss: %.2e | l2: %.2e " % (pde_loss.cpu().detach().numpy(), bc_loss.cpu().detach().numpy(), l2.cpu().detach().numpy())) 110 | 111 | train() 112 | -------------------------------------------------------------------------------- /lbfgsb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import reduce 3 | from torch.optim.optimizer import Optimizer 4 | 5 | import math 6 | 7 | be_verbose=False 8 | 9 | class LBFGSB(Optimizer): 10 | """Implements L-BFGS-B algorithm. 11 | Primary reference: 12 | 1) MATLAB code https://github.com/bgranzow/L-BFGS-B by Brian Granzow 13 | Theory based on: 14 | 1) A Limited Memory Algorithm for Bound Constrained Optimization, Byrd et al. 1995 15 | 2) Numerical Optimization, Nocedal and Wright, 2006 16 | 17 | .. warning:: 18 | This optimizer doesn't support per-parameter options and parameter 19 | groups (there can be only one). 20 | 21 | .. note:: 22 | This is still WIP, the saving/restoring of state dict is not fully implemented. 23 | 24 | Arguments: 25 | lower_bound (shape equal to parameter vector): parameters > lower_bound 26 | upper_bound (shape equal to parameter vector): parameters < upper_bound 27 | max_iter (int): maximal number of iterations per optimization step 28 | (default: 10) 29 | tolerance_grad (float): termination tolerance on first order optimality 30 | (default: 1e-5). 31 | tolerance_change (float): termination tolerance on function 32 | value/parameter changes (default: 1e-20). 33 | history_size (int): update history size (default: 7). 34 | batch_mode: True for stochastic version (default: False) 35 | cost_use_gradient: set this to True when the cost function also needs the gradient, for example in TV (total variation) regularization. (default: False) 36 | 37 | Example: 38 | ------ 39 | >>> x=torch.rand(2,requires_grad=True,dtype=torch.float64,device=mydevice) 40 | >>> x_l=torch.ones(2,device=mydevice)*(-1.0) 41 | >>> x_u=torch.ones(2,device=mydevice) 42 | >>> optimizer=LBFGSB([x],lower_bound=x_l, upper_bound=x_u, history_size=7, max_iter=4, batch_mode=True) 43 | >>> def cost_function(): 44 | >>> f=torch.pow(1.0-x[0],2.0)+100.0*torch.pow(x[1]-x[0]*x[0],2.0) 45 | >>> return f 46 | >>> for ci in range(10): 47 | >>> def closure(): 48 | >>> if torch.is_grad_enabled(): 49 | >>> optimizer.zero_grad() 50 | >>> loss=cost_function() 51 | >>> if loss.requires_grad: 52 | >>> loss.backward() 53 | >>> return loss 54 | >>> 55 | >>> optimizer.step(closure) 56 | ------ 57 | """ 58 | 59 | def __init__(self, params, lower_bound, upper_bound, max_iter=10, 60 | tolerance_grad=1e-5, tolerance_change=1e-20, history_size=7, 61 | batch_mode=False, cost_use_gradient=False): 62 | defaults = dict(max_iter=max_iter, 63 | tolerance_grad=tolerance_grad, tolerance_change=tolerance_change, 64 | history_size=history_size, 65 | batch_mode=batch_mode, 66 | cost_use_gradient=cost_use_gradient) 67 | super(LBFGSB, self).__init__(params, defaults) 68 | 69 | if len(self.param_groups) != 1: 70 | raise ValueError("LBFGSB doesn't support per-parameter options " 71 | "(parameter groups)") 72 | 73 | self._params = self.param_groups[0]['params'] 74 | self._numel_cache = None 75 | self._device = self._params[0].device 76 | self._dtype= self._params[0].dtype 77 | self._l=lower_bound.clone(memory_format=torch.contiguous_format).to(self._device) 78 | self._u=upper_bound.clone(memory_format=torch.contiguous_format).to(self._device) 79 | self._m=history_size 80 | self._n=self._numel() 81 | # local storage as matrices (instead of curvature pairs) 82 | self._W=torch.zeros(self._n,self._m*2,dtype=self._dtype).to(self._device) 83 | self._Y=torch.zeros(self._n,self._m,dtype=self._dtype).to(self._device) 84 | self._S=torch.zeros(self._n,self._m,dtype=self._dtype).to(self._device) 85 | self._M=torch.zeros(self._m*2,self._m*2,dtype=self._dtype).to(self._device) 86 | 87 | self._fit_to_constraints() 88 | 89 | self._eps=tolerance_change 90 | self._realmax=1e20 91 | self._theta=1 92 | 93 | # batch mode 94 | self.running_avg=None 95 | self.running_avg_sq=None 96 | self.alphabar=1.0 97 | 98 | def _numel(self): 99 | if self._numel_cache is None: 100 | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0) 101 | return self._numel_cache 102 | 103 | def _gather_flat_grad(self): 104 | views = [] 105 | for p in self._params: 106 | if p.grad is None: 107 | view = p.data.new(p.data.numel()).zero_() 108 | elif p.grad.data.is_sparse: 109 | view = p.grad.data.to_dense().contiguous().view(-1) 110 | else: 111 | view = p.grad.data.contiguous().view(-1) 112 | views.append(view) 113 | return torch.cat(views, 0) 114 | 115 | def _add_grad(self, step_size, update): 116 | offset = 0 117 | for p in self._params: 118 | numel = p.numel() 119 | # view as to avoid deprecated pointwise semantics 120 | p.data.add_(update[offset:offset + numel].view_as(p.data), alpha=step_size) 121 | offset += numel 122 | assert offset == self._numel() 123 | 124 | #copy the parameter values out, create a list of vectors 125 | def _copy_params_out(self): 126 | return [p.detach().flatten().clone(memory_format=torch.contiguous_format) for p in self._params] 127 | 128 | #copy the parameter values back, dividing the list appropriately 129 | def _copy_params_in(self,new_params): 130 | with torch.no_grad(): 131 | for p, pdata in zip(self._params, new_params): 132 | p.copy_(pdata.view_as(p)) 133 | 134 | # restrict parameters to constraints 135 | def _fit_to_constraints(self): 136 | params=[] 137 | for p in self._params: 138 | # make a vector 139 | p = p.detach().flatten() 140 | params.append(p) 141 | x=torch.cat(params,0) 142 | for i in range(x.numel()): 143 | if (x[i]self._u[i]): 146 | x[i]=self._u[i] 147 | offset = 0 148 | with torch.no_grad(): 149 | for p in self._params: 150 | numel = p.numel() 151 | p.copy_(x[offset:offset + numel].view_as(p)) 152 | offset += numel 153 | assert offset == self._numel() 154 | 155 | def _get_optimality(self,g): 156 | # get the inf-norm of the projected gradient 157 | # pp. 17, (6.1) 158 | # x: nx1 parameters 159 | # g: nx1 gradient 160 | # l: nx1 lower bound 161 | # u: nx1 upper bound 162 | x=torch.cat(self._copy_params_out(),0) 163 | projected_g=x-g 164 | for i in range(x.numel()): 165 | if projected_g[i]self._u[i]: 168 | projected_g[i]=self._u[i] 169 | projected_g=projected_g-x 170 | return max(abs(projected_g)) 171 | 172 | def _get_breakpoints(self,x,g): 173 | # compute breakpoints for Cauchy point 174 | # pp 5-6, (4.1), (4.2), pp. 8, CP initialize \mathcal{F} 175 | # x: nx1 parameters 176 | # g: nx1 gradient 177 | # l: nx1 lower bound 178 | # u: nx1 upper bound 179 | # out: 180 | # t: nx1 breakpoint vector 181 | # d: nx1 search direction vector 182 | # F: nx1 indices that sort t from low to high 183 | t=torch.zeros(self._n,1,dtype=self._dtype,device=self._device) 184 | d=-g 185 | for i in range(self._n): 186 | if (g[i]<0.0): 187 | t[i]=(x[i]-self._u[i])/g[i] 188 | elif (g[i]>0.0): 189 | t[i]=(x[i]-self._l[i])/g[i] 190 | else: 191 | t[i]=self._realmax 192 | 193 | if (t[i]0, scaling 208 | # W: nx2m 209 | # M: 2mx2m 210 | # out: 211 | # xc: nx1 the generalized Cauchy point 212 | # c: 2mx1 initialization vector for subspace minimization 213 | 214 | x=torch.cat(self._copy_params_out(),0) 215 | tt,d,F=self._get_breakpoints(x,g) 216 | xc=x.clone() 217 | c=torch.zeros(2*self._m,1,dtype=self._dtype,device=self._device) 218 | p=torch.mm(self._W.transpose(0,1),d) 219 | fp=-torch.mm(d.transpose(0,1),d) 220 | fpp=-self._theta*fp-torch.mm(p.transpose(0,1),torch.mm(self._M,p)) 221 | fp=fp.squeeze() 222 | fpp=fpp.squeeze() 223 | fpp0=-self._theta*fp 224 | if (fpp != 0.0): 225 | dt_min=-fp/fpp 226 | else: 227 | dt_min=-fp/self._eps 228 | t_old=0 229 | # find lowest index i where F[i] is positive (minimum t) 230 | for j in range(self._n): 231 | i=j 232 | if F[i]>=0.0: 233 | break 234 | b=F[i] 235 | t=tt[b] 236 | dt=t-t_old 237 | 238 | while (idt): 239 | if d[b]>0.0: 240 | xc[b]=self._u[b] 241 | elif d[b]<0.0: 242 | xc[b]=self._l[b] 243 | 244 | zb=xc[b]-x[b] 245 | c=c+dt*p 246 | gb=g[b] 247 | Wbt=self._W[b,:] 248 | Wbt=Wbt.unsqueeze(-1).transpose(0,1) 249 | fp=fp+dt*fpp+gb*gb+self._theta*gb*zb-gb*torch.mm(Wbt,torch.mm(self._M,c)) 250 | fpp=fpp-self._theta*gb*gb-2.0*gb*torch.mm(Wbt,torch.mm(self._M,p))-gb*gb*torch.mm(Wbt,torch.mm(self._M,Wbt.transpose(0,1))) 251 | fp=fp.squeeze() 252 | fpp=fpp.squeeze() 253 | fpp=max(self._eps*fpp0,fpp) 254 | p=p+gb*Wbt.transpose(0,1) 255 | d[b]=0.0 256 | if (fpp != 0.0): 257 | dt_min=-fp/fpp 258 | else: 259 | dt_min=-fp/self._eps 260 | t_old=t 261 | i=i+1 262 | if i0, scaling 287 | # W: nx2m 288 | # M: 2mx2m 289 | # out: 290 | # xbar: nx1 minimizer 291 | # line_search_flag: bool 292 | 293 | line_search_flag=True 294 | free_vars_index=list() 295 | for i in range(self._n): 296 | if (xc[i] != self._u[i]) and (xc[i] != self._l[i]): 297 | free_vars_index.append(i) 298 | 299 | n_free_vars=len(free_vars_index) 300 | if n_free_vars==0: 301 | xbar=xc.clone() 302 | line_search_flag=False 303 | return xbar,line_search_flag 304 | 305 | WtZ=torch.zeros((2*self._m,n_free_vars),dtype=self._dtype,device=self._device) 306 | # each column of WtZ (2*m values) = row of i-th free variable in W (2*m values) 307 | for i in range(n_free_vars): 308 | WtZ[:,i]=self._W[free_vars_index[i],:] 309 | 310 | x=torch.cat(self._copy_params_out(),0) 311 | rr=g+self._theta*(xc-x) - torch.mm(self._W,torch.mm(self._M,c)).squeeze() 312 | r=torch.zeros(n_free_vars,1,dtype=self._dtype,device=self._device) 313 | for i in range(n_free_vars): 314 | r[i]=rr[free_vars_index[i]] 315 | 316 | invtheta=1.0/self._theta 317 | v=torch.mm(self._M,torch.mm(WtZ,r)) 318 | N=invtheta*torch.mm(WtZ,WtZ.transpose(0,1)) 319 | N=torch.eye(2*self._m).to(self._device)-torch.mm(self._M,N) 320 | v,_,_,_=torch.linalg.lstsq(N,v,rcond=None) 321 | du=-invtheta*r-invtheta*invtheta*torch.mm(WtZ.transpose(0,1),v) 322 | 323 | alpha_star=self._find_alpha(xc,du,free_vars_index) 324 | d_star=alpha_star*du 325 | xbar=xc.clone() 326 | for i in range(n_free_vars): 327 | idx=free_vars_index[i] 328 | xbar[idx]=xbar[idx]+d_star[i] 329 | 330 | return xbar,line_search_flag 331 | 332 | def _find_alpha(self, xc, du, free_vars_index): 333 | # pp. 11, (5.8) 334 | # l: nx1 lower bound 335 | # u: nx1 upper bound 336 | # xc: nx1 generalized Cauchy point 337 | # du: n_free_varsx1 338 | # free_vars_index: n_free_varsx1 indices of free variables 339 | # out: 340 | # alpha_star: positive scaling parameter 341 | 342 | n_free_vars=len(free_vars_index) 343 | alpha_star=1.0 344 | for i in range(n_free_vars): 345 | idx=free_vars_index[i] 346 | if du[i]>0.0: 347 | alpha_star=min(alpha_star,(self._u[idx]-xc[idx])/du[i]) 348 | elif du[i]<0.0: 349 | alpha_star=min(alpha_star,(self._l[idx]-xc[idx])/du[i]) 350 | 351 | return alpha_star 352 | 353 | 354 | def _linesearch_backtrack(self, closure, f_old, gk, pk, alphabar): 355 | """Line search (backtracking) 356 | 357 | Arguments: 358 | closure (callable): A closure that reevaluates the model 359 | and returns the loss. 360 | f_old: original cost 361 | gk: gradient vector 362 | pk: step direction vector 363 | alphabar: max step size 364 | """ 365 | c1=1e-4 366 | citer=35 367 | alphak=alphabar 368 | 369 | x0list=self._copy_params_out() 370 | xk=[x.clone() for x in x0list] 371 | self._add_grad(alphak,pk) 372 | f_new=float(closure()) 373 | s=gk 374 | prodterm=c1*s.dot(pk) 375 | ci=0 376 | while (cif_old+alphak*prodterm)): 377 | alphak=0.5*alphak 378 | self._copy_params_in(xk) 379 | self._add_grad(alphak,pk) 380 | f_new=float(closure()) 381 | ci=ci+1 382 | 383 | self._copy_params_in(xk) 384 | return alphak 385 | 386 | 387 | def _strong_wolfe(self, closure, f0, g0, p): 388 | # line search to satisfy strong Wolfe conditions 389 | # Alg 3.5, pp. 60, Numerical optimization Nocedal & Wright 390 | # cost: cost function R^n -> 1 391 | # gradient: gradient function R^n -> R^n 392 | # x0: nx1 initial parameters 393 | # f0: 1 intial cost 394 | # g0: nx1 initial gradient 395 | # p: nx1 intial search direction 396 | # out: 397 | # alpha: step length 398 | 399 | c1=1e-4 400 | c2=0.9 401 | alpha_max=2.5 402 | alpha_im1=0 403 | alpha_i=1 404 | f_im1=f0 405 | dphi0=torch.dot(g0,p) 406 | 407 | # make a copy of original params 408 | x0list=self._copy_params_out() 409 | x0=[x.clone() for x in x0list] 410 | 411 | i=0 412 | max_iters=20 413 | while 1: 414 | # x=x0+alpha_i*p 415 | self._copy_params_in(x0) 416 | self._add_grad(alpha_i,p) 417 | f_i=float(closure()) 418 | if (f_i>f0+c1*dphi0) or ((i>1) and (f_i>f_im1)): 419 | alpha=self._alpha_zoom(closure,x0,f0,g0,p,alpha_im1,alpha_i) 420 | break 421 | g_i=self._gather_flat_grad() 422 | dphi=torch.dot(g_i,p) 423 | if (abs(dphi)<=-c2*dphi0): 424 | alpha=alpha_i 425 | break 426 | if (dphi>=0.0): 427 | alpha=self._alpha_zoom(closure,x0,f0,g0,p,alpha_i,alpha_im1) 428 | break 429 | alpha_im1=alpha_i 430 | f_im1=f_i 431 | alpha_i=alpha_i+0.8*(alpha_max-alpha_i) 432 | if (i>max_iters): 433 | alpha=alpha_i 434 | break 435 | i=i+1 436 | 437 | # restore original params 438 | self._copy_params_in(x0) 439 | return alpha 440 | 441 | 442 | def _alpha_zoom(self, closure, x0, f0, g0, p, alpha_lo, alpha_hi): 443 | # Alg 3.6, pp. 61, Numerical optimization Nocedal & Wright 444 | # cost: cost function R^n -> 1 445 | # gradient: gradient function R^n -> R^n 446 | # x0: list() initial parameters 447 | # f0: 1 intial cost 448 | # g0: nx1 initial gradient 449 | # p: nx1 intial search direction 450 | # alpha_lo: low limit for alpha 451 | # alpha_hi: high limit for alpha 452 | # out: 453 | # alpha: zoomed step length 454 | c1=1e-4 455 | c2=0.9 456 | i=0 457 | max_iters=20 458 | dphi0=torch.dot(g0,p) 459 | while 1: 460 | alpha_i=0.5*(alpha_lo+alpha_hi) 461 | alpha=alpha_i 462 | # x=x0+alpha_i*p 463 | self._copy_params_in(x0) 464 | self._add_grad(alpha_i,p) 465 | f_i=float(closure()) 466 | g_i=self._gather_flat_grad() 467 | # x_lo=x0+alpha_lo*p 468 | self._copy_params_in(x0) 469 | self._add_grad(alpha_lo,p) 470 | f_lo=float(closure()) 471 | if ((f_i>f0+c1*alpha_i*dphi0) or (f_i>=f_lo)): 472 | alpha_hi=alpha_i 473 | else: 474 | dphi=torch.dot(g_i,p) 475 | if ((abs(dphi)<=-c2*dphi0)): 476 | alpha=alpha_i 477 | break 478 | if (dphi*(alpha_hi-alpha_lo)>=0.0): 479 | alpha_hi=alpha_lo 480 | alpha_lo=alpha_i 481 | i=i+1 482 | if (i>max_iters): 483 | alpha=alpha_i 484 | break 485 | 486 | return alpha 487 | 488 | 489 | 490 | 491 | def step(self, closure): 492 | """Performs a single optimization step. 493 | 494 | Arguments: 495 | closure (callable): A closure that reevaluates the model 496 | and returns the loss. 497 | """ 498 | assert len(self.param_groups) == 1 499 | 500 | group = self.param_groups[0] 501 | max_iter = group['max_iter'] 502 | tolerance_grad = group['tolerance_grad'] 503 | tolerance_change = group['tolerance_change'] 504 | history_size = group['history_size'] 505 | 506 | batch_mode = group['batch_mode'] 507 | cost_use_gradient = group['cost_use_gradient'] 508 | 509 | 510 | # NOTE: LBFGS has only global state, but we register it as state for 511 | # the first param, because this helps with casting in load_state_dict 512 | state = self.state[self._params[0]] 513 | state.setdefault('func_evals', 0) 514 | state.setdefault('n_iter', 0) 515 | 516 | 517 | # evaluate initial f(x) and df/dx 518 | orig_loss = closure() 519 | f= float(orig_loss) 520 | current_evals = 1 521 | state['func_evals'] += 1 522 | 523 | g=self._gather_flat_grad() 524 | abs_grad_sum = g.abs().sum() 525 | 526 | if torch.isnan(abs_grad_sum) or abs_grad_sum <= tolerance_grad: 527 | return orig_loss 528 | 529 | n_iter=0 530 | 531 | if batch_mode and state['n_iter']==0: 532 | self.running_avg=torch.zeros_like(g.data) 533 | self.running_avg_sq=torch.zeros_like(g.data) 534 | 535 | while (self._get_optimality(g)>tolerance_change) and n_iter1) 565 | if batch_changed: 566 | tmp_grad_1=g_old.clone(memory_format=torch.contiguous_format) 567 | tmp_grad_1.add_(self.running_avg,alpha=-1.0) # grad-oldmean 568 | self.running_avg.add_(tmp_grad_1,alpha=1.0/state['n_iter']) 569 | tmp_grad_2=g_old.clone(memory_format=torch.contiguous_format) 570 | tmp_grad_2.add_(self.running_avg,alpha=-1.0) # grad-newmean 571 | self.running_avg_sq.addcmul_(tmp_grad_2,tmp_grad_1,value=1) # # +(grad-newmean)(grad-oldmean) 572 | self.alphabar=1.0/(1.0+self.running_avg_sq.sum()/((state['n_iter']-1)*g_old.norm().item())) 573 | 574 | 575 | if (curv f_old + alphak*prodterm)): 155 | alphak=0.5*alphak 156 | self._copy_params_in(xk) 157 | self._add_grad(alphak, pk) 158 | f_new=float(closure()) 159 | if be_verbose: 160 | print('LN %d alpha=%f fnew=%f fold=%f'%(ci,alphak,f_new,f_old)) 161 | ci=ci+1 162 | 163 | # if the cost is not sufficiently decreased, also try -ve steps 164 | if (f_old-f_new < torch.abs(prodterm)): 165 | alphak1=-alphabar 166 | self._copy_params_in(xk) 167 | self._add_grad(alphak1, pk) 168 | f_new1=float(closure()) 169 | if be_verbose: 170 | print('NLN fnew=%f'%f_new1) 171 | while (ci f_old + alphak1*prodterm)): 172 | alphak1=0.5*alphak1 173 | self._copy_params_in(xk) 174 | self._add_grad(alphak1, pk) 175 | f_new1=float(closure()) 176 | if be_verbose: 177 | print('NLN %d alpha=%f fnew=%f fold=%f'%(ci,alphak1,f_new1,f_old)) 178 | ci=ci+1 179 | 180 | if f_new1phi_0+alphai*gphi_0) or (ci>1 and phi_alphai>=phi_alphai1) : 261 | # ai=alphai1, bi=alphai bracket 262 | if be_verbose: 263 | print("bracket "+str(alphai1)+","+str(alphai)) 264 | alphak=self._linesearch_zoom(closure,xk,pk,alphai1,alphai,phi_0,gphi_0,sigma,rho,t1,t2,t3,step) 265 | if be_verbose: 266 | print("Linesearch: condition 1 met") 267 | break 268 | 269 | # evaluate grad(phi(alpha(i))) */ 270 | # note that self._params already is xk+alphai. pk, so only add the missing term 271 | # xp <- xk+(alphai+step). pk 272 | self._add_grad(step, pk) #FF param = param - t * grad 273 | p01=float(closure()) 274 | # xp <- xk+(alphai-step). pk 275 | self._add_grad(-2.0*step, pk) #FF param = param - t * grad 276 | p02=float(closure()) 277 | gphi_i=(p01-p02)/(2.0*step); 278 | 279 | if (abs(gphi_i)<=-sigma*gphi_0): 280 | alphak=alphai 281 | if be_verbose: 282 | print("Linesearch: condition 2 met") 283 | break 284 | 285 | if gphi_i>=0.0 : 286 | # ai=alphai, bi=alphai1 bracket 287 | if be_verbose: 288 | print("bracket "+str(alphai)+","+str(alphai1)) 289 | alphak=self._linesearch_zoom(closure,xk,pk,alphai,alphai1,phi_0,gphi_0,sigma,rho,t1,t2,t3,step) 290 | if be_verbose: 291 | print("Linesearch: condition 3 met") 292 | break 293 | # else preserve old values 294 | if (mu<=2.0*alphai-alphai1): 295 | alphai1=alphai 296 | alphai=mu 297 | else: 298 | # choose by interpolation in [2*alphai-alphai1,min(mu,alphai+t1*(alphai-alphai1)] 299 | p01=2.0*alphai-alphai1; 300 | p02=min(mu,alphai+t1*(alphai-alphai1)) 301 | alphai=self._cubic_interpolate(closure,xk,pk,p01,p02,step) 302 | 303 | 304 | phi_alphai1=phi_alphai; 305 | # update function evals 306 | closure_evals +=3 307 | ci=ci+1 308 | 309 | 310 | 311 | 312 | # recover original params 313 | self._copy_params_in(xk) 314 | # update state 315 | state['func_evals'] += closure_evals 316 | return alphak 317 | 318 | 319 | def _cubic_interpolate(self,closure,xk,pk,a,b,step): 320 | """ Cubic interpolation within interval [a,b] or [b,a] (a>b is possible) 321 | 322 | Arguments: 323 | closure (callable): A closure that reevaluates the model 324 | and returns the loss. 325 | xk: copy of parameter values 326 | pk: gradient vector 327 | a/b: interval for interpolation 328 | step: step size for differencing 329 | """ 330 | 331 | 332 | self._copy_params_in(xk) 333 | 334 | # state parameter 335 | state = self.state[self._params[0]] 336 | # count function evals 337 | closure_evals=0 338 | 339 | # xp <- xk+a. pk 340 | self._add_grad(a, pk) #FF param = param + t * grad 341 | f0=float(closure()) 342 | # xp <- xk+(a+step). pk 343 | self._add_grad(step, pk) #FF param = param + t * grad 344 | p01=float(closure()) 345 | # xp <- xk+(a-step). pk 346 | self._add_grad(-2.0*step, pk) #FF param = param - t * grad 347 | p02=float(closure()) 348 | f0d=(p01-p02)/(2.0*step) 349 | 350 | # xp <- xk+b. pk 351 | self._add_grad(-a+step+b, pk) #FF param = param + t * grad 352 | f1=float(closure()) 353 | # xp <- xk+(b+step). pk 354 | self._add_grad(step, pk) #FF param = param + t * grad 355 | p01=float(closure()) 356 | # xp <- xk+(b-step). pk 357 | self._add_grad(-2.0*step, pk) #FF param = param - t * grad 358 | p02=float(closure()) 359 | f1d=(p01-p02)/(2.0*step) 360 | 361 | closure_evals=6 362 | 363 | aa=3.0*(f0-f1)/(b-a)+f1d-f0d 364 | p01=aa*aa-f0d*f1d 365 | if (p01>0.0): 366 | cc=math.sqrt(p01) 367 | #print('f0='+str(f0d)+' f1='+str(f1d)+' cc='+str(cc)) 368 | if (f1d-f0d+2.0*cc)==0.0: 369 | return (a+b)*0.5 370 | z0=b-(f1d+cc-aa)*(b-a)/(f1d-f0d+2.0*cc) 371 | aa=max(a,b) 372 | cc=min(a,b) 373 | if z0>aa or z0phi_0+rho*alphaj*gphi_0) or phi_j>=phi_aj : 456 | bj=alphaj # aj is unchanged 457 | else: 458 | # evaluate grad(alphaj) 459 | # xp <- xk+(alphaj+step). pk 460 | self._add_grad(-aj+alphaj+step, pk) #FF param = param + t * grad 461 | p01=float(closure()) 462 | # xp <- xk+(alphaj-step). pk 463 | self._add_grad(-2.0*step, pk) #FF param = param + t * grad 464 | p02=float(closure()) 465 | gphi_j=(p01-p02)/(2.0*step) 466 | 467 | 468 | closure_evals +=2 469 | 470 | # termination due to roundoff/other errors pp. 38, Fletcher 471 | if (aj-alphaj)*gphi_j <= step: 472 | alphak=alphaj 473 | found_step=True 474 | break 475 | 476 | if abs(gphi_j)<=-sigma*gphi_0 : 477 | alphak=alphaj 478 | found_step=True 479 | break 480 | 481 | if gphi_j*(bj-aj)>=0.0: 482 | bj=aj 483 | # else bj is unchanged 484 | aj=alphaj 485 | 486 | 487 | ci=ci+1 488 | 489 | if not found_step: 490 | alphak=alphaj 491 | 492 | # update state 493 | state['func_evals'] += closure_evals 494 | 495 | return alphak 496 | 497 | 498 | def step(self, closure): 499 | """Performs a single optimization step. 500 | 501 | Arguments: 502 | closure (callable): A closure that reevaluates the model 503 | and returns the loss. 504 | """ 505 | assert len(self.param_groups) == 1 506 | 507 | group = self.param_groups[0] 508 | lr = group['lr'] 509 | max_iter = group['max_iter'] 510 | max_eval = group['max_eval'] 511 | tolerance_grad = group['tolerance_grad'] 512 | tolerance_change = group['tolerance_change'] 513 | line_search_fn = group['line_search_fn'] 514 | history_size = group['history_size'] 515 | 516 | batch_mode = group['batch_mode'] 517 | cost_use_gradient = group['cost_use_gradient'] 518 | 519 | 520 | # NOTE: LBFGS has only global state, but we register it as state for 521 | # the first param, because this helps with casting in load_state_dict 522 | state = self.state[self._params[0]] 523 | state.setdefault('func_evals', 0) 524 | state.setdefault('n_iter', 0) 525 | 526 | 527 | # evaluate initial f(x) and df/dx 528 | orig_loss = closure() 529 | loss = float(orig_loss) 530 | current_evals = 1 531 | state['func_evals'] += 1 532 | 533 | flat_grad = self._gather_flat_grad() 534 | abs_grad_sum = flat_grad.abs().sum() 535 | 536 | if torch.isnan(abs_grad_sum) or abs_grad_sum <= tolerance_grad: 537 | return orig_loss 538 | 539 | # tensors cached in state (for tracing) 540 | d = state.get('d') 541 | t = state.get('t') 542 | old_dirs = state.get('old_dirs') 543 | old_stps = state.get('old_stps') 544 | H_diag = state.get('H_diag') 545 | prev_flat_grad = state.get('prev_flat_grad') 546 | prev_loss = state.get('prev_loss') 547 | 548 | n_iter = 0 549 | 550 | if batch_mode: 551 | alphabar=lr 552 | lm0=1e-6 553 | 554 | # optimize for a max of max_iter iterations 555 | grad_nrm=flat_grad.norm().item() 556 | while n_iter < max_iter and not math.isnan(grad_nrm): 557 | # keep track of nb of iterations 558 | n_iter += 1 559 | state['n_iter'] += 1 560 | 561 | ############################################################ 562 | # compute gradient descent direction 563 | ############################################################ 564 | if state['n_iter'] == 1: 565 | d = flat_grad.neg() 566 | old_dirs = [] 567 | old_stps = [] 568 | H_diag = 1 569 | if batch_mode: 570 | running_avg=torch.zeros_like(flat_grad.data) 571 | running_avg_sq=torch.zeros_like(flat_grad.data) 572 | else: 573 | if batch_mode: 574 | running_avg=state.get('running_avg') 575 | running_avg_sq=state.get('running_avg_sq') 576 | if running_avg is None: 577 | running_avg=torch.zeros_like(flat_grad.data) 578 | running_avg_sq=torch.zeros_like(flat_grad.data) 579 | 580 | # do lbfgs update (update memory) 581 | # what happens if current and prev grad are equal, ||y||->0 ?? 582 | y = flat_grad.sub(prev_flat_grad) 583 | 584 | s = d.mul(t) 585 | 586 | if batch_mode: # y = y+ lm0 * s, to have a trust region 587 | y.add_(s,alpha=lm0) 588 | 589 | ys = y.dot(s) # y^T*s 590 | sn = s.norm().item() # ||s|| 591 | # FIXME batch_changed does not work for full batch mode (data might be the same) 592 | batch_changed= batch_mode and (n_iter==1 and state['n_iter']>1) 593 | if batch_changed: # batch has changed 594 | # online estimate of mean,variance of gradient (inter-batch, not intra-batch) 595 | # newmean <- oldmean + (grad - oldmean)/niter 596 | # moment <- oldmoment + (grad-oldmean)(grad-newmean) 597 | # variance = moment/(niter-1) 598 | 599 | g_old=flat_grad.clone(memory_format=torch.contiguous_format) 600 | g_old.add_(running_avg,alpha=-1.0) # grad-oldmean 601 | running_avg.add_(g_old,alpha=1.0/state['n_iter']) # newmean 602 | g_new=flat_grad.clone(memory_format=torch.contiguous_format) 603 | g_new.add_(running_avg,alpha=-1.0) # grad-newmean 604 | running_avg_sq.addcmul_(g_new,g_old,value=1) # +(grad-newmean)(grad-oldmean) 605 | alphabar=1/(1+running_avg_sq.sum()/((state['n_iter']-1)*(grad_nrm))) 606 | if be_verbose: 607 | print('iter %d |mean| %f |var| %f ||grad|| %f step %f y^Ts %f alphabar=%f'%(state['n_iter'],running_avg.sum(),running_avg_sq.sum()/(state['n_iter']-1),grad_nrm,t,ys,alphabar)) 608 | 609 | 610 | if ys > 1e-10*sn*sn and not batch_changed : 611 | # updating memory (only when we have y within a single batch) 612 | if len(old_dirs) == history_size: 613 | # shift history by one (limited-memory) 614 | old_dirs.pop(0) 615 | old_stps.pop(0) 616 | 617 | # store new direction/step 618 | old_dirs.append(y) 619 | old_stps.append(s) 620 | 621 | # update scale of initial Hessian approximation 622 | H_diag = ys / y.dot(y) # (y*y) 623 | 624 | if math.isnan(H_diag): 625 | print('Warning H_diag nan') 626 | 627 | # compute the approximate (L-BFGS) inverse Hessian 628 | # multiplied by the gradient 629 | num_old = len(old_dirs) 630 | 631 | if 'ro' not in state: 632 | state['ro'] = [None] * history_size 633 | state['al'] = [None] * history_size 634 | ro = state['ro'] 635 | al = state['al'] 636 | 637 | for i in range(num_old): 638 | ro[i] = 1. / old_dirs[i].dot(old_stps[i]) 639 | 640 | # iteration in L-BFGS loop collapsed to use just one buffer 641 | q = flat_grad.neg() 642 | for i in range(num_old - 1, -1, -1): 643 | al[i] = old_stps[i].dot(q) * ro[i] 644 | q.add_(old_dirs[i],alpha=-al[i]) 645 | 646 | # multiply by initial Hessian 647 | # r/d is the final direction 648 | d = r = torch.mul(q, H_diag) 649 | for i in range(num_old): 650 | be_i = old_dirs[i].dot(r) * ro[i] 651 | r.add_(old_stps[i],alpha=al[i] - be_i) 652 | 653 | if prev_flat_grad is None: 654 | prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) 655 | 656 | else: 657 | prev_flat_grad.copy_(flat_grad) 658 | 659 | prev_loss = loss 660 | 661 | ############################################################ 662 | # compute step length 663 | ############################################################ 664 | # reset initial guess for step size 665 | if state['n_iter'] == 1: 666 | t = min(1., 1. / abs_grad_sum) * lr 667 | else: 668 | t = lr 669 | 670 | # directional derivative 671 | gtd = flat_grad.dot(d) # g * d 672 | 673 | if math.isnan(gtd.item()): 674 | print('Warning grad norm infinite') 675 | print('iter %d'%state['n_iter']) 676 | print('||grad||=%f'%grad_nrm) 677 | print('||d||=%f'%d.norm().item()) 678 | # optional line search: user function 679 | ls_func_evals = 0 680 | if line_search_fn: 681 | # perform line search, using user function 682 | ##raise RuntimeError("line search function is not supported yet") 683 | #FF################################# 684 | # Note: we disable gradient calculation during line search 685 | # because it is not needed 686 | if not cost_use_gradient: 687 | torch.set_grad_enabled(False) 688 | if not batch_mode: 689 | t=self._linesearch_cubic(closure,d,1e-6) 690 | else: 691 | t=self._linesearch_backtrack(closure,d,flat_grad,alphabar) 692 | if not cost_use_gradient: 693 | torch.set_grad_enabled(True) 694 | 695 | if math.isnan(t): 696 | print('Warning: stepsize nan') 697 | t=lr 698 | self._add_grad(t, d) #FF param = param + t * d 699 | if be_verbose: 700 | print('step size=%f'%(t)) 701 | #FF################################# 702 | else: 703 | #FF Here, t = stepsize, d = -grad, in cache 704 | # no line search, simply move with fixed-step 705 | self._add_grad(t, d) #FF param = param + t * d 706 | if n_iter != max_iter: 707 | # re-evaluate function only if not in last iteration 708 | # the reason we do this: in a stochastic setting, 709 | # no use to re-evaluate that function here 710 | loss = float(closure()) 711 | flat_grad = self._gather_flat_grad() 712 | abs_grad_sum = flat_grad.abs().sum() 713 | if math.isnan(abs_grad_sum): 714 | print('Warning: gradient nan') 715 | break 716 | ls_func_evals = 1 717 | 718 | # update func eval 719 | current_evals += ls_func_evals 720 | state['func_evals'] += ls_func_evals 721 | 722 | ############################################################ 723 | # check conditions 724 | ############################################################ 725 | if n_iter == max_iter: 726 | break 727 | 728 | if current_evals >= max_eval: 729 | break 730 | 731 | if abs_grad_sum <= tolerance_grad: 732 | break 733 | 734 | if gtd > -tolerance_change: 735 | break 736 | 737 | if d.mul(t).abs_().sum() <= tolerance_change: 738 | break 739 | 740 | if abs(loss - prev_loss) < tolerance_change: 741 | break 742 | 743 | state['d'] = d 744 | state['t'] = t 745 | state['old_dirs'] = old_dirs 746 | state['old_stps'] = old_stps 747 | state['H_diag'] = H_diag 748 | state['prev_flat_grad'] = prev_flat_grad 749 | state['prev_loss'] = prev_loss 750 | 751 | if batch_mode: 752 | if 'running_avg' not in locals() or running_avg is None: 753 | running_avg=torch.zeros_like(flat_grad.data) 754 | running_avg_sq=torch.zeros_like(flat_grad.data) 755 | state['running_avg']=running_avg 756 | state['running_avg_sq']=running_avg_sq 757 | 758 | 759 | return orig_loss 760 | -------------------------------------------------------------------------------- /loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlesc-dirac/pytorch/283ea6b93785b87ebe409f7fe887401a7ccfb313/loss.png -------------------------------------------------------------------------------- /resnet9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlesc-dirac/pytorch/283ea6b93785b87ebe409f7fe887401a7ccfb313/resnet9.png --------------------------------------------------------------------------------