├── .gitignore
├── LICENSE
├── README.md
├── activation.png
├── cifar10_resnet.py
├── kan_pde.py
├── lbfgsb.py
├── lbfgsnew.py
├── loss.png
└── resnet9.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LBFGS optimizer
2 | An improved LBFGS (and LBFGS-B) optimizer for PyTorch is provided with the code. Further details are given [in this paper](https://ieeexplore.ieee.org/document/8755567). Also see [this introduction](http://sagecal.sourceforge.net/pytorch/index.html).
3 |
4 | Examples of use:
5 |
6 | * Federated learning: see [these examples](https://github.com/SarodYatawatta/federated-pytorch-test).
7 |
8 | * Calibration and other inverse problems: see [radio interferometric calibration](https://github.com/SarodYatawatta/calibration-pytorch-test).
9 |
10 | * K-harmonic means clustering: see [LOFAR system health management](https://github.com/SarodYatawatta/LSHM).
11 |
12 | * Other problems: see [this example](https://ieeexplore.ieee.org/abstract/document/8588731).
13 |
14 | Files included are:
15 |
16 | ``` lbfgsnew.py ```: New LBFGS optimizer
17 |
18 | ``` lbfgsb.py ```: LBFGS-B optimizer (with bound constraints)
19 |
20 | ``` cifar10_resnet.py ```: CIFAR10 ResNet training example (see figures below)
21 |
22 | ``` kan_pde.py ```: Kolmogorov Arnold network PDE example using LBFGS-B
23 |
24 |
25 |
26 | The above figure shows the training loss and training time [using Colab](https://colab.research.google.com/notebooks/intro.ipynb) with one GPU. ResNet18 and ResNet101 models are used. Test accuracy after 20 epochs: 84% for LBFGS and 82% for Adam.
27 |
28 | Changing the activation from commonly used ```ReLU``` to others like ```ELU``` gives faster convergence in LBFGS, as seen in the figure below.
29 |
30 |
31 |
32 | Here is a comparison of both training error and test accuracy for ResNet9 using LBFGS and Adam.
33 |
34 |
35 |
36 | Example usage in full batch mode:
37 |
38 | ```
39 | from lbfgsnew import LBFGSNew
40 | optimizer = LBFGSNew(model.parameters(), history_size=7, max_iter=100, line_search_fn=True, batch_mode=False)
41 | ```
42 |
43 | Example usage in minibatch mode:
44 |
45 | ```
46 | from lbfgsnew import LBFGSNew
47 | optimizer = LBFGSNew(model.parameters(), history_size=7, max_iter=2, line_search_fn=True, batch_mode=True)
48 | ```
49 |
50 | Note: for certain problems, the gradient can also be part of the cost, for example in TV regularization. In such situations, give the option ```cost_use_gradient=True``` to ```LBFGSNew()```. However, this will increase the computational cost, so only use when needed.
51 |
--------------------------------------------------------------------------------
/activation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nlesc-dirac/pytorch/283ea6b93785b87ebe409f7fe887401a7ccfb313/activation.png
--------------------------------------------------------------------------------
/cifar10_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 |
5 | import math
6 | import time
7 |
8 | # (try to) use a GPU for computation?
9 | use_cuda=True
10 | if use_cuda and torch.cuda.is_available():
11 | mydevice=torch.device('cuda')
12 | else:
13 | mydevice=torch.device('cpu')
14 |
15 |
16 | # try replacing relu with elu
17 | torch.manual_seed(69)
18 | default_batch=128 # no. of batches per epoch 50000/default_batch
19 | batches_for_report=10#
20 |
21 | transform=transforms.Compose(
22 | [transforms.ToTensor(),
23 | transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
24 |
25 |
26 | trainset=torchvision.datasets.CIFAR10(root='./torchdata', train=True,
27 | download=True, transform=transform)
28 |
29 | trainloader=torch.utils.data.DataLoader(trainset, batch_size=default_batch,
30 | shuffle=True, num_workers=2)
31 |
32 | testset=torchvision.datasets.CIFAR10(root='./torchdata', train=False,
33 | download=True, transform=transform)
34 |
35 | testloader=torch.utils.data.DataLoader(testset, batch_size=default_batch,
36 | shuffle=False, num_workers=0)
37 |
38 |
39 | classes=('plane', 'car', 'bird', 'cat',
40 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
41 |
42 |
43 |
44 | import matplotlib.pyplot as plt
45 | import numpy as np
46 |
47 | from torch.autograd import Variable
48 | import torch.nn as nn
49 | import torch.nn.functional as F
50 |
51 |
52 | '''ResNet in PyTorch.
53 | Reference:
54 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
55 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
56 |
57 | From: https://github.com/kuangliu/pytorch-cifar
58 | '''
59 | import torch
60 | import torch.nn as nn
61 | import torch.nn.functional as F
62 |
63 |
64 | class BasicBlock(nn.Module):
65 | expansion = 1
66 |
67 | def __init__(self, in_planes, planes, stride=1):
68 | super(BasicBlock, self).__init__()
69 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
70 | self.bn1 = nn.BatchNorm2d(planes)
71 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
72 | self.bn2 = nn.BatchNorm2d(planes)
73 |
74 | self.shortcut = nn.Sequential()
75 | if stride != 1 or in_planes != self.expansion*planes:
76 | self.shortcut = nn.Sequential(
77 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
78 | nn.BatchNorm2d(self.expansion*planes)
79 | )
80 |
81 | def forward(self, x):
82 | out = F.elu(self.bn1(self.conv1(x)))
83 | out = self.bn2(self.conv2(out))
84 | out += self.shortcut(x)
85 | out = F.elu(out)
86 | return out
87 |
88 |
89 | class Bottleneck(nn.Module):
90 | expansion = 4
91 |
92 | def __init__(self, in_planes, planes, stride=1):
93 | super(Bottleneck, self).__init__()
94 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
95 | self.bn1 = nn.BatchNorm2d(planes)
96 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
97 | self.bn2 = nn.BatchNorm2d(planes)
98 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
99 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
100 |
101 | self.shortcut = nn.Sequential()
102 | if stride != 1 or in_planes != self.expansion*planes:
103 | self.shortcut = nn.Sequential(
104 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
105 | nn.BatchNorm2d(self.expansion*planes)
106 | )
107 |
108 | def forward(self, x):
109 | out = F.elu(self.bn1(self.conv1(x)))
110 | out = F.elu(self.bn2(self.conv2(out)))
111 | out = self.bn3(self.conv3(out))
112 | out += self.shortcut(x)
113 | out = F.elu(out)
114 | return out
115 |
116 |
117 | class ResNet(nn.Module):
118 | def __init__(self, block, num_blocks, num_classes=10):
119 | super(ResNet, self).__init__()
120 | self.in_planes = 64
121 |
122 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
123 | self.bn1 = nn.BatchNorm2d(64)
124 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
125 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
126 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
127 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
128 | self.linear = nn.Linear(512*block.expansion, num_classes)
129 |
130 | def _make_layer(self, block, planes, num_blocks, stride):
131 | strides = [stride] + [1]*(num_blocks-1)
132 | layers = []
133 | for stride in strides:
134 | layers.append(block(self.in_planes, planes, stride))
135 | self.in_planes = planes * block.expansion
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | out = F.elu(self.bn1(self.conv1(x)))
140 | out = self.layer1(out)
141 | out = self.layer2(out)
142 | out = self.layer3(out)
143 | out = self.layer4(out)
144 | out = F.avg_pool2d(out, 4)
145 | out = out.view(out.size(0), -1)
146 | out = self.linear(out)
147 | return out
148 |
149 | def ResNet9():
150 | return ResNet(BasicBlock, [1,1,1,1])
151 |
152 | def ResNet18():
153 | return ResNet(BasicBlock, [2,2,2,2])
154 |
155 | def ResNet34():
156 | return ResNet(BasicBlock, [3,4,6,3])
157 |
158 | def ResNet50():
159 | return ResNet(Bottleneck, [3,4,6,3])
160 |
161 | def ResNet101():
162 | return ResNet(Bottleneck, [3,4,23,3])
163 |
164 | def ResNet152():
165 | return ResNet(Bottleneck, [3,8,36,3])
166 |
167 |
168 | # enable this to use wide ResNet
169 | wide_resnet=False
170 | if not wide_resnet:
171 | net=ResNet18().to(mydevice)
172 | else:
173 | # use wide residual net https://arxiv.org/abs/1605.07146
174 | net=torchvision.models.resnet.wide_resnet50_2().to(mydevice)
175 |
176 |
177 | #####################################################
178 | def verification_error_check(net):
179 | correct=0
180 | total=0
181 | for data in testloader:
182 | images,labels=data
183 | outputs=net(Variable(images).to(mydevice))
184 | _,predicted=torch.max(outputs.data,1)
185 | correct += (predicted==labels.to(mydevice)).sum()
186 | total += labels.size(0)
187 |
188 | return 100*correct//total
189 | #####################################################
190 |
191 | lambda1=0.000001
192 | lambda2=0.001
193 |
194 | # loss function and optimizer
195 | import torch.optim as optim
196 | from lbfgsnew import LBFGSNew # custom optimizer
197 | criterion=nn.CrossEntropyLoss()
198 | #optimizer=optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
199 | #optimizer=optim.Adam(net.parameters(), lr=0.001)
200 | optimizer = LBFGSNew(net.parameters(), history_size=7, max_iter=2, line_search_fn=True,batch_mode=True)
201 |
202 |
203 | load_model=False
204 | # update from a saved model
205 | if load_model:
206 | checkpoint=torch.load('./res18.model',map_location=mydevice)
207 | net.load_state_dict(checkpoint['model_state_dict'])
208 | net.train() # initialize for training (BN,dropout)
209 |
210 | start_time=time.time()
211 | use_lbfgs=True
212 | # train network
213 | for epoch in range(20):
214 | running_loss=0.0
215 | for i,data in enumerate(trainloader,0):
216 | # get the inputs
217 | inputs,labels=data
218 | # wrap them in variable
219 | inputs,labels=Variable(inputs).to(mydevice),Variable(labels).to(mydevice)
220 |
221 | if not use_lbfgs:
222 | # zero gradients
223 | optimizer.zero_grad()
224 | # forward+backward optimize
225 | outputs=net(inputs)
226 | loss=criterion(outputs,labels)
227 | loss.backward()
228 | optimizer.step()
229 | else:
230 | if not wide_resnet:
231 | layer1=torch.cat([x.view(-1) for x in net.layer1.parameters()])
232 | layer2=torch.cat([x.view(-1) for x in net.layer2.parameters()])
233 | layer3=torch.cat([x.view(-1) for x in net.layer3.parameters()])
234 | layer4=torch.cat([x.view(-1) for x in net.layer4.parameters()])
235 |
236 | def closure():
237 | if torch.is_grad_enabled():
238 | optimizer.zero_grad()
239 | outputs=net(inputs)
240 | if not wide_resnet:
241 | l1_penalty=lambda1*(torch.norm(layer1,1)+torch.norm(layer2,1)+torch.norm(layer3,1)+torch.norm(layer4,1))
242 | l2_penalty=lambda2*(torch.norm(layer1,2)+torch.norm(layer2,2)+torch.norm(layer3,2)+torch.norm(layer4,2))
243 | loss=criterion(outputs,labels)+l1_penalty+l2_penalty
244 | else:
245 | l1_penalty=0
246 | l2_penalty=0
247 | loss=criterion(outputs,labels)
248 | if loss.requires_grad:
249 | loss.backward()
250 | #print('loss %f l1 %f l2 %f'%(loss,l1_penalty,l2_penalty))
251 | return loss
252 | optimizer.step(closure)
253 | # only for diagnostics
254 | outputs=net(inputs)
255 | loss=criterion(outputs,labels)
256 | running_loss +=loss.data.item()
257 |
258 | if math.isnan(loss.data.item()):
259 | print('loss became nan at %d'%i)
260 | break
261 |
262 | # print statistics
263 | if i%(batches_for_report) == (batches_for_report-1): # after every 'batches_for_report'
264 | print('%f: [%d, %5d] loss: %.5f accuracy: %.3f'%
265 | (time.time()-start_time,epoch+1,i+1,running_loss/batches_for_report,
266 | verification_error_check(net)))
267 | running_loss=0.0
268 |
269 | print('Finished Training')
270 |
271 |
272 | # save model (and other extra items)
273 | torch.save({
274 | 'model_state_dict':net.state_dict(),
275 | 'epoch':epoch,
276 | 'optimizer_state_dict':optimizer.state_dict(),
277 | 'running_loss':running_loss,
278 | },'./res.model')
279 |
280 |
281 | # whole dataset
282 | correct=0
283 | total=0
284 | for data in trainloader:
285 | images,labels=data
286 | outputs=net(Variable(images).to(mydevice)).cpu()
287 | _,predicted=torch.max(outputs.data,1)
288 | total += labels.size(0)
289 | correct += (predicted==labels).sum()
290 |
291 | print('Accuracy of the network on the %d train images: %d %%'%
292 | (total,100*correct//total))
293 |
294 | correct=0
295 | total=0
296 | for data in testloader:
297 | images,labels=data
298 | outputs=net(Variable(images).to(mydevice)).cpu()
299 | _,predicted=torch.max(outputs.data,1)
300 | total += labels.size(0)
301 | correct += (predicted==labels).sum()
302 |
303 | print('Accuracy of the network on the %d test images: %d %%'%
304 | (total,100*correct//total))
305 |
306 |
307 | class_correct=list(0. for i in range(10))
308 | class_total=list(0. for i in range(10))
309 | for data in testloader:
310 | images,labels=data
311 | outputs=net(Variable(images).to(mydevice)).cpu()
312 | _,predicted=torch.max(outputs.data,1)
313 | c=(predicted==labels).squeeze()
314 | for i in range(4):
315 | label=labels[i]
316 | class_correct[label] += c[i]
317 | class_total[label] += 1
318 |
319 | for i in range(10):
320 | print('Accuracy of %5s : %2d %%' %
321 | (classes[i],100*float(class_correct[i])/float(class_total[i])))
322 |
--------------------------------------------------------------------------------
/kan_pde.py:
--------------------------------------------------------------------------------
1 | # This is an exmple of training a KAN model, original at
2 | # https://kindxiaoming.github.io/pykan/Examples/Example_6_PDE.html
3 | # using the LBFGS-B optimizer
4 |
5 | from kan import KAN
6 | from lbfgsb import LBFGSB
7 | from lbfgsnew import LBFGSNew
8 | import torch
9 | import matplotlib.pyplot as plt
10 | from torch import autograd
11 | from tqdm import tqdm
12 | import numpy as np
13 |
14 | use_cuda=True
15 | if use_cuda and torch.cuda.is_available():
16 | mydevice=torch.device('cuda')
17 | else:
18 | mydevice=torch.device('cpu')
19 |
20 |
21 | dim = 2
22 | np_i = 21 # number of interior points (along each dimension)
23 | np_b = 21 # number of boundary points (along each dimension)
24 | ranges = [-1, 1]
25 |
26 | model = KAN(width=[2,2,1], grid=5, k=3, grid_eps=1.0, device=mydevice)
27 |
28 | # get all parameters (all may not be trainable)
29 | n_params = sum([np.prod(p.size()) for p in model.parameters()])
30 | # lower/upper bounds for parameters
31 | x_l=(torch.ones(n_params)*(-100.0)).to(mydevice)
32 | x_u=(torch.ones(n_params)*(100.0)).to(mydevice)
33 |
34 | def batch_jacobian(func, x, create_graph=False):
35 | # x in shape (Batch, Length)
36 | def _func_sum(x):
37 | return func(x).sum(dim=0)
38 | return autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)
39 |
40 | # define solution
41 | sol_fun = lambda x: torch.sin(torch.pi*x[:,[0]])*torch.sin(torch.pi*x[:,[1]])
42 | source_fun = lambda x: -2*torch.pi**2 * torch.sin(torch.pi*x[:,[0]])*torch.sin(torch.pi*x[:,[1]])
43 |
44 | # interior
45 | sampling_mode = 'random' # 'random' or 'mesh'
46 |
47 | x_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i).to(mydevice)
48 | y_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i).to(mydevice)
49 | X, Y = torch.meshgrid(x_mesh, y_mesh, indexing="ij")
50 | if sampling_mode == 'mesh':
51 | #mesh
52 | x_i = torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)
53 | else:
54 | #random
55 | x_i = torch.rand((np_i**2,2))*2-1
56 | x_i=x_i.to(mydevice)
57 |
58 | # boundary, 4 sides
59 | helper = lambda X, Y: torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)
60 | xb1 = helper(X[0], Y[0])
61 | xb2 = helper(X[-1], Y[0])
62 | xb3 = helper(X[:,0], Y[:,0])
63 | xb4 = helper(X[:,0], Y[:,-1])
64 | x_b = torch.cat([xb1, xb2, xb3, xb4], dim=0)
65 |
66 | steps = 20
67 | alpha = 0.1
68 | log = 1
69 |
70 | #torch.autograd.set_detect_anomaly(True)
71 | def train():
72 | # try running with batch_mode=True and batch_mode=False (both should work)
73 | optimizer = LBFGSB(model.parameters(), lower_bound=x_l, upper_bound=x_u, history_size=10, tolerance_grad=1e-32, tolerance_change=1e-32, batch_mode=True, cost_use_gradient=True)
74 | #optimizer = LBFGSNew(model.parameters(), history_size=10, tolerance_grad=1e-32, tolerance_change=1e-32, batch_mode=True, cost_use_gradient=True)
75 |
76 | pbar = tqdm(range(steps), desc='description')
77 |
78 | for _ in pbar:
79 | def closure():
80 | global pde_loss, bc_loss
81 | optimizer.zero_grad()
82 | # interior loss
83 | sol = sol_fun(x_i)
84 | sol_D1_fun = lambda x: batch_jacobian(model, x, create_graph=True)[:,0,:]
85 | sol_D1 = sol_D1_fun(x_i)
86 | sol_D2 = batch_jacobian(sol_D1_fun, x_i, create_graph=True)[:,:,:]
87 | lap = torch.sum(torch.diagonal(sol_D2, dim1=1, dim2=2), dim=1, keepdim=True)
88 | source = source_fun(x_i)
89 | pde_loss = torch.mean((lap - source)**2)
90 |
91 | # boundary loss
92 | bc_true = sol_fun(x_b)
93 | bc_pred = model(x_b)
94 | bc_loss = torch.mean((bc_pred-bc_true)**2)
95 |
96 | loss = alpha * pde_loss + bc_loss
97 | loss.backward()
98 | return loss
99 |
100 | if _ % 5 == 0 and _ < 50:
101 | model.update_grid_from_samples(x_i)
102 |
103 | optimizer.step(closure)
104 | sol = sol_fun(x_i)
105 | loss = alpha * pde_loss + bc_loss
106 | l2 = torch.mean((model(x_i) - sol)**2)
107 |
108 | if _ % log == 0:
109 | pbar.set_description("pde loss: %.2e | bc loss: %.2e | l2: %.2e " % (pde_loss.cpu().detach().numpy(), bc_loss.cpu().detach().numpy(), l2.cpu().detach().numpy()))
110 |
111 | train()
112 |
--------------------------------------------------------------------------------
/lbfgsb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from functools import reduce
3 | from torch.optim.optimizer import Optimizer
4 |
5 | import math
6 |
7 | be_verbose=False
8 |
9 | class LBFGSB(Optimizer):
10 | """Implements L-BFGS-B algorithm.
11 | Primary reference:
12 | 1) MATLAB code https://github.com/bgranzow/L-BFGS-B by Brian Granzow
13 | Theory based on:
14 | 1) A Limited Memory Algorithm for Bound Constrained Optimization, Byrd et al. 1995
15 | 2) Numerical Optimization, Nocedal and Wright, 2006
16 |
17 | .. warning::
18 | This optimizer doesn't support per-parameter options and parameter
19 | groups (there can be only one).
20 |
21 | .. note::
22 | This is still WIP, the saving/restoring of state dict is not fully implemented.
23 |
24 | Arguments:
25 | lower_bound (shape equal to parameter vector): parameters > lower_bound
26 | upper_bound (shape equal to parameter vector): parameters < upper_bound
27 | max_iter (int): maximal number of iterations per optimization step
28 | (default: 10)
29 | tolerance_grad (float): termination tolerance on first order optimality
30 | (default: 1e-5).
31 | tolerance_change (float): termination tolerance on function
32 | value/parameter changes (default: 1e-20).
33 | history_size (int): update history size (default: 7).
34 | batch_mode: True for stochastic version (default: False)
35 | cost_use_gradient: set this to True when the cost function also needs the gradient, for example in TV (total variation) regularization. (default: False)
36 |
37 | Example:
38 | ------
39 | >>> x=torch.rand(2,requires_grad=True,dtype=torch.float64,device=mydevice)
40 | >>> x_l=torch.ones(2,device=mydevice)*(-1.0)
41 | >>> x_u=torch.ones(2,device=mydevice)
42 | >>> optimizer=LBFGSB([x],lower_bound=x_l, upper_bound=x_u, history_size=7, max_iter=4, batch_mode=True)
43 | >>> def cost_function():
44 | >>> f=torch.pow(1.0-x[0],2.0)+100.0*torch.pow(x[1]-x[0]*x[0],2.0)
45 | >>> return f
46 | >>> for ci in range(10):
47 | >>> def closure():
48 | >>> if torch.is_grad_enabled():
49 | >>> optimizer.zero_grad()
50 | >>> loss=cost_function()
51 | >>> if loss.requires_grad:
52 | >>> loss.backward()
53 | >>> return loss
54 | >>>
55 | >>> optimizer.step(closure)
56 | ------
57 | """
58 |
59 | def __init__(self, params, lower_bound, upper_bound, max_iter=10,
60 | tolerance_grad=1e-5, tolerance_change=1e-20, history_size=7,
61 | batch_mode=False, cost_use_gradient=False):
62 | defaults = dict(max_iter=max_iter,
63 | tolerance_grad=tolerance_grad, tolerance_change=tolerance_change,
64 | history_size=history_size,
65 | batch_mode=batch_mode,
66 | cost_use_gradient=cost_use_gradient)
67 | super(LBFGSB, self).__init__(params, defaults)
68 |
69 | if len(self.param_groups) != 1:
70 | raise ValueError("LBFGSB doesn't support per-parameter options "
71 | "(parameter groups)")
72 |
73 | self._params = self.param_groups[0]['params']
74 | self._numel_cache = None
75 | self._device = self._params[0].device
76 | self._dtype= self._params[0].dtype
77 | self._l=lower_bound.clone(memory_format=torch.contiguous_format).to(self._device)
78 | self._u=upper_bound.clone(memory_format=torch.contiguous_format).to(self._device)
79 | self._m=history_size
80 | self._n=self._numel()
81 | # local storage as matrices (instead of curvature pairs)
82 | self._W=torch.zeros(self._n,self._m*2,dtype=self._dtype).to(self._device)
83 | self._Y=torch.zeros(self._n,self._m,dtype=self._dtype).to(self._device)
84 | self._S=torch.zeros(self._n,self._m,dtype=self._dtype).to(self._device)
85 | self._M=torch.zeros(self._m*2,self._m*2,dtype=self._dtype).to(self._device)
86 |
87 | self._fit_to_constraints()
88 |
89 | self._eps=tolerance_change
90 | self._realmax=1e20
91 | self._theta=1
92 |
93 | # batch mode
94 | self.running_avg=None
95 | self.running_avg_sq=None
96 | self.alphabar=1.0
97 |
98 | def _numel(self):
99 | if self._numel_cache is None:
100 | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
101 | return self._numel_cache
102 |
103 | def _gather_flat_grad(self):
104 | views = []
105 | for p in self._params:
106 | if p.grad is None:
107 | view = p.data.new(p.data.numel()).zero_()
108 | elif p.grad.data.is_sparse:
109 | view = p.grad.data.to_dense().contiguous().view(-1)
110 | else:
111 | view = p.grad.data.contiguous().view(-1)
112 | views.append(view)
113 | return torch.cat(views, 0)
114 |
115 | def _add_grad(self, step_size, update):
116 | offset = 0
117 | for p in self._params:
118 | numel = p.numel()
119 | # view as to avoid deprecated pointwise semantics
120 | p.data.add_(update[offset:offset + numel].view_as(p.data), alpha=step_size)
121 | offset += numel
122 | assert offset == self._numel()
123 |
124 | #copy the parameter values out, create a list of vectors
125 | def _copy_params_out(self):
126 | return [p.detach().flatten().clone(memory_format=torch.contiguous_format) for p in self._params]
127 |
128 | #copy the parameter values back, dividing the list appropriately
129 | def _copy_params_in(self,new_params):
130 | with torch.no_grad():
131 | for p, pdata in zip(self._params, new_params):
132 | p.copy_(pdata.view_as(p))
133 |
134 | # restrict parameters to constraints
135 | def _fit_to_constraints(self):
136 | params=[]
137 | for p in self._params:
138 | # make a vector
139 | p = p.detach().flatten()
140 | params.append(p)
141 | x=torch.cat(params,0)
142 | for i in range(x.numel()):
143 | if (x[i]self._u[i]):
146 | x[i]=self._u[i]
147 | offset = 0
148 | with torch.no_grad():
149 | for p in self._params:
150 | numel = p.numel()
151 | p.copy_(x[offset:offset + numel].view_as(p))
152 | offset += numel
153 | assert offset == self._numel()
154 |
155 | def _get_optimality(self,g):
156 | # get the inf-norm of the projected gradient
157 | # pp. 17, (6.1)
158 | # x: nx1 parameters
159 | # g: nx1 gradient
160 | # l: nx1 lower bound
161 | # u: nx1 upper bound
162 | x=torch.cat(self._copy_params_out(),0)
163 | projected_g=x-g
164 | for i in range(x.numel()):
165 | if projected_g[i]self._u[i]:
168 | projected_g[i]=self._u[i]
169 | projected_g=projected_g-x
170 | return max(abs(projected_g))
171 |
172 | def _get_breakpoints(self,x,g):
173 | # compute breakpoints for Cauchy point
174 | # pp 5-6, (4.1), (4.2), pp. 8, CP initialize \mathcal{F}
175 | # x: nx1 parameters
176 | # g: nx1 gradient
177 | # l: nx1 lower bound
178 | # u: nx1 upper bound
179 | # out:
180 | # t: nx1 breakpoint vector
181 | # d: nx1 search direction vector
182 | # F: nx1 indices that sort t from low to high
183 | t=torch.zeros(self._n,1,dtype=self._dtype,device=self._device)
184 | d=-g
185 | for i in range(self._n):
186 | if (g[i]<0.0):
187 | t[i]=(x[i]-self._u[i])/g[i]
188 | elif (g[i]>0.0):
189 | t[i]=(x[i]-self._l[i])/g[i]
190 | else:
191 | t[i]=self._realmax
192 |
193 | if (t[i]0, scaling
208 | # W: nx2m
209 | # M: 2mx2m
210 | # out:
211 | # xc: nx1 the generalized Cauchy point
212 | # c: 2mx1 initialization vector for subspace minimization
213 |
214 | x=torch.cat(self._copy_params_out(),0)
215 | tt,d,F=self._get_breakpoints(x,g)
216 | xc=x.clone()
217 | c=torch.zeros(2*self._m,1,dtype=self._dtype,device=self._device)
218 | p=torch.mm(self._W.transpose(0,1),d)
219 | fp=-torch.mm(d.transpose(0,1),d)
220 | fpp=-self._theta*fp-torch.mm(p.transpose(0,1),torch.mm(self._M,p))
221 | fp=fp.squeeze()
222 | fpp=fpp.squeeze()
223 | fpp0=-self._theta*fp
224 | if (fpp != 0.0):
225 | dt_min=-fp/fpp
226 | else:
227 | dt_min=-fp/self._eps
228 | t_old=0
229 | # find lowest index i where F[i] is positive (minimum t)
230 | for j in range(self._n):
231 | i=j
232 | if F[i]>=0.0:
233 | break
234 | b=F[i]
235 | t=tt[b]
236 | dt=t-t_old
237 |
238 | while (idt):
239 | if d[b]>0.0:
240 | xc[b]=self._u[b]
241 | elif d[b]<0.0:
242 | xc[b]=self._l[b]
243 |
244 | zb=xc[b]-x[b]
245 | c=c+dt*p
246 | gb=g[b]
247 | Wbt=self._W[b,:]
248 | Wbt=Wbt.unsqueeze(-1).transpose(0,1)
249 | fp=fp+dt*fpp+gb*gb+self._theta*gb*zb-gb*torch.mm(Wbt,torch.mm(self._M,c))
250 | fpp=fpp-self._theta*gb*gb-2.0*gb*torch.mm(Wbt,torch.mm(self._M,p))-gb*gb*torch.mm(Wbt,torch.mm(self._M,Wbt.transpose(0,1)))
251 | fp=fp.squeeze()
252 | fpp=fpp.squeeze()
253 | fpp=max(self._eps*fpp0,fpp)
254 | p=p+gb*Wbt.transpose(0,1)
255 | d[b]=0.0
256 | if (fpp != 0.0):
257 | dt_min=-fp/fpp
258 | else:
259 | dt_min=-fp/self._eps
260 | t_old=t
261 | i=i+1
262 | if i0, scaling
287 | # W: nx2m
288 | # M: 2mx2m
289 | # out:
290 | # xbar: nx1 minimizer
291 | # line_search_flag: bool
292 |
293 | line_search_flag=True
294 | free_vars_index=list()
295 | for i in range(self._n):
296 | if (xc[i] != self._u[i]) and (xc[i] != self._l[i]):
297 | free_vars_index.append(i)
298 |
299 | n_free_vars=len(free_vars_index)
300 | if n_free_vars==0:
301 | xbar=xc.clone()
302 | line_search_flag=False
303 | return xbar,line_search_flag
304 |
305 | WtZ=torch.zeros((2*self._m,n_free_vars),dtype=self._dtype,device=self._device)
306 | # each column of WtZ (2*m values) = row of i-th free variable in W (2*m values)
307 | for i in range(n_free_vars):
308 | WtZ[:,i]=self._W[free_vars_index[i],:]
309 |
310 | x=torch.cat(self._copy_params_out(),0)
311 | rr=g+self._theta*(xc-x) - torch.mm(self._W,torch.mm(self._M,c)).squeeze()
312 | r=torch.zeros(n_free_vars,1,dtype=self._dtype,device=self._device)
313 | for i in range(n_free_vars):
314 | r[i]=rr[free_vars_index[i]]
315 |
316 | invtheta=1.0/self._theta
317 | v=torch.mm(self._M,torch.mm(WtZ,r))
318 | N=invtheta*torch.mm(WtZ,WtZ.transpose(0,1))
319 | N=torch.eye(2*self._m).to(self._device)-torch.mm(self._M,N)
320 | v,_,_,_=torch.linalg.lstsq(N,v,rcond=None)
321 | du=-invtheta*r-invtheta*invtheta*torch.mm(WtZ.transpose(0,1),v)
322 |
323 | alpha_star=self._find_alpha(xc,du,free_vars_index)
324 | d_star=alpha_star*du
325 | xbar=xc.clone()
326 | for i in range(n_free_vars):
327 | idx=free_vars_index[i]
328 | xbar[idx]=xbar[idx]+d_star[i]
329 |
330 | return xbar,line_search_flag
331 |
332 | def _find_alpha(self, xc, du, free_vars_index):
333 | # pp. 11, (5.8)
334 | # l: nx1 lower bound
335 | # u: nx1 upper bound
336 | # xc: nx1 generalized Cauchy point
337 | # du: n_free_varsx1
338 | # free_vars_index: n_free_varsx1 indices of free variables
339 | # out:
340 | # alpha_star: positive scaling parameter
341 |
342 | n_free_vars=len(free_vars_index)
343 | alpha_star=1.0
344 | for i in range(n_free_vars):
345 | idx=free_vars_index[i]
346 | if du[i]>0.0:
347 | alpha_star=min(alpha_star,(self._u[idx]-xc[idx])/du[i])
348 | elif du[i]<0.0:
349 | alpha_star=min(alpha_star,(self._l[idx]-xc[idx])/du[i])
350 |
351 | return alpha_star
352 |
353 |
354 | def _linesearch_backtrack(self, closure, f_old, gk, pk, alphabar):
355 | """Line search (backtracking)
356 |
357 | Arguments:
358 | closure (callable): A closure that reevaluates the model
359 | and returns the loss.
360 | f_old: original cost
361 | gk: gradient vector
362 | pk: step direction vector
363 | alphabar: max step size
364 | """
365 | c1=1e-4
366 | citer=35
367 | alphak=alphabar
368 |
369 | x0list=self._copy_params_out()
370 | xk=[x.clone() for x in x0list]
371 | self._add_grad(alphak,pk)
372 | f_new=float(closure())
373 | s=gk
374 | prodterm=c1*s.dot(pk)
375 | ci=0
376 | while (cif_old+alphak*prodterm)):
377 | alphak=0.5*alphak
378 | self._copy_params_in(xk)
379 | self._add_grad(alphak,pk)
380 | f_new=float(closure())
381 | ci=ci+1
382 |
383 | self._copy_params_in(xk)
384 | return alphak
385 |
386 |
387 | def _strong_wolfe(self, closure, f0, g0, p):
388 | # line search to satisfy strong Wolfe conditions
389 | # Alg 3.5, pp. 60, Numerical optimization Nocedal & Wright
390 | # cost: cost function R^n -> 1
391 | # gradient: gradient function R^n -> R^n
392 | # x0: nx1 initial parameters
393 | # f0: 1 intial cost
394 | # g0: nx1 initial gradient
395 | # p: nx1 intial search direction
396 | # out:
397 | # alpha: step length
398 |
399 | c1=1e-4
400 | c2=0.9
401 | alpha_max=2.5
402 | alpha_im1=0
403 | alpha_i=1
404 | f_im1=f0
405 | dphi0=torch.dot(g0,p)
406 |
407 | # make a copy of original params
408 | x0list=self._copy_params_out()
409 | x0=[x.clone() for x in x0list]
410 |
411 | i=0
412 | max_iters=20
413 | while 1:
414 | # x=x0+alpha_i*p
415 | self._copy_params_in(x0)
416 | self._add_grad(alpha_i,p)
417 | f_i=float(closure())
418 | if (f_i>f0+c1*dphi0) or ((i>1) and (f_i>f_im1)):
419 | alpha=self._alpha_zoom(closure,x0,f0,g0,p,alpha_im1,alpha_i)
420 | break
421 | g_i=self._gather_flat_grad()
422 | dphi=torch.dot(g_i,p)
423 | if (abs(dphi)<=-c2*dphi0):
424 | alpha=alpha_i
425 | break
426 | if (dphi>=0.0):
427 | alpha=self._alpha_zoom(closure,x0,f0,g0,p,alpha_i,alpha_im1)
428 | break
429 | alpha_im1=alpha_i
430 | f_im1=f_i
431 | alpha_i=alpha_i+0.8*(alpha_max-alpha_i)
432 | if (i>max_iters):
433 | alpha=alpha_i
434 | break
435 | i=i+1
436 |
437 | # restore original params
438 | self._copy_params_in(x0)
439 | return alpha
440 |
441 |
442 | def _alpha_zoom(self, closure, x0, f0, g0, p, alpha_lo, alpha_hi):
443 | # Alg 3.6, pp. 61, Numerical optimization Nocedal & Wright
444 | # cost: cost function R^n -> 1
445 | # gradient: gradient function R^n -> R^n
446 | # x0: list() initial parameters
447 | # f0: 1 intial cost
448 | # g0: nx1 initial gradient
449 | # p: nx1 intial search direction
450 | # alpha_lo: low limit for alpha
451 | # alpha_hi: high limit for alpha
452 | # out:
453 | # alpha: zoomed step length
454 | c1=1e-4
455 | c2=0.9
456 | i=0
457 | max_iters=20
458 | dphi0=torch.dot(g0,p)
459 | while 1:
460 | alpha_i=0.5*(alpha_lo+alpha_hi)
461 | alpha=alpha_i
462 | # x=x0+alpha_i*p
463 | self._copy_params_in(x0)
464 | self._add_grad(alpha_i,p)
465 | f_i=float(closure())
466 | g_i=self._gather_flat_grad()
467 | # x_lo=x0+alpha_lo*p
468 | self._copy_params_in(x0)
469 | self._add_grad(alpha_lo,p)
470 | f_lo=float(closure())
471 | if ((f_i>f0+c1*alpha_i*dphi0) or (f_i>=f_lo)):
472 | alpha_hi=alpha_i
473 | else:
474 | dphi=torch.dot(g_i,p)
475 | if ((abs(dphi)<=-c2*dphi0)):
476 | alpha=alpha_i
477 | break
478 | if (dphi*(alpha_hi-alpha_lo)>=0.0):
479 | alpha_hi=alpha_lo
480 | alpha_lo=alpha_i
481 | i=i+1
482 | if (i>max_iters):
483 | alpha=alpha_i
484 | break
485 |
486 | return alpha
487 |
488 |
489 |
490 |
491 | def step(self, closure):
492 | """Performs a single optimization step.
493 |
494 | Arguments:
495 | closure (callable): A closure that reevaluates the model
496 | and returns the loss.
497 | """
498 | assert len(self.param_groups) == 1
499 |
500 | group = self.param_groups[0]
501 | max_iter = group['max_iter']
502 | tolerance_grad = group['tolerance_grad']
503 | tolerance_change = group['tolerance_change']
504 | history_size = group['history_size']
505 |
506 | batch_mode = group['batch_mode']
507 | cost_use_gradient = group['cost_use_gradient']
508 |
509 |
510 | # NOTE: LBFGS has only global state, but we register it as state for
511 | # the first param, because this helps with casting in load_state_dict
512 | state = self.state[self._params[0]]
513 | state.setdefault('func_evals', 0)
514 | state.setdefault('n_iter', 0)
515 |
516 |
517 | # evaluate initial f(x) and df/dx
518 | orig_loss = closure()
519 | f= float(orig_loss)
520 | current_evals = 1
521 | state['func_evals'] += 1
522 |
523 | g=self._gather_flat_grad()
524 | abs_grad_sum = g.abs().sum()
525 |
526 | if torch.isnan(abs_grad_sum) or abs_grad_sum <= tolerance_grad:
527 | return orig_loss
528 |
529 | n_iter=0
530 |
531 | if batch_mode and state['n_iter']==0:
532 | self.running_avg=torch.zeros_like(g.data)
533 | self.running_avg_sq=torch.zeros_like(g.data)
534 |
535 | while (self._get_optimality(g)>tolerance_change) and n_iter1)
565 | if batch_changed:
566 | tmp_grad_1=g_old.clone(memory_format=torch.contiguous_format)
567 | tmp_grad_1.add_(self.running_avg,alpha=-1.0) # grad-oldmean
568 | self.running_avg.add_(tmp_grad_1,alpha=1.0/state['n_iter'])
569 | tmp_grad_2=g_old.clone(memory_format=torch.contiguous_format)
570 | tmp_grad_2.add_(self.running_avg,alpha=-1.0) # grad-newmean
571 | self.running_avg_sq.addcmul_(tmp_grad_2,tmp_grad_1,value=1) # # +(grad-newmean)(grad-oldmean)
572 | self.alphabar=1.0/(1.0+self.running_avg_sq.sum()/((state['n_iter']-1)*g_old.norm().item()))
573 |
574 |
575 | if (curv f_old + alphak*prodterm)):
155 | alphak=0.5*alphak
156 | self._copy_params_in(xk)
157 | self._add_grad(alphak, pk)
158 | f_new=float(closure())
159 | if be_verbose:
160 | print('LN %d alpha=%f fnew=%f fold=%f'%(ci,alphak,f_new,f_old))
161 | ci=ci+1
162 |
163 | # if the cost is not sufficiently decreased, also try -ve steps
164 | if (f_old-f_new < torch.abs(prodterm)):
165 | alphak1=-alphabar
166 | self._copy_params_in(xk)
167 | self._add_grad(alphak1, pk)
168 | f_new1=float(closure())
169 | if be_verbose:
170 | print('NLN fnew=%f'%f_new1)
171 | while (ci f_old + alphak1*prodterm)):
172 | alphak1=0.5*alphak1
173 | self._copy_params_in(xk)
174 | self._add_grad(alphak1, pk)
175 | f_new1=float(closure())
176 | if be_verbose:
177 | print('NLN %d alpha=%f fnew=%f fold=%f'%(ci,alphak1,f_new1,f_old))
178 | ci=ci+1
179 |
180 | if f_new1phi_0+alphai*gphi_0) or (ci>1 and phi_alphai>=phi_alphai1) :
261 | # ai=alphai1, bi=alphai bracket
262 | if be_verbose:
263 | print("bracket "+str(alphai1)+","+str(alphai))
264 | alphak=self._linesearch_zoom(closure,xk,pk,alphai1,alphai,phi_0,gphi_0,sigma,rho,t1,t2,t3,step)
265 | if be_verbose:
266 | print("Linesearch: condition 1 met")
267 | break
268 |
269 | # evaluate grad(phi(alpha(i))) */
270 | # note that self._params already is xk+alphai. pk, so only add the missing term
271 | # xp <- xk+(alphai+step). pk
272 | self._add_grad(step, pk) #FF param = param - t * grad
273 | p01=float(closure())
274 | # xp <- xk+(alphai-step). pk
275 | self._add_grad(-2.0*step, pk) #FF param = param - t * grad
276 | p02=float(closure())
277 | gphi_i=(p01-p02)/(2.0*step);
278 |
279 | if (abs(gphi_i)<=-sigma*gphi_0):
280 | alphak=alphai
281 | if be_verbose:
282 | print("Linesearch: condition 2 met")
283 | break
284 |
285 | if gphi_i>=0.0 :
286 | # ai=alphai, bi=alphai1 bracket
287 | if be_verbose:
288 | print("bracket "+str(alphai)+","+str(alphai1))
289 | alphak=self._linesearch_zoom(closure,xk,pk,alphai,alphai1,phi_0,gphi_0,sigma,rho,t1,t2,t3,step)
290 | if be_verbose:
291 | print("Linesearch: condition 3 met")
292 | break
293 | # else preserve old values
294 | if (mu<=2.0*alphai-alphai1):
295 | alphai1=alphai
296 | alphai=mu
297 | else:
298 | # choose by interpolation in [2*alphai-alphai1,min(mu,alphai+t1*(alphai-alphai1)]
299 | p01=2.0*alphai-alphai1;
300 | p02=min(mu,alphai+t1*(alphai-alphai1))
301 | alphai=self._cubic_interpolate(closure,xk,pk,p01,p02,step)
302 |
303 |
304 | phi_alphai1=phi_alphai;
305 | # update function evals
306 | closure_evals +=3
307 | ci=ci+1
308 |
309 |
310 |
311 |
312 | # recover original params
313 | self._copy_params_in(xk)
314 | # update state
315 | state['func_evals'] += closure_evals
316 | return alphak
317 |
318 |
319 | def _cubic_interpolate(self,closure,xk,pk,a,b,step):
320 | """ Cubic interpolation within interval [a,b] or [b,a] (a>b is possible)
321 |
322 | Arguments:
323 | closure (callable): A closure that reevaluates the model
324 | and returns the loss.
325 | xk: copy of parameter values
326 | pk: gradient vector
327 | a/b: interval for interpolation
328 | step: step size for differencing
329 | """
330 |
331 |
332 | self._copy_params_in(xk)
333 |
334 | # state parameter
335 | state = self.state[self._params[0]]
336 | # count function evals
337 | closure_evals=0
338 |
339 | # xp <- xk+a. pk
340 | self._add_grad(a, pk) #FF param = param + t * grad
341 | f0=float(closure())
342 | # xp <- xk+(a+step). pk
343 | self._add_grad(step, pk) #FF param = param + t * grad
344 | p01=float(closure())
345 | # xp <- xk+(a-step). pk
346 | self._add_grad(-2.0*step, pk) #FF param = param - t * grad
347 | p02=float(closure())
348 | f0d=(p01-p02)/(2.0*step)
349 |
350 | # xp <- xk+b. pk
351 | self._add_grad(-a+step+b, pk) #FF param = param + t * grad
352 | f1=float(closure())
353 | # xp <- xk+(b+step). pk
354 | self._add_grad(step, pk) #FF param = param + t * grad
355 | p01=float(closure())
356 | # xp <- xk+(b-step). pk
357 | self._add_grad(-2.0*step, pk) #FF param = param - t * grad
358 | p02=float(closure())
359 | f1d=(p01-p02)/(2.0*step)
360 |
361 | closure_evals=6
362 |
363 | aa=3.0*(f0-f1)/(b-a)+f1d-f0d
364 | p01=aa*aa-f0d*f1d
365 | if (p01>0.0):
366 | cc=math.sqrt(p01)
367 | #print('f0='+str(f0d)+' f1='+str(f1d)+' cc='+str(cc))
368 | if (f1d-f0d+2.0*cc)==0.0:
369 | return (a+b)*0.5
370 | z0=b-(f1d+cc-aa)*(b-a)/(f1d-f0d+2.0*cc)
371 | aa=max(a,b)
372 | cc=min(a,b)
373 | if z0>aa or z0phi_0+rho*alphaj*gphi_0) or phi_j>=phi_aj :
456 | bj=alphaj # aj is unchanged
457 | else:
458 | # evaluate grad(alphaj)
459 | # xp <- xk+(alphaj+step). pk
460 | self._add_grad(-aj+alphaj+step, pk) #FF param = param + t * grad
461 | p01=float(closure())
462 | # xp <- xk+(alphaj-step). pk
463 | self._add_grad(-2.0*step, pk) #FF param = param + t * grad
464 | p02=float(closure())
465 | gphi_j=(p01-p02)/(2.0*step)
466 |
467 |
468 | closure_evals +=2
469 |
470 | # termination due to roundoff/other errors pp. 38, Fletcher
471 | if (aj-alphaj)*gphi_j <= step:
472 | alphak=alphaj
473 | found_step=True
474 | break
475 |
476 | if abs(gphi_j)<=-sigma*gphi_0 :
477 | alphak=alphaj
478 | found_step=True
479 | break
480 |
481 | if gphi_j*(bj-aj)>=0.0:
482 | bj=aj
483 | # else bj is unchanged
484 | aj=alphaj
485 |
486 |
487 | ci=ci+1
488 |
489 | if not found_step:
490 | alphak=alphaj
491 |
492 | # update state
493 | state['func_evals'] += closure_evals
494 |
495 | return alphak
496 |
497 |
498 | def step(self, closure):
499 | """Performs a single optimization step.
500 |
501 | Arguments:
502 | closure (callable): A closure that reevaluates the model
503 | and returns the loss.
504 | """
505 | assert len(self.param_groups) == 1
506 |
507 | group = self.param_groups[0]
508 | lr = group['lr']
509 | max_iter = group['max_iter']
510 | max_eval = group['max_eval']
511 | tolerance_grad = group['tolerance_grad']
512 | tolerance_change = group['tolerance_change']
513 | line_search_fn = group['line_search_fn']
514 | history_size = group['history_size']
515 |
516 | batch_mode = group['batch_mode']
517 | cost_use_gradient = group['cost_use_gradient']
518 |
519 |
520 | # NOTE: LBFGS has only global state, but we register it as state for
521 | # the first param, because this helps with casting in load_state_dict
522 | state = self.state[self._params[0]]
523 | state.setdefault('func_evals', 0)
524 | state.setdefault('n_iter', 0)
525 |
526 |
527 | # evaluate initial f(x) and df/dx
528 | orig_loss = closure()
529 | loss = float(orig_loss)
530 | current_evals = 1
531 | state['func_evals'] += 1
532 |
533 | flat_grad = self._gather_flat_grad()
534 | abs_grad_sum = flat_grad.abs().sum()
535 |
536 | if torch.isnan(abs_grad_sum) or abs_grad_sum <= tolerance_grad:
537 | return orig_loss
538 |
539 | # tensors cached in state (for tracing)
540 | d = state.get('d')
541 | t = state.get('t')
542 | old_dirs = state.get('old_dirs')
543 | old_stps = state.get('old_stps')
544 | H_diag = state.get('H_diag')
545 | prev_flat_grad = state.get('prev_flat_grad')
546 | prev_loss = state.get('prev_loss')
547 |
548 | n_iter = 0
549 |
550 | if batch_mode:
551 | alphabar=lr
552 | lm0=1e-6
553 |
554 | # optimize for a max of max_iter iterations
555 | grad_nrm=flat_grad.norm().item()
556 | while n_iter < max_iter and not math.isnan(grad_nrm):
557 | # keep track of nb of iterations
558 | n_iter += 1
559 | state['n_iter'] += 1
560 |
561 | ############################################################
562 | # compute gradient descent direction
563 | ############################################################
564 | if state['n_iter'] == 1:
565 | d = flat_grad.neg()
566 | old_dirs = []
567 | old_stps = []
568 | H_diag = 1
569 | if batch_mode:
570 | running_avg=torch.zeros_like(flat_grad.data)
571 | running_avg_sq=torch.zeros_like(flat_grad.data)
572 | else:
573 | if batch_mode:
574 | running_avg=state.get('running_avg')
575 | running_avg_sq=state.get('running_avg_sq')
576 | if running_avg is None:
577 | running_avg=torch.zeros_like(flat_grad.data)
578 | running_avg_sq=torch.zeros_like(flat_grad.data)
579 |
580 | # do lbfgs update (update memory)
581 | # what happens if current and prev grad are equal, ||y||->0 ??
582 | y = flat_grad.sub(prev_flat_grad)
583 |
584 | s = d.mul(t)
585 |
586 | if batch_mode: # y = y+ lm0 * s, to have a trust region
587 | y.add_(s,alpha=lm0)
588 |
589 | ys = y.dot(s) # y^T*s
590 | sn = s.norm().item() # ||s||
591 | # FIXME batch_changed does not work for full batch mode (data might be the same)
592 | batch_changed= batch_mode and (n_iter==1 and state['n_iter']>1)
593 | if batch_changed: # batch has changed
594 | # online estimate of mean,variance of gradient (inter-batch, not intra-batch)
595 | # newmean <- oldmean + (grad - oldmean)/niter
596 | # moment <- oldmoment + (grad-oldmean)(grad-newmean)
597 | # variance = moment/(niter-1)
598 |
599 | g_old=flat_grad.clone(memory_format=torch.contiguous_format)
600 | g_old.add_(running_avg,alpha=-1.0) # grad-oldmean
601 | running_avg.add_(g_old,alpha=1.0/state['n_iter']) # newmean
602 | g_new=flat_grad.clone(memory_format=torch.contiguous_format)
603 | g_new.add_(running_avg,alpha=-1.0) # grad-newmean
604 | running_avg_sq.addcmul_(g_new,g_old,value=1) # +(grad-newmean)(grad-oldmean)
605 | alphabar=1/(1+running_avg_sq.sum()/((state['n_iter']-1)*(grad_nrm)))
606 | if be_verbose:
607 | print('iter %d |mean| %f |var| %f ||grad|| %f step %f y^Ts %f alphabar=%f'%(state['n_iter'],running_avg.sum(),running_avg_sq.sum()/(state['n_iter']-1),grad_nrm,t,ys,alphabar))
608 |
609 |
610 | if ys > 1e-10*sn*sn and not batch_changed :
611 | # updating memory (only when we have y within a single batch)
612 | if len(old_dirs) == history_size:
613 | # shift history by one (limited-memory)
614 | old_dirs.pop(0)
615 | old_stps.pop(0)
616 |
617 | # store new direction/step
618 | old_dirs.append(y)
619 | old_stps.append(s)
620 |
621 | # update scale of initial Hessian approximation
622 | H_diag = ys / y.dot(y) # (y*y)
623 |
624 | if math.isnan(H_diag):
625 | print('Warning H_diag nan')
626 |
627 | # compute the approximate (L-BFGS) inverse Hessian
628 | # multiplied by the gradient
629 | num_old = len(old_dirs)
630 |
631 | if 'ro' not in state:
632 | state['ro'] = [None] * history_size
633 | state['al'] = [None] * history_size
634 | ro = state['ro']
635 | al = state['al']
636 |
637 | for i in range(num_old):
638 | ro[i] = 1. / old_dirs[i].dot(old_stps[i])
639 |
640 | # iteration in L-BFGS loop collapsed to use just one buffer
641 | q = flat_grad.neg()
642 | for i in range(num_old - 1, -1, -1):
643 | al[i] = old_stps[i].dot(q) * ro[i]
644 | q.add_(old_dirs[i],alpha=-al[i])
645 |
646 | # multiply by initial Hessian
647 | # r/d is the final direction
648 | d = r = torch.mul(q, H_diag)
649 | for i in range(num_old):
650 | be_i = old_dirs[i].dot(r) * ro[i]
651 | r.add_(old_stps[i],alpha=al[i] - be_i)
652 |
653 | if prev_flat_grad is None:
654 | prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
655 |
656 | else:
657 | prev_flat_grad.copy_(flat_grad)
658 |
659 | prev_loss = loss
660 |
661 | ############################################################
662 | # compute step length
663 | ############################################################
664 | # reset initial guess for step size
665 | if state['n_iter'] == 1:
666 | t = min(1., 1. / abs_grad_sum) * lr
667 | else:
668 | t = lr
669 |
670 | # directional derivative
671 | gtd = flat_grad.dot(d) # g * d
672 |
673 | if math.isnan(gtd.item()):
674 | print('Warning grad norm infinite')
675 | print('iter %d'%state['n_iter'])
676 | print('||grad||=%f'%grad_nrm)
677 | print('||d||=%f'%d.norm().item())
678 | # optional line search: user function
679 | ls_func_evals = 0
680 | if line_search_fn:
681 | # perform line search, using user function
682 | ##raise RuntimeError("line search function is not supported yet")
683 | #FF#################################
684 | # Note: we disable gradient calculation during line search
685 | # because it is not needed
686 | if not cost_use_gradient:
687 | torch.set_grad_enabled(False)
688 | if not batch_mode:
689 | t=self._linesearch_cubic(closure,d,1e-6)
690 | else:
691 | t=self._linesearch_backtrack(closure,d,flat_grad,alphabar)
692 | if not cost_use_gradient:
693 | torch.set_grad_enabled(True)
694 |
695 | if math.isnan(t):
696 | print('Warning: stepsize nan')
697 | t=lr
698 | self._add_grad(t, d) #FF param = param + t * d
699 | if be_verbose:
700 | print('step size=%f'%(t))
701 | #FF#################################
702 | else:
703 | #FF Here, t = stepsize, d = -grad, in cache
704 | # no line search, simply move with fixed-step
705 | self._add_grad(t, d) #FF param = param + t * d
706 | if n_iter != max_iter:
707 | # re-evaluate function only if not in last iteration
708 | # the reason we do this: in a stochastic setting,
709 | # no use to re-evaluate that function here
710 | loss = float(closure())
711 | flat_grad = self._gather_flat_grad()
712 | abs_grad_sum = flat_grad.abs().sum()
713 | if math.isnan(abs_grad_sum):
714 | print('Warning: gradient nan')
715 | break
716 | ls_func_evals = 1
717 |
718 | # update func eval
719 | current_evals += ls_func_evals
720 | state['func_evals'] += ls_func_evals
721 |
722 | ############################################################
723 | # check conditions
724 | ############################################################
725 | if n_iter == max_iter:
726 | break
727 |
728 | if current_evals >= max_eval:
729 | break
730 |
731 | if abs_grad_sum <= tolerance_grad:
732 | break
733 |
734 | if gtd > -tolerance_change:
735 | break
736 |
737 | if d.mul(t).abs_().sum() <= tolerance_change:
738 | break
739 |
740 | if abs(loss - prev_loss) < tolerance_change:
741 | break
742 |
743 | state['d'] = d
744 | state['t'] = t
745 | state['old_dirs'] = old_dirs
746 | state['old_stps'] = old_stps
747 | state['H_diag'] = H_diag
748 | state['prev_flat_grad'] = prev_flat_grad
749 | state['prev_loss'] = prev_loss
750 |
751 | if batch_mode:
752 | if 'running_avg' not in locals() or running_avg is None:
753 | running_avg=torch.zeros_like(flat_grad.data)
754 | running_avg_sq=torch.zeros_like(flat_grad.data)
755 | state['running_avg']=running_avg
756 | state['running_avg_sq']=running_avg_sq
757 |
758 |
759 | return orig_loss
760 |
--------------------------------------------------------------------------------
/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nlesc-dirac/pytorch/283ea6b93785b87ebe409f7fe887401a7ccfb313/loss.png
--------------------------------------------------------------------------------
/resnet9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nlesc-dirac/pytorch/283ea6b93785b87ebe409f7fe887401a7ccfb313/resnet9.png
--------------------------------------------------------------------------------