├── .gitignore ├── LICENSE ├── README.md ├── cls ├── densenet.py ├── models.py ├── plot.py └── train.py ├── denoising ├── create.py ├── main.py ├── main.tv.py ├── models.py ├── plot.py └── run-exps.sh ├── profile ├── optnet-forward.py └── optnet-single.py ├── sudoku ├── create.py ├── data │ ├── 2 │ │ ├── features.pt │ │ └── labels.pt │ └── 3 │ │ ├── features.pt │ │ └── labels.pt ├── models.py ├── plot.py ├── prof-sparse.py ├── train.py └── true-Qpenalty-errors.py ├── tests ├── optnet-back.py └── optnet-np.py └── util └── init.plot.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | exps 3 | work 4 | __pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OptNet: Differentiable Optimization as a Layer in Neural Networks 2 | 3 | This repository is by [Brandon Amos](http://bamos.github.io) 4 | and [J. Zico Kolter](http://zicokolter.com) 5 | and contains the [PyTorch](https://pytorch.org) source code to 6 | reproduce the experiments in our ICML 2017 paper 7 | [OptNet: Differentiable Optimization as a Layer in Neural Networks](https://arxiv.org/abs/1703.00443). 8 | 9 | If you find this repository helpful in your publications, 10 | please consider citing our paper. 11 | 12 | ``` 13 | @InProceedings{amos2017optnet, 14 | title = {{O}pt{N}et: Differentiable Optimization as a Layer in Neural Networks}, 15 | author = {Brandon Amos and J. Zico Kolter}, 16 | booktitle = {Proceedings of the 34th International Conference on Machine Learning}, 17 | pages = {136--145}, 18 | year = {2017}, 19 | volume = {70}, 20 | series = {Proceedings of Machine Learning Research}, 21 | publisher ={PMLR}, 22 | } 23 | ``` 24 | 25 | # Informal Introduction 26 | 27 | [Mathematical optimization](https://en.wikipedia.org/wiki/Mathematical_optimization) 28 | is a well-studied language of expressing solutions to many real-life problems 29 | that come up in machine learning and many other fields such as mechanics, 30 | economics, EE, operations research, control engineering, geophysics, 31 | and molecular modeling. 32 | As we build our machine learning systems to interact with real 33 | data from these fields, we often **cannot** (but sometimes can) 34 | simply ``learn away'' the optimization sub-problems by adding more 35 | layers in our network. Well-defined optimization problems may be added 36 | if you have a thorough understanding of your feature space, but 37 | oftentimes we **don't** have this understanding and resort to 38 | automatic feature learning for our tasks. 39 | 40 | Until this repository, **no** modern deep learning library has provided 41 | a way of adding a learnable optimization layer (other than simply unrolling 42 | an optimization procedure, which is inefficient and inexact) into 43 | our model formulation that we can quickly try to see if it's a nice way 44 | of expressing our data. 45 | 46 | See our paper 47 | [OptNet: Differentiable Optimization as a Layer in Neural Networks](https://arxiv.org/abs/1703.00443) 48 | and code at 49 | [locuslab/optnet](https://github.com/locuslab/optnet) 50 | if you are interested in learning more about our initial exploration 51 | in this space of automatically learning quadratic program layers 52 | for signal denoising and sudoku. 53 | 54 | ## Setup and Dependencies 55 | 56 | + Python/numpy/[PyTorch](https://pytorch.org) 57 | + [qpth](https://github.com/locuslab/qpth): 58 | *Our fast QP solver for PyTorch released in conjunction with this paper.* 59 | + [bamos/block](https://github.com/bamos/block): 60 | *Our intelligent block matrix library for numpy, PyTorch, and beyond.* 61 | + Optional: [bamos/setGPU](https://github.com/bamos/setGPU): 62 | A small library to set `CUDA_VISIBLE_DEVICES` on multi-GPU systems. 63 | 64 | # Denoising Experiments 65 | 66 | ``` 67 | denoising 68 | ├── create.py - Script to create the denoising dataset. 69 | ├── plot.py - Plot the results from any experiment. 70 | ├── main.py - Run the FC baseline and OptNet denoising experiments. (See arguments.) 71 | ├── main.tv.py - Run the TV baseline denoising experiment. 72 | └── run-exps.sh - Run all experiments. (May need to uncomment some lines.) 73 | ``` 74 | 75 | # Sudoku Experiments 76 | 77 | + The dataset we used in our experiments is available in `sudoku/data`. 78 | 79 | ``` 80 | sudoku 81 | ├── create.py - Script to create the dataset. 82 | ├── plot.py - Plot the results from any experiment. 83 | ├── main.py - Run the FC baseline and OptNet Sudoku experiments. (See arguments.) 84 | └── models.py - Models used for Sudoku. 85 | ``` 86 | 87 | # Classification Experiments 88 | 89 | ``` 90 | cls 91 | ├── train.py - Run the FC baseline and OptNet classification experiments. (See arguments.) 92 | ├── plot.py - Plot the results from any experiment. 93 | └── models.py - Models used for classification. 94 | ``` 95 | 96 | # Acknowledgments 97 | 98 | The rapid development of this work would not have been possible without 99 | the immense amount of help from the [PyTorch](https://pytorch.org) team, 100 | particularly [Soumith Chintala](http://soumith.ch/) and 101 | [Adam Paszke](https://github.com/apaszke). 102 | 103 | # Licensing 104 | 105 | Unless otherwise stated, the source code is copyright 106 | Carnegie Mellon University and licensed under the 107 | [Apache 2.0 License](./LICENSE). 108 | -------------------------------------------------------------------------------- /cls/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | import torchvision.datasets as dset 10 | import torchvision.transforms as transforms 11 | from torch.utils.data import DataLoader 12 | 13 | import torchvision.models as models 14 | 15 | import sys 16 | import math 17 | 18 | class Bottleneck(nn.Module): 19 | def __init__(self, nChannels, growthRate): 20 | super(Bottleneck, self).__init__() 21 | interChannels = 4*growthRate 22 | self.bn1 = nn.BatchNorm2d(nChannels) 23 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, 24 | bias=False) 25 | self.bn2 = nn.BatchNorm2d(interChannels) 26 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 27 | padding=1, bias=False) 28 | 29 | def forward(self, x): 30 | out = self.conv1(F.relu(self.bn1(x))) 31 | out = self.conv2(F.relu(self.bn2(out))) 32 | out = torch.cat((x, out), 1) 33 | return out 34 | 35 | class SingleLayer(nn.Module): 36 | def __init__(self, nChannels, growthRate): 37 | super(SingleLayer, self).__init__() 38 | self.bn1 = nn.BatchNorm2d(nChannels) 39 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, 40 | padding=1, bias=False) 41 | 42 | def forward(self, x): 43 | out = self.conv1(F.relu(self.bn1(x))) 44 | out = torch.cat((x, out), 1) 45 | return out 46 | 47 | class Transition(nn.Module): 48 | def __init__(self, nChannels, nOutChannels): 49 | super(Transition, self).__init__() 50 | self.bn1 = nn.BatchNorm2d(nChannels) 51 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, 52 | bias=False) 53 | 54 | def forward(self, x): 55 | out = self.conv1(F.relu(self.bn1(x))) 56 | out = F.avg_pool2d(out, 2) 57 | return out 58 | 59 | 60 | class DenseNet(nn.Module): 61 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): 62 | super(DenseNet, self).__init__() 63 | 64 | nDenseBlocks = (depth-4) // 3 65 | if bottleneck: 66 | nDenseBlocks //= 2 67 | 68 | nChannels = 2*growthRate 69 | self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, 70 | bias=False) 71 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 72 | nChannels += nDenseBlocks*growthRate 73 | nOutChannels = int(math.floor(nChannels*reduction)) 74 | self.trans1 = Transition(nChannels, nOutChannels) 75 | 76 | nChannels = nOutChannels 77 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 78 | nChannels += nDenseBlocks*growthRate 79 | nOutChannels = int(math.floor(nChannels*reduction)) 80 | self.trans2 = Transition(nChannels, nOutChannels) 81 | 82 | nChannels = nOutChannels 83 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 84 | nChannels += nDenseBlocks*growthRate 85 | 86 | self.bn1 = nn.BatchNorm2d(nChannels) 87 | self.fc = nn.Linear(nChannels, nClasses) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 92 | m.weight.data.normal_(0, math.sqrt(2. / n)) 93 | elif isinstance(m, nn.BatchNorm2d): 94 | m.weight.data.fill_(1) 95 | m.bias.data.zero_() 96 | elif isinstance(m, nn.Linear): 97 | m.bias.data.zero_() 98 | 99 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 100 | layers = [] 101 | for i in range(int(nDenseBlocks)): 102 | if bottleneck: 103 | layers.append(Bottleneck(nChannels, growthRate)) 104 | else: 105 | layers.append(SingleLayer(nChannels, growthRate)) 106 | nChannels += growthRate 107 | return nn.Sequential(*layers) 108 | 109 | def forward(self, x): 110 | out = self.conv1(x) 111 | out = self.trans1(self.dense1(out)) 112 | out = self.trans2(self.dense2(out)) 113 | out = self.dense3(out) 114 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 115 | out = F.log_softmax(self.fc(out)) 116 | return out 117 | -------------------------------------------------------------------------------- /cls/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | from torch.autograd import Function, Variable 5 | from torch.nn.parameter import Parameter 6 | import torch.nn.functional as F 7 | 8 | from qpth.qp import QPFunction, QPSolvers 9 | 10 | class Lenet(nn.Module): 11 | def __init__(self, nHidden, nCls=10, proj='softmax'): 12 | super(Lenet, self).__init__() 13 | self.conv1 = nn.Conv2d(1, 20, kernel_size=5) 14 | self.conv2 = nn.Conv2d(20, 50, kernel_size=5) 15 | self.fc1 = nn.Linear(50*4*4, nHidden) 16 | self.fc2 = nn.Linear(nHidden, nCls) 17 | 18 | self.proj = proj 19 | self.nCls = nCls 20 | 21 | if proj == 'simproj': 22 | self.Q = Variable(0.5*torch.eye(nCls).double().cuda()) 23 | self.G = Variable(-torch.eye(nCls).double().cuda()) 24 | self.h = Variable(-1e-5*torch.ones(nCls).double().cuda()) 25 | self.A = Variable((torch.ones(1, nCls)).double().cuda()) 26 | self.b = Variable(torch.Tensor([1.]).double().cuda()) 27 | def projF(x): 28 | nBatch = x.size(0) 29 | Q = self.Q.unsqueeze(0).expand(nBatch, nCls, nCls) 30 | G = self.G.unsqueeze(0).expand(nBatch, nCls, nCls) 31 | h = self.h.unsqueeze(0).expand(nBatch, nCls) 32 | A = self.A.unsqueeze(0).expand(nBatch, 1, nCls) 33 | b = self.b.unsqueeze(0).expand(nBatch, 1) 34 | x = QPFunction()(Q, -x.double(), G, h, A, b).float() 35 | x = x.log() 36 | return x 37 | self.projF = projF 38 | else: 39 | self.projF = F.log_softmax 40 | 41 | def forward(self, x): 42 | nBatch = x.size(0) 43 | 44 | x = F.max_pool2d(self.conv1(x), 2) 45 | x = F.max_pool2d(self.conv2(x), 2) 46 | x = x.view(nBatch, -1) 47 | x = F.relu(self.fc1(x)) 48 | x = self.fc2(x) 49 | return self.projF(x) 50 | 51 | class LenetOptNet(nn.Module): 52 | def __init__(self, nHidden=50, nineq=200, neq=0, eps=1e-4): 53 | super(LenetOptNet, self).__init__() 54 | self.conv1 = nn.Conv2d(1, 20, kernel_size=5) 55 | self.conv2 = nn.Conv2d(20, 50, kernel_size=5) 56 | 57 | self.qp_o = nn.Linear(50*4*4, nHidden) 58 | self.qp_z0 = nn.Linear(50*4*4, nHidden) 59 | self.qp_s0 = nn.Linear(50*4*4, nineq) 60 | 61 | assert(neq==0) 62 | self.M = Variable(torch.tril(torch.ones(nHidden, nHidden)).cuda()) 63 | self.L = Parameter(torch.tril(torch.rand(nHidden, nHidden).cuda())) 64 | self.G = Parameter(torch.Tensor(nineq,nHidden).uniform_(-1,1).cuda()) 65 | # self.z0 = Parameter(torch.zeros(nHidden).cuda()) 66 | # self.s0 = Parameter(torch.ones(nineq).cuda()) 67 | 68 | self.nHidden = nHidden 69 | self.nineq = nineq 70 | self.neq = neq 71 | self.eps = eps 72 | 73 | def forward(self, x): 74 | nBatch = x.size(0) 75 | 76 | x = F.max_pool2d(self.conv1(x), 2) 77 | x = F.max_pool2d(self.conv2(x), 2) 78 | x = x.view(nBatch, -1) 79 | 80 | L = self.M*self.L 81 | Q = L.mm(L.t()) + self.eps*Variable(torch.eye(self.nHidden)).cuda() 82 | Q = Q.unsqueeze(0).expand(nBatch, self.nHidden, self.nHidden) 83 | G = self.G.unsqueeze(0).expand(nBatch, self.nineq, self.nHidden) 84 | z0 = self.qp_z0(x) 85 | s0 = self.qp_s0(x) 86 | h = z0.mm(self.G.t())+s0 87 | e = Variable(torch.Tensor()) 88 | inputs = self.qp_o(x) 89 | x = QPFunction()(Q, inputs, G, h, e, e) 90 | x = x[:,:10] 91 | 92 | return F.log_softmax(x) 93 | 94 | class FC(nn.Module): 95 | def __init__(self, nHidden, bn): 96 | super().__init__() 97 | self.bn = bn 98 | 99 | self.fc1 = nn.Linear(784, nHidden) 100 | if bn: 101 | self.bn1 = nn.BatchNorm1d(nHidden) 102 | self.bn2 = nn.BatchNorm1d(10) 103 | self.fc2 = nn.Linear(nHidden, 10) 104 | self.fc3 = nn.Linear(10, 10) 105 | 106 | def forward(self, x): 107 | nBatch = x.size(0) 108 | 109 | # FC-ReLU-(BN)-FC-ReLU-(BN)-FC-Softmax 110 | x = x.view(nBatch, -1) 111 | x = F.relu(self.fc1(x)) 112 | if self.bn: 113 | x = self.bn1(x) 114 | x = F.relu(self.fc2(x)) 115 | if self.bn: 116 | x = self.bn2(x) 117 | x = self.fc3(x) 118 | return F.log_softmax(x) 119 | 120 | class OptNet(nn.Module): 121 | def __init__(self, nFeatures, nHidden, nCls, bn, nineq=200, neq=0, eps=1e-4): 122 | super().__init__() 123 | 124 | self.nFeatures = nFeatures 125 | self.nHidden = nHidden 126 | self.bn = bn 127 | self.nCls = nCls 128 | 129 | if bn: 130 | self.bn1 = nn.BatchNorm1d(nHidden) 131 | self.bn2 = nn.BatchNorm1d(nCls) 132 | 133 | self.fc1 = nn.Linear(nFeatures, nHidden) 134 | self.fc2 = nn.Linear(nHidden, nCls) 135 | 136 | # self.qp_z0 = nn.Linear(nCls, nCls) 137 | # self.qp_s0 = nn.Linear(nCls, nineq) 138 | 139 | assert(neq==0) 140 | self.M = Variable(torch.tril(torch.ones(nCls, nCls)).cuda()) 141 | self.L = Parameter(torch.tril(torch.rand(nCls, nCls).cuda())) 142 | self.G = Parameter(torch.Tensor(nineq,nCls).uniform_(-1,1).cuda()) 143 | self.z0 = Parameter(torch.zeros(nCls).cuda()) 144 | self.s0 = Parameter(torch.ones(nineq).cuda()) 145 | 146 | self.nineq = nineq 147 | self.neq = neq 148 | self.eps = eps 149 | 150 | def forward(self, x): 151 | nBatch = x.size(0) 152 | 153 | # FC-ReLU-(BN)-FC-ReLU-(BN)-QP-Softmax 154 | x = x.view(nBatch, -1) 155 | x = F.relu(self.fc1(x)) 156 | if self.bn: 157 | x = self.bn1(x) 158 | x = F.relu(self.fc2(x)) 159 | if self.bn: 160 | x = self.bn2(x) 161 | 162 | L = self.M*self.L 163 | Q = L.mm(L.t()) + self.eps*Variable(torch.eye(self.nCls)).cuda() 164 | Q = Q.unsqueeze(0).expand(nBatch, self.nCls, self.nCls) 165 | G = self.G.unsqueeze(0).expand(nBatch, self.nineq, self.nCls) 166 | # z0 = self.qp_z0(x) 167 | # s0 = self.qp_s0(x) 168 | z0 = self.z0.unsqueeze(0).expand(nBatch, self.nCls) 169 | s0 = self.s0.unsqueeze(0).expand(nBatch, self.nineq) 170 | h = z0.mm(self.G.t())+s0 171 | e = Variable(torch.Tensor()) 172 | inputs = x 173 | x = QPFunction(verbose=-1)( 174 | Q.double(), inputs.double(), G.double(), h.double(), e, e) 175 | x = x.float() 176 | # x = x[:,:10].float() 177 | 178 | return F.log_softmax(x) 179 | 180 | class OptNetEq(nn.Module): 181 | def __init__(self, nFeatures, nHidden, nCls, neq, Qpenalty=0.1, eps=1e-4): 182 | super().__init__() 183 | 184 | self.nFeatures = nFeatures 185 | self.nHidden = nHidden 186 | self.nCls = nCls 187 | 188 | self.fc1 = nn.Linear(nFeatures, nHidden) 189 | self.fc2 = nn.Linear(nHidden, nCls) 190 | 191 | self.Q = Variable(Qpenalty*torch.eye(nHidden).double().cuda()) 192 | self.G = Variable(-torch.eye(nHidden).double().cuda()) 193 | self.h = Variable(torch.zeros(nHidden).double().cuda()) 194 | self.A = Parameter(torch.rand(neq,nHidden).double().cuda()) 195 | self.b = Variable(torch.ones(self.A.size(0)).double().cuda()) 196 | 197 | self.neq = neq 198 | 199 | def forward(self, x): 200 | nBatch = x.size(0) 201 | 202 | # FC-ReLU-QP-FC-Softmax 203 | x = x.view(nBatch, -1) 204 | x = F.relu(self.fc1(x)) 205 | 206 | Q = self.Q.unsqueeze(0).expand(nBatch, self.Q.size(0), self.Q.size(1)) 207 | p = -x.view(nBatch,-1) 208 | G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1)) 209 | h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0)) 210 | A = self.A.unsqueeze(0).expand(nBatch, self.A.size(0), self.A.size(1)) 211 | b = self.b.unsqueeze(0).expand(nBatch, self.b.size(0)) 212 | 213 | x = QPFunction(verbose=False)(Q, p.double(), G, h, A, b).float() 214 | x = self.fc2(x) 215 | 216 | return F.log_softmax(x) 217 | -------------------------------------------------------------------------------- /cls/plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | import numpy as np 6 | 7 | import math 8 | 9 | import matplotlib as mpl 10 | mpl.use('Agg') 11 | import matplotlib.pyplot as plt 12 | plt.style.use('bmh') 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('expDir', type=str) 17 | args = parser.parse_args() 18 | 19 | trainP = os.path.join(args.expDir, 'train.csv') 20 | trainData = np.loadtxt(trainP, delimiter=',').reshape(-1, 3) 21 | testP = os.path.join(args.expDir, 'test.csv') 22 | testData = np.loadtxt(testP, delimiter=',').reshape(-1, 3) 23 | 24 | trainI, trainLoss, trainErr = np.split(trainData, [1,2], axis=1) 25 | trainI, trainLoss, trainErr = [x.ravel() for x in 26 | (trainI, trainLoss, trainErr)] 27 | 28 | N = len(trainI) // math.ceil(trainI[-1]) 29 | trainI_, trainLoss_, trainErr_ = rolling(N, trainI, trainLoss, trainErr) 30 | 31 | testI, testLoss, testErr = np.split(testData, [1,2], axis=1) 32 | 33 | fig, ax = plt.subplots(1, 1, figsize=(6, 5)) 34 | plt.plot(trainI, trainLoss, label='Train') 35 | # plt.plot(trainI_, trainLoss_, label='Train') 36 | plt.plot(testI, testLoss, label='Test') 37 | plt.xlabel('Epoch') 38 | plt.ylabel('Cross-Entropy Loss') 39 | # ax.set_ylim([1e-2, 1e0]) 40 | plt.legend() 41 | ax.set_yscale('log') 42 | loss_fname = os.path.join(args.expDir, 'loss.png') 43 | plt.tight_layout() 44 | plt.savefig(loss_fname) 45 | print('Created {}'.format(loss_fname)) 46 | 47 | fig, ax = plt.subplots(1, 1, figsize=(6, 5)) 48 | # plt.plot(trainI, trainErr, label='Train') 49 | plt.plot(trainI_, trainErr_, label='Train') 50 | plt.plot(testI, testErr, label='Test') 51 | plt.xlabel('Epoch') 52 | plt.ylabel('Error') 53 | ax.set_yscale('log') 54 | ax.set_ylim(ymin=1) 55 | # ax.set_ylim([0.5,1.2]) 56 | plt.legend() 57 | err_fname = os.path.join(args.expDir, 'error.png') 58 | plt.tight_layout() 59 | plt.savefig(err_fname) 60 | print('Created {}'.format(err_fname)) 61 | 62 | loss_err_fname = os.path.join(args.expDir, 'loss-error.png') 63 | os.system('convert +append "{}" "{}" "{}"'.format(loss_fname, err_fname, loss_err_fname)) 64 | print('Created {}'.format(loss_err_fname)) 65 | 66 | def rolling(N, i, loss, err): 67 | i_ = i[N-1:] 68 | K = np.full(N, 1./N) 69 | loss_ = np.convolve(loss, K, 'valid') 70 | err_ = np.convolve(err, K, 'valid') 71 | return i_, loss_, err_ 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /cls/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | 5 | import argparse 6 | 7 | try: import setGPU 8 | except ImportError: pass 9 | 10 | import torch 11 | 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | 15 | import torch.nn.functional as F 16 | from torch.autograd import Function, Variable 17 | 18 | import torchvision.datasets as dset 19 | import torchvision.transforms as transforms 20 | from torchvision.utils import save_image 21 | 22 | from torch.utils.data import DataLoader 23 | 24 | import os 25 | import sys 26 | import math 27 | 28 | import shutil 29 | 30 | import setproctitle 31 | 32 | import densenet 33 | import models 34 | # import make_graph 35 | 36 | import sys 37 | from IPython.core import ultratb 38 | sys.excepthook = ultratb.FormattedTB(mode='Verbose', 39 | color_scheme='Linux', call_pdb=1) 40 | 41 | def get_loaders(args): 42 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 43 | if args.dataset == 'mnist': 44 | trainLoader = torch.utils.data.DataLoader( 45 | dset.MNIST('data/mnist', train=True, download=True, 46 | transform=transforms.Compose([ 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.1307,), (0.3081,)) 49 | ])), 50 | batch_size=args.batchSz, shuffle=True, **kwargs) 51 | testLoader = torch.utils.data.DataLoader( 52 | dset.MNIST('data/mnist', train=False, transform=transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.1307,), (0.3081,)) 55 | ])), 56 | batch_size=args.batchSz, shuffle=False, **kwargs) 57 | elif args.dataset == 'cifar-10': 58 | normMean = [0.49139968, 0.48215827, 0.44653124] 59 | normStd = [0.24703233, 0.24348505, 0.26158768] 60 | normTransform = transforms.Normalize(normMean, normStd) 61 | 62 | trainTransform = transforms.Compose([ 63 | transforms.RandomCrop(32, padding=4), 64 | transforms.RandomHorizontalFlip(), 65 | transforms.ToTensor(), 66 | normTransform 67 | ]) 68 | testTransform = transforms.Compose([ 69 | transforms.ToTensor(), 70 | normTransform 71 | ]) 72 | 73 | trainLoader = DataLoader( 74 | dset.CIFAR10(root='data/cifar', train=True, download=True, 75 | transform=trainTransform), 76 | batch_size=args.batchSz, shuffle=True, **kwargs) 77 | testLoader = DataLoader( 78 | dset.CIFAR10(root='data/cifar', train=False, download=True, 79 | transform=testTransform), 80 | batch_size=args.batchSz, shuffle=False, **kwargs) 81 | else: 82 | assert(False) 83 | 84 | return trainLoader, testLoader 85 | 86 | def get_net(args): 87 | if args.model == 'densenet': 88 | net = densenet.DenseNet(growthRate=12, depth=100, reduction=0.5, 89 | bottleneck=True, nClasses=10) 90 | elif args.model == 'lenet': 91 | net = models.Lenet(args.nHidden, 10, args.proj) 92 | elif args.model == 'lenet-optnet': 93 | net = models.LenetOptNet(args.nHidden, args.nineq) 94 | elif args.model == 'fc': 95 | net = models.FC(args.nHidden, args.bn) 96 | elif args.model == 'optnet': 97 | net = models.OptNet(28*28, args.nHidden, 10, args.bn, args.nineq) 98 | elif args.model == 'optnet-eq': 99 | net = models.OptNetEq(28*28, args.nHidden, 10, args.neq) 100 | else: 101 | assert(False) 102 | 103 | return net 104 | 105 | def get_optimizer(args, params): 106 | if args.dataset == 'mnist': 107 | if args.model == 'optnet-eq': 108 | params = list(params) 109 | A_param = params.pop(0) 110 | assert(A_param.size() == (args.neq, args.nHidden)) 111 | optimizer = optim.Adam([ 112 | {'params': params, 'lr': 1e-3}, 113 | {'params': [A_param], 'lr': 1e-1} 114 | ]) 115 | else: 116 | optimizer = optim.Adam(params) 117 | elif args.dataset in ('cifar-10', 'cifar-100'): 118 | if args.opt == 'sgd': 119 | optimizer = optim.SGD(params, lr=1e-1, momentum=0.9, weight_decay=args.weightDecay) 120 | elif args.opt == 'adam': 121 | optimizer = optim.Adam(params, weight_decay=args.weightDecay) 122 | else: 123 | assert(False) 124 | 125 | return optimizer 126 | 127 | def main(): 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--batchSz', type=int, default=64) 130 | parser.add_argument('--no-cuda', action='store_true') 131 | parser.add_argument('--save', type=str) 132 | parser.add_argument('--work', type=str, default='work') 133 | parser.add_argument('--seed', type=int, default=1) 134 | parser.add_argument('--nEpoch', type=int, default=1000) 135 | parser.add_argument('--weightDecay', type=float, default=1e-4) 136 | parser.add_argument('--opt', type=str, default='sgd', 137 | choices=('sgd', 'adam')) 138 | parser.add_argument('dataset', type=str, 139 | choices=['mnist', 'cifar-10', 'cifar-100', 'svhn']) 140 | subparsers = parser.add_subparsers(dest='model') 141 | lenetP = subparsers.add_parser('lenet') 142 | lenetP.add_argument('--nHidden', type=int, default=50) 143 | lenetP.add_argument('--proj', type=str, choices=('softmax', 'simproj')) 144 | lenetOptnetP = subparsers.add_parser('lenet-optnet') 145 | lenetOptnetP.add_argument('--nHidden', type=int, default=50) 146 | lenetOptnetP.add_argument('--nineq', type=int, default=100) 147 | lenetOptnetP.add_argument('--eps', type=float, default=1e-4) 148 | densenetP = subparsers.add_parser('densenet') 149 | fcP = subparsers.add_parser('fc') 150 | fcP.add_argument('--nHidden', type=int, default=500) 151 | fcP.add_argument('--bn', action='store_true') 152 | optnetP = subparsers.add_parser('optnet') 153 | optnetP.add_argument('--nHidden', type=int, default=500) 154 | optnetP.add_argument('--eps', default=1e-4) 155 | optnetP.add_argument('--nineq', type=int, default=10) 156 | optnetP.add_argument('--bn', action='store_true') 157 | optnetEqP = subparsers.add_parser('optnet-eq') 158 | optnetEqP.add_argument('--nHidden', type=int, default=100) 159 | optnetEqP.add_argument('--neq', type=int, default=50) 160 | args = parser.parse_args() 161 | 162 | args.cuda = not args.no_cuda and torch.cuda.is_available() 163 | if args.save is None: 164 | t = '{}.{}'.format(args.dataset, args.model) 165 | if args.model == 'lenet': 166 | t += '.nHidden:{}.proj:{}'.format(args.nHidden, args.proj) 167 | elif args.model == 'fc': 168 | t += '.nHidden:{}'.format(args.nHidden) 169 | if args.bn: 170 | t += '.bn' 171 | elif args.model == 'optnet': 172 | t += '.nHidden:{}.nineq:{}.eps:{}'.format(args.nHidden, args.nineq, args.eps) 173 | if args.bn: 174 | t += '.bn' 175 | elif args.model == 'optnet-eq': 176 | t += '.nHidden:{}.neq:{}'.format(args.nHidden, args.neq) 177 | elif args.model == 'lenet-optnet': 178 | t += '.nHidden:{}.nineq:{}.eps:{}'.format(args.nHidden, args.nineq, args.eps) 179 | setproctitle.setproctitle('bamos.'+t) 180 | args.save = os.path.join(args.work, t) 181 | 182 | torch.manual_seed(args.seed) 183 | if args.cuda: 184 | torch.cuda.manual_seed(args.seed) 185 | 186 | if os.path.exists(args.save): 187 | shutil.rmtree(args.save) 188 | os.makedirs(args.save, exist_ok=True) 189 | 190 | trainLoader, testLoader = get_loaders(args) 191 | net = get_net(args) 192 | optimizer = get_optimizer(args, net.parameters()) 193 | 194 | args.nparams = sum([p.data.nelement() for p in net.parameters()]) 195 | with open(os.path.join(args.save, 'meta.json'), 'w') as f: 196 | json.dump(vars(args), f, sort_keys=True, indent=2) 197 | 198 | print(' + Number of params: {}'.format(args.nparams)) 199 | if args.cuda: 200 | net = net.cuda() 201 | 202 | trainF = open(os.path.join(args.save, 'train.csv'), 'w') 203 | testF = open(os.path.join(args.save, 'test.csv'), 'w') 204 | 205 | for epoch in range(1, args.nEpoch + 1): 206 | adjust_opt(args, optimizer, epoch) 207 | train(args, epoch, net, trainLoader, optimizer, trainF) 208 | test(args, epoch, net, testLoader, optimizer, testF) 209 | try: 210 | torch.save(net, os.path.join(args.save, 'latest.pth')) 211 | except: 212 | pass 213 | os.system('./plot.py "{}" &'.format(args.save)) 214 | 215 | trainF.close() 216 | testF.close() 217 | 218 | def train(args, epoch, net, trainLoader, optimizer, trainF): 219 | net.train() 220 | nProcessed = 0 221 | nTrain = len(trainLoader.dataset) 222 | for batch_idx, (data, target) in enumerate(trainLoader): 223 | if args.cuda: 224 | data, target = data.cuda(), target.cuda() 225 | data, target = Variable(data), Variable(target) 226 | optimizer.zero_grad() 227 | output = net(data) 228 | loss = F.nll_loss(output, target) 229 | # make_graph.save('/tmp/t.dot', loss.creator); assert(False) 230 | loss.backward() 231 | optimizer.step() 232 | nProcessed += len(data) 233 | pred = output.data.max(1)[1] # get the index of the max log-probability 234 | incorrect = pred.ne(target.data).cpu().sum() 235 | err = 100.*incorrect/len(data) 236 | partialEpoch = epoch + batch_idx / len(trainLoader) - 1 237 | print('Train Epoch: {:.2f} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tError: {:.6f}'.format( 238 | partialEpoch, nProcessed, nTrain, 100. * batch_idx / len(trainLoader), 239 | loss.data[0], err)) 240 | 241 | trainF.write('{},{},{}\n'.format(partialEpoch, loss.data[0], err)) 242 | trainF.flush() 243 | 244 | def test(args, epoch, net, testLoader, optimizer, testF): 245 | net.eval() 246 | test_loss = 0 247 | incorrect = 0 248 | for data, target in testLoader: 249 | if args.cuda: 250 | data, target = data.cuda(), target.cuda() 251 | data, target = Variable(data, volatile=True), Variable(target) 252 | output = net(data) 253 | test_loss += F.nll_loss(output, target).data[0] 254 | pred = output.data.max(1)[1] # get the index of the max log-probability 255 | incorrect += pred.ne(target.data).cpu().sum() 256 | 257 | test_loss = test_loss 258 | test_loss /= len(testLoader) # loss function already averages over batch size 259 | nTotal = len(testLoader.dataset) 260 | err = 100.*incorrect/nTotal 261 | print('\nTest set: Average loss: {:.4f}, Error: {}/{} ({:.0f}%)\n'.format( 262 | test_loss, incorrect, nTotal, err)) 263 | 264 | testF.write('{},{},{}\n'.format(epoch, test_loss, err)) 265 | testF.flush() 266 | 267 | def adjust_opt(args, optimizer, epoch): 268 | if args.model == 'densenet': 269 | if args.opt == 'sgd': 270 | if epoch == 150: update_lr(optimizer, 1e-2) 271 | elif epoch == 225: update_lr(optimizer, 1e-3) 272 | else: return 273 | 274 | def update_lr(optimizer, lr): 275 | for param_group in optimizer.param_groups: 276 | param_group['lr'] = lr 277 | 278 | if __name__=='__main__': 279 | main() 280 | -------------------------------------------------------------------------------- /denoising/create.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import numpy as np 5 | import numpy.random as npr 6 | import torch 7 | 8 | import os, sys 9 | import shutil 10 | 11 | import matplotlib as mpl 12 | mpl.use('Agg') 13 | import matplotlib.pyplot as plt 14 | plt.style.use('bmh') 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--minBps', type=int, default=1) 19 | parser.add_argument('--maxBps', type=int, default=10) 20 | parser.add_argument('--seqLen', type=int, default=100) 21 | parser.add_argument('--minHeight', type=int, default=10) 22 | parser.add_argument('--maxHeight', type=int, default=100) 23 | parser.add_argument('--noise', type=float, default=10) 24 | parser.add_argument('--nSamples', type=int, default=10000) 25 | parser.add_argument('--save', type=str, default='data/synthetic') 26 | args = parser.parse_args() 27 | 28 | npr.seed(0) 29 | 30 | save = args.save 31 | if os.path.isdir(save): 32 | shutil.rmtree(save) 33 | os.makedirs(save) 34 | 35 | X, Y = [], [] 36 | for i in range(args.nSamples): 37 | Xi, Yi = sample(args) 38 | X.append(Xi); Y.append(Yi) 39 | if i == 0: 40 | fig, ax = plt.subplots(1, 1) 41 | plt.plot(Xi, label='Corrupted') 42 | plt.plot(Yi, label='Original') 43 | plt.legend() 44 | f = os.path.join(args.save, "example.png") 45 | fig.savefig(f) 46 | print("Created {}".format(f)) 47 | 48 | X = np.array(X) 49 | Y = np.array(Y) 50 | 51 | for loc,arr in (('features.pt', X), ('labels.pt', Y)): 52 | fname = os.path.join(args.save, loc) 53 | with open(fname, 'wb') as f: 54 | torch.save(torch.Tensor(arr), f) 55 | print("Created {}".format(fname)) 56 | 57 | def sample(args): 58 | nBps = npr.randint(args.minBps, args.maxBps) 59 | bpLocs = [0] + sorted(npr.choice(args.seqLen-2, nBps-1, replace=False)+1) + [args.seqLen] 60 | bpDiffs = np.diff(bpLocs) 61 | heights = npr.randint(args.minHeight, args.maxHeight, nBps) 62 | Y = [] 63 | for d, h in zip(bpDiffs, heights): 64 | Y += [h]*d 65 | Y = np.array(Y, dtype=np.float) 66 | 67 | X = Y + npr.normal(0, args.noise, (args.seqLen)) 68 | return X, Y 69 | 70 | if __name__=='__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /denoising/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import csv 5 | import os 6 | import shutil 7 | from tqdm import tqdm 8 | 9 | try: import setGPU 10 | except ImportError: pass 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torch.autograd import Variable 16 | 17 | import numpy as np 18 | import numpy.random as npr 19 | 20 | import sys 21 | 22 | import matplotlib as mpl 23 | mpl.use('Agg') 24 | import matplotlib.pyplot as plt 25 | plt.style.use('bmh') 26 | 27 | import setproctitle 28 | 29 | import models 30 | 31 | import sys 32 | from IPython.core import ultratb 33 | sys.excepthook = ultratb.FormattedTB(mode='Verbose', 34 | color_scheme='Linux', call_pdb=1) 35 | 36 | def print_header(msg): 37 | print('===>', msg) 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--no-cuda', action='store_true') 42 | parser.add_argument('--batchSz', type=int, default=150) 43 | parser.add_argument('--testBatchSz', type=int, default=100) 44 | parser.add_argument('--nEpoch', type=int, default=100) 45 | parser.add_argument('--testPct', type=float, default=0.1) 46 | parser.add_argument('--work', type=str, default='work') 47 | parser.add_argument('--save', type=str) 48 | subparsers = parser.add_subparsers(dest='model') 49 | subparsers.required = True 50 | reluP = subparsers.add_parser('relu') 51 | reluP.add_argument('--nHidden', type=int, default=50) 52 | reluP.add_argument('--bn', action='store_true') 53 | optnetP = subparsers.add_parser('optnet') 54 | # optnetP.add_argument('--nHidden', type=int, default=50) 55 | # optnetP.add_argument('--nineq', type=int, default=100) 56 | optnetP.add_argument('--eps', type=float, default=1e-4) 57 | optnetP.add_argument('--tvInit', action='store_true') 58 | optnetP.add_argument('--learnD', action='store_true') 59 | optnetP.add_argument('--Dpenalty', type=float, default=1e-1) 60 | args = parser.parse_args() 61 | 62 | args.cuda = not args.no_cuda and torch.cuda.is_available() 63 | # args.save = args.save or 'work/{}.{}'.format(args.dataset, args.model) 64 | if args.save is None: 65 | t = os.path.join(args.work, args.model) 66 | if args.model == 'optnet': 67 | t += '.eps={}'.format(args.eps) 68 | if args.tvInit: 69 | t += '.tvInit' 70 | if args.learnD: 71 | t += '.learnD.{}'.format(args.Dpenalty) 72 | elif args.model == 'relu': 73 | t += '.nHidden:{}'.format(args.nHidden) 74 | if args.bn: 75 | t += '.bn' 76 | args.save = t 77 | setproctitle.setproctitle('bamos.' + args.save) 78 | 79 | with open('data/synthetic/features.pt', 'rb') as f: 80 | X = torch.load(f) 81 | with open('data/synthetic/labels.pt', 'rb') as f: 82 | Y = torch.load(f) 83 | 84 | N, nFeatures = X.size() 85 | 86 | nTrain = int(N*(1.-args.testPct)) 87 | nTest = N-nTrain 88 | 89 | trainX = X[:nTrain] 90 | trainY = Y[:nTrain] 91 | testX = X[nTrain:] 92 | testY = Y[nTrain:] 93 | 94 | assert(nTrain % args.batchSz == 0) 95 | assert(nTest % args.testBatchSz == 0) 96 | 97 | save = args.save 98 | if os.path.isdir(save): 99 | shutil.rmtree(save) 100 | os.makedirs(save) 101 | 102 | npr.seed(1) 103 | 104 | print_header('Building model') 105 | if args.model == 'relu': 106 | # nHidden = 2*nFeatures-1 107 | nHidden = args.nHidden 108 | model = models.ReluNet(nFeatures, nHidden, args.bn) 109 | elif args.model == 'optnet': 110 | if args.learnD: 111 | model = models.OptNet_LearnD(nFeatures, args) 112 | else: 113 | model = models.OptNet(nFeatures, args) 114 | 115 | if args.cuda: 116 | model = model.cuda() 117 | 118 | fields = ['epoch', 'loss'] 119 | trainF = open(os.path.join(save, 'train.csv'), 'w') 120 | trainW = csv.writer(trainF) 121 | trainW.writerow(fields) 122 | trainF.flush() 123 | testF = open(os.path.join(save, 'test.csv'), 'w') 124 | testW = csv.writer(testF) 125 | testW.writerow(fields) 126 | testF.flush() 127 | 128 | 129 | if args.model == 'optnet': 130 | if args.tvInit: lr = 1e-4 131 | elif args.learnD: lr = 1e-2 132 | else: lr = 1e-3 133 | else: 134 | lr = 1e-3 135 | optimizer = optim.Adam(model.parameters(), lr=lr) 136 | 137 | writeParams(args, model, 'init') 138 | test(args, 0, model, testF, testW, testX, testY) 139 | for epoch in range(1, args.nEpoch+1): 140 | # update_lr(optimizer, epoch) 141 | train(args, epoch, model, trainF, trainW, trainX, trainY, optimizer) 142 | test(args, epoch, model, testF, testW, testX, testY) 143 | torch.save(model, os.path.join(args.save, 'latest.pth')) 144 | writeParams(args, model, 'latest') 145 | os.system('./plot.py "{}" &'.format(args.save)) 146 | 147 | def writeParams(args, model, tag): 148 | if args.model == 'optnet' and args.learnD: 149 | D = model.D.data.cpu().numpy() 150 | np.savetxt(os.path.join(args.save, 'D.{}'.format(tag)), D) 151 | 152 | def train(args, epoch, model, trainF, trainW, trainX, trainY, optimizer): 153 | batchSz = args.batchSz 154 | 155 | batch_data_t = torch.FloatTensor(batchSz, trainX.size(1)) 156 | batch_targets_t = torch.FloatTensor(batchSz, trainY.size(1)) 157 | if args.cuda: 158 | batch_data_t = batch_data_t.cuda() 159 | batch_targets_t = batch_targets_t.cuda() 160 | batch_data = Variable(batch_data_t, requires_grad=False) 161 | batch_targets = Variable(batch_targets_t, requires_grad=False) 162 | for i in range(0, trainX.size(0), batchSz): 163 | batch_data.data[:] = trainX[i:i+batchSz] 164 | batch_targets.data[:] = trainY[i:i+batchSz] 165 | # Fixed batch size for debugging: 166 | # batch_data.data[:] = trainX[:batchSz] 167 | # batch_targets.data[:] = trainY[:batchSz] 168 | 169 | optimizer.zero_grad() 170 | preds = model(batch_data) 171 | mseLoss = nn.MSELoss()(preds, batch_targets) 172 | if args.model == 'optnet' and args.learnD: 173 | loss = mseLoss + args.Dpenalty*(model.D.norm(1)) 174 | else: 175 | loss = mseLoss 176 | loss.backward() 177 | optimizer.step() 178 | 179 | print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format( 180 | epoch, i+batchSz, trainX.size(0), 181 | float(i+batchSz)/trainX.size(0)*100, 182 | mseLoss.data[0])) 183 | 184 | trainW.writerow((epoch-1+float(i+batchSz)/trainX.size(0), mseLoss.data[0])) 185 | trainF.flush() 186 | 187 | def test(args, epoch, model, testF, testW, testX, testY): 188 | batchSz = args.testBatchSz 189 | 190 | test_loss = 0 191 | batch_data_t = torch.FloatTensor(batchSz, testX.size(1)) 192 | batch_targets_t = torch.FloatTensor(batchSz, testY.size(1)) 193 | if args.cuda: 194 | batch_data_t = batch_data_t.cuda() 195 | batch_targets_t = batch_targets_t.cuda() 196 | batch_data = Variable(batch_data_t, volatile=True) 197 | batch_targets = Variable(batch_targets_t, volatile=True) 198 | 199 | for i in range(0, testX.size(0), batchSz): 200 | print('Testing model: {}/{}'.format(i, testX.size(0)), end='\r') 201 | batch_data.data[:] = testX[i:i+batchSz] 202 | batch_targets.data[:] = testY[i:i+batchSz] 203 | output = model(batch_data) 204 | if i == 0: 205 | testOut = os.path.join(args.save, 'test-imgs') 206 | os.makedirs(testOut, exist_ok=True) 207 | for j in range(4): 208 | X = batch_data.data[j].cpu().numpy() 209 | Y = batch_targets.data[j].cpu().numpy() 210 | Yhat = output[j].data.cpu().numpy() 211 | 212 | fig, ax = plt.subplots(1, 1) 213 | plt.plot(X, label='Corrupted') 214 | plt.plot(Y, label='Original') 215 | plt.plot(Yhat, label='Predicted') 216 | plt.legend() 217 | f = os.path.join(testOut, '{}.png'.format(j)) 218 | fig.savefig(f) 219 | test_loss += nn.MSELoss()(output, batch_targets) 220 | 221 | nBatches = testX.size(0)/batchSz 222 | test_loss = test_loss.data[0]/nBatches 223 | print('TEST SET RESULTS:' + ' ' * 20) 224 | print('Average loss: {:.4f}'.format(test_loss)) 225 | 226 | testW.writerow((epoch, test_loss)) 227 | testF.flush() 228 | 229 | if __name__=='__main__': 230 | main() 231 | -------------------------------------------------------------------------------- /denoising/main.tv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import csv 5 | import os 6 | import shutil 7 | from tqdm import tqdm 8 | 9 | import cvxpy as cp 10 | 11 | import torch 12 | 13 | import numpy as np 14 | import numpy.random as npr 15 | 16 | import sys 17 | from IPython.core import ultratb 18 | sys.excepthook = ultratb.FormattedTB(mode='Verbose', 19 | color_scheme='Linux', call_pdb=1) 20 | 21 | import matplotlib as mpl 22 | mpl.use('Agg') 23 | import matplotlib.pyplot as plt 24 | plt.style.use('bmh') 25 | mpl.rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) 26 | mpl.rc('text', usetex=True) 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--nEpoch', type=int, default=50) 31 | parser.add_argument('--testPct', type=float, default=0.1) 32 | parser.add_argument('--workDir', type=str, default='work/tv') 33 | args = parser.parse_args() 34 | 35 | with open('data/synthetic/features.pt', 'rb') as f: 36 | X = torch.load(f).numpy() 37 | with open('data/synthetic/labels.pt', 'rb') as f: 38 | Y = torch.load(f).numpy() 39 | 40 | N, nFeatures = X.shape 41 | nTrain = int(N*(1.-args.testPct)) 42 | nTest = N-nTrain 43 | 44 | trainX = X[:nTrain] 45 | trainY = Y[:nTrain] 46 | testX = X[nTrain:] 47 | testY = Y[nTrain:] 48 | 49 | workDir = args.workDir 50 | if os.path.isdir(workDir): 51 | shutil.rmtree(workDir) 52 | os.makedirs(workDir) 53 | 54 | npr.seed(1) 55 | 56 | X_ = cp.Parameter(nFeatures) 57 | Y_ = cp.Variable(nFeatures) 58 | lams = list(np.linspace(0,100,101)) 59 | mses = [] 60 | 61 | def getMse(lam): 62 | prob = cp.Problem(cp.Minimize(0.5*cp.sum_squares(X_-Y_)+lam*cp.tv(Y_))) 63 | mses_lam = [] 64 | 65 | # testOut = os.path.join(workDir, 'test-imgs', 'lam-{:07.2f}'.format(lam)) 66 | # os.makedirs(testOut, exist_ok=True) 67 | 68 | for i in range(nTest): 69 | X_.value = testX[i] 70 | prob.solve(cp.SCS) 71 | assert('optimal' in prob.status) 72 | Yhat = np.array(Y_.value).ravel() 73 | mse = np.mean(np.square(testY[i] - Yhat)) 74 | 75 | mses_lam.append(mse) 76 | 77 | # if i <= 4: 78 | # fig, ax = plt.subplots(1, 1) 79 | # plt.plot(testX[i], label='Corrupted') 80 | # plt.plot(testY[i], label='Original') 81 | # plt.plot(Yhat, label='Predicted') 82 | # plt.legend() 83 | # f = os.path.join(testOut, '{}.png'.format(i)) 84 | # fig.savefig(f) 85 | # plt.close(fig) 86 | 87 | return np.mean(mses_lam) 88 | 89 | for lam in lams: 90 | mses.append(getMse(lam)) 91 | print(lam, mses[-1]) 92 | 93 | xMin, xMax = (1, 30) 94 | 95 | with open(os.path.join(workDir, 'mses.csv'), 'w') as f: 96 | for lam,mse in zip(lams,mses): 97 | f.write('{},{}\n'.format(lam,mse)) 98 | 99 | fig, ax = plt.subplots(1, 1) 100 | plt.plot(lams, mses) 101 | plt.xlabel("$\lambda$") 102 | plt.ylabel("MSE") 103 | # plt.xlim(xmin=0) 104 | # ax.set_yscale('log') 105 | for ext in ['pdf', 'png']: 106 | f = os.path.join(workDir, "loss."+ext) 107 | fig.savefig(f) 108 | print("Created {}".format(f)) 109 | 110 | if __name__=='__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /denoising/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | from torch.nn.parameter import Parameter 12 | 13 | import torchvision.datasets as dset 14 | import torchvision.transforms as transforms 15 | from torch.utils.data import DataLoader 16 | 17 | from block import block 18 | 19 | from qpth.qp import QPFunction 20 | 21 | class ReluNet(nn.Module): 22 | def __init__(self, nFeatures, nHidden, bn=False): 23 | super().__init__() 24 | self.bn = bn 25 | 26 | self.fc1 = nn.Linear(nFeatures, nHidden) 27 | self.fc2 = nn.Linear(nHidden, nFeatures) 28 | if bn: 29 | self.bn1 = nn.BatchNorm1d(nHidden) 30 | 31 | def __call__(self, x): 32 | x = F.relu(self.fc1(x)) 33 | if self.bn: 34 | x = self.bn1(x) 35 | x = self.fc2(x) 36 | return x 37 | 38 | class OptNet(nn.Module): 39 | def __init__(self, nFeatures, args): 40 | super(OptNet, self).__init__() 41 | 42 | nHidden, neq, nineq = 2*nFeatures-1,0,2*nFeatures-2 43 | assert(neq==0) 44 | 45 | self.fc1 = nn.Linear(nFeatures, nHidden) 46 | self.M = Variable(torch.tril(torch.ones(nHidden, nHidden)).cuda()) 47 | 48 | if args.tvInit: 49 | Q = 1e-8*torch.eye(nHidden) 50 | Q[:nFeatures,:nFeatures] = torch.eye(nFeatures) 51 | self.L = Parameter(torch.potrf(Q)) 52 | 53 | D = torch.zeros(nFeatures-1, nFeatures) 54 | D[:nFeatures-1,:nFeatures-1] = torch.eye(nFeatures-1) 55 | D[:nFeatures-1,1:nFeatures] -= torch.eye(nFeatures-1) 56 | G_ = block((( D, -torch.eye(nFeatures-1)), 57 | (-D, -torch.eye(nFeatures-1)))) 58 | self.G = Parameter(G_) 59 | self.s0 = Parameter(torch.ones(2*nFeatures-2)+1e-6*torch.randn(2*nFeatures-2)) 60 | G_pinv = (G_.t().mm(G_)+1e-5*torch.eye(nHidden)).inverse().mm(G_.t()) 61 | self.z0 = Parameter(-G_pinv.mv(self.s0.data)+1e-6*torch.randn(nHidden)) 62 | 63 | lam = 21.21 64 | W_fc1, b_fc1 = self.fc1.weight, self.fc1.bias 65 | W_fc1.data[:,:] = 1e-3*torch.randn((2*nFeatures-1, nFeatures)) 66 | # W_fc1.data[:,:] = 0.0 67 | W_fc1.data[:nFeatures,:nFeatures] += -torch.eye(nFeatures) 68 | # b_fc1.data[:] = torch.zeros(2*nFeatures-1) 69 | b_fc1.data[:] = 0.0 70 | b_fc1.data[nFeatures:2*nFeatures-1] = lam 71 | else: 72 | self.L = Parameter(torch.tril(torch.rand(nHidden, nHidden))) 73 | self.G = Parameter(torch.Tensor(nineq,nHidden).uniform_(-1,1)) 74 | self.z0 = Parameter(torch.zeros(nHidden)) 75 | self.s0 = Parameter(torch.ones(nineq)) 76 | 77 | self.nFeatures = nFeatures 78 | self.nHidden = nHidden 79 | self.neq = neq 80 | self.nineq = nineq 81 | self.args = args 82 | 83 | def cuda(self): 84 | # TODO: Is there a more automatic way? 85 | for x in [self.L, self.G, self.z0, self.s0]: 86 | x.data = x.data.cuda() 87 | 88 | return super().cuda() 89 | 90 | def forward(self, x): 91 | nBatch = x.size(0) 92 | 93 | x = self.fc1(x) 94 | 95 | L = self.M*self.L 96 | Q = L.mm(L.t()) + self.args.eps*Variable(torch.eye(self.nHidden)).cuda() 97 | Q = Q.unsqueeze(0).expand(nBatch, self.nHidden, self.nHidden) 98 | G = self.G.unsqueeze(0).expand(nBatch, self.nineq, self.nHidden) 99 | h = self.G.mv(self.z0)+self.s0 100 | h = h.unsqueeze(0).expand(nBatch, self.nineq) 101 | e = Variable(torch.Tensor()) 102 | x = QPFunction()(Q, x, G, h, e, e) 103 | x = x[:,:self.nFeatures] 104 | 105 | return x 106 | 107 | class OptNet_LearnD(nn.Module): 108 | def __init__(self, nFeatures, args): 109 | super().__init__() 110 | 111 | nHidden, neq, nineq = 2*nFeatures-1,0,2*nFeatures-2 112 | assert(neq==0) 113 | 114 | # self.fc1 = nn.Linear(nFeatures, nHidden) 115 | self.M = Variable(torch.tril(torch.ones(nHidden, nHidden)).cuda()) 116 | 117 | Q = 1e-8*torch.eye(nHidden) 118 | Q[:nFeatures,:nFeatures] = torch.eye(nFeatures) 119 | self.L = Variable(torch.potrf(Q)) 120 | 121 | self.D = Parameter(0.3*torch.randn(nFeatures-1, nFeatures)) 122 | # self.lam = Parameter(20.*torch.ones(1)) 123 | self.h = Variable(torch.zeros(nineq)) 124 | 125 | self.nFeatures = nFeatures 126 | self.nHidden = nHidden 127 | self.neq = neq 128 | self.nineq = nineq 129 | self.args = args 130 | 131 | def cuda(self): 132 | # TODO: Is there a more automatic way? 133 | for x in [self.L, self.D, self.h]: 134 | x.data = x.data.cuda() 135 | 136 | return super().cuda() 137 | 138 | def forward(self, x): 139 | nBatch = x.size(0) 140 | 141 | L = self.M*self.L 142 | Q = L.mm(L.t()) + self.args.eps*Variable(torch.eye(self.nHidden)).cuda() 143 | Q = Q.unsqueeze(0).expand(nBatch, self.nHidden, self.nHidden) 144 | nI = Variable(-torch.eye(self.nFeatures-1).type_as(Q.data)) 145 | G = torch.cat(( 146 | torch.cat(( self.D, nI), 1), 147 | torch.cat((-self.D, nI), 1) 148 | )) 149 | G = G.unsqueeze(0).expand(nBatch, self.nineq, self.nHidden) 150 | h = self.h.unsqueeze(0).expand(nBatch, self.nineq) 151 | e = Variable(torch.Tensor()) 152 | # p = torch.cat((-x, self.lam.unsqueeze(0).expand(nBatch, self.nFeatures-1)), 1) 153 | p = torch.cat((-x, Parameter(13.*torch.ones(nBatch, self.nFeatures-1).cuda())), 1) 154 | x = QPFunction()(Q.double(), p.double(), G.double(), h.double(), e, e).float() 155 | x = x[:,:self.nFeatures] 156 | 157 | return x 158 | -------------------------------------------------------------------------------- /denoising/plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | 5 | import matplotlib as mpl 6 | mpl.use('Agg') 7 | import matplotlib.pyplot as plt 8 | plt.style.use('bmh') 9 | import pandas as pd 10 | import numpy as np 11 | import math 12 | 13 | import os 14 | import sys 15 | import json 16 | import glob 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('workDir', type=str) 21 | args = parser.parse_args() 22 | 23 | trainF = os.path.join(args.workDir, 'train.csv') 24 | testF = os.path.join(args.workDir, 'test.csv') 25 | 26 | trainDf = pd.read_csv(trainF, sep=',') 27 | testDf = pd.read_csv(testF, sep=',') 28 | 29 | plotLoss(trainDf, testDf, args.workDir) 30 | 31 | initDf = os.path.join(args.workDir, 'D.init') 32 | if os.path.exists(initDf): 33 | initD = np.loadtxt(initDf) 34 | latestD = np.loadtxt(os.path.join(args.workDir, 'D.latest')) 35 | plotD(initD, latestD, args.workDir) 36 | 37 | def plotLoss(trainDf, testDf, workDir): 38 | # fig, ax = plt.subplots(1, 1, figsize=(5,2)) 39 | fig, ax = plt.subplots(1, 1) 40 | # fig.tight_layout() 41 | 42 | trainEpoch = trainDf['epoch'].values 43 | trainLoss = trainDf['loss'].values 44 | 45 | N = len(trainEpoch) // math.ceil(trainEpoch[-1]) 46 | trainEpoch_, trainLoss_ = rolling(N, trainEpoch, trainLoss) 47 | plt.plot(trainEpoch_, trainLoss_, label='Train') 48 | # plt.plot(trainEpoch, trainLoss, label='Train') 49 | if not testDf.empty: 50 | plt.plot(testDf['epoch'].values, testDf['loss'].values, label='Test') 51 | plt.xlabel("Epoch") 52 | plt.ylabel("MSE") 53 | plt.xlim(xmin=0) 54 | plt.grid(b=True, which='major', color='k', linestyle='-') 55 | plt.grid(b=True, which='minor', color='k', linestyle='--', alpha=0.2) 56 | plt.legend() 57 | ax.set_yscale('log') 58 | for ext in ['pdf', 'png']: 59 | f = os.path.join(workDir, "loss."+ext) 60 | fig.savefig(f) 61 | print("Created {}".format(f)) 62 | 63 | def plotD(initD, latestD, workDir): 64 | def p(D, fname): 65 | plt.clf() 66 | lim = max(np.abs(np.min(D)), np.abs(np.max(D))) 67 | clim = (-lim, lim) 68 | plt.imshow(D, cmap='bwr', interpolation='nearest', clim=clim) 69 | plt.colorbar() 70 | plt.savefig(os.path.join(workDir, fname)) 71 | 72 | p(initD, 'initD.png') 73 | p(latestD, 'latestD.png') 74 | 75 | latestDs = latestD**6 76 | latestDs = latestDs/np.sum(latestDs, axis=1)[:,None] 77 | I = np.argsort(latestDs.dot(np.arange(latestDs.shape[1]))) 78 | latestDs = latestD[I] 79 | initDs = initD[I] 80 | 81 | p(initDs, 'initD_sorted.png') 82 | p(latestDs, 'latestD_sorted.png') 83 | 84 | # Dcombined = np.concatenate((initDs, np.zeros((initD.shape[0], 10)), latestDs), axis=1) 85 | # p(Dcombined, 'Dcombined.png') 86 | 87 | def rolling(N, i, loss): 88 | i_ = i[N-1:] 89 | K = np.full(N, 1./N) 90 | loss_ = np.convolve(loss, K, 'valid') 91 | return i_, loss_ 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /denoising/run-exps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd $(dirname $0) 4 | 5 | runLearnD() { 6 | Dpenalty=$1 7 | ./main.py --nEpoch 50 optnet --learnD --Dpenalty $Dpenalty & 8 | sleep 4s 9 | } 10 | 11 | runLearnD 0.01 12 | runLearnD 0.1 13 | # runLearnD 0.5 14 | # runLearnD 1.0 15 | # runLearnD 10.0 16 | 17 | # ./main.py --nEpoch 50 optnet --tvInit & 18 | 19 | # runRelu() { 20 | # NHIDDEN=$1 21 | # OTHER=$2 22 | # ./main.py --nEpoch 50 relu --nHidden $NHIDDEN $OTHER & 23 | # sleep 4 24 | # } 25 | 26 | # runRelu 100 27 | # runRelu 100 '--bn' 28 | # runRelu 200 29 | # runRelu 200 '--bn' 30 | # runRelu 500 31 | # runRelu 500 '--bn' 32 | # runRelu 1000 33 | # runRelu 1000 '--bn' 34 | -------------------------------------------------------------------------------- /profile/optnet-forward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import sys 5 | 6 | import numpy as np 7 | import numpy.random as npr 8 | 9 | import adact 10 | import adact_forward_ip as aip 11 | 12 | import itertools 13 | import time 14 | 15 | import torch 16 | 17 | import sys 18 | from IPython.core import ultratb 19 | sys.excepthook = ultratb.FormattedTB(mode='Verbose', 20 | color_scheme='Linux', call_pdb=1) 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--nTrials', type=int, default=5) 25 | parser.add_argument('--nBatch', type=int, default=128) 26 | args = parser.parse_args() 27 | 28 | npr.seed(0) 29 | 30 | # print('==== CPU ===\n') 31 | # prof(args, False) 32 | 33 | print('\n\n==== GPU ===\n') 34 | prof(args, True) 35 | 36 | def prof(args, cuda): 37 | print('| nz | neq | nineq | single | batched |') 38 | print('|-------+-------+-------+--------+---------|') 39 | for nz,neq,nineq in itertools.product([10,100,200], [0,10,50], [10,50]): 40 | if nz >= neq and nz >= nineq: 41 | times = [] 42 | for i in range(args.nTrials): 43 | times.append(prof_instance(nz, neq, nineq, args.nBatch, cuda)) 44 | times = np.array(times) 45 | cp, pdipm = times.mean(axis=0) 46 | cp_sd, pdipm_sd = times.std(axis=0) 47 | print("| {:5d} | {:5d} | {:5d} | {:.3f} +/- {:.3f} | {:.3f} +/- {:.3f} |".format( 48 | nz, neq, nineq, cp, cp_sd, pdipm, pdipm_sd)) 49 | 50 | def prof_instance(nz, neq, nineq, nBatch, cuda): 51 | L = np.tril(npr.uniform(0,1, (nz,nz))) + np.eye(nz,nz) 52 | G = npr.randn(nineq,nz) 53 | A = npr.randn(neq,nz) 54 | z0 = npr.randn(nz) 55 | s0 = np.ones(nineq) 56 | p = npr.randn(nBatch,nz) 57 | 58 | p, L, G, A, z0, s0 = [torch.Tensor(x) for x in [p, L, G, A, z0, s0]] 59 | Q = torch.mm(L, L.t())+0.001*torch.eye(nz).type_as(L) 60 | if cuda: 61 | p, L, Q, G, A, z0, s0 = [x.cuda() for x in [p, L, Q, G, A, z0, s0]] 62 | b = torch.mv(A, z0) if neq > 0 else None 63 | h = torch.mv(G, z0)+s0 64 | 65 | af = adact.AdactFunction() 66 | 67 | single_results = [] 68 | start = time.time() 69 | U_Q, U_S, R = aip.pre_factor_kkt(Q, G, A) 70 | for i in range(nBatch): 71 | single_results.append(aip.forward_single(p[i], Q, G, A, b, h, U_Q, U_S, R)) 72 | single_time = time.time()-start 73 | 74 | start = time.time() 75 | Q_LU, S_LU, R = aip.pre_factor_kkt_batch(Q, G, A, nBatch) 76 | zhat_b, nu_b, lam_b = aip.forward_batch(p, Q, G, A, b, h, Q_LU, S_LU, R) 77 | batched_time = time.time()-start 78 | 79 | zhat_diff = (single_results[0][0] - zhat_b[0]).norm() 80 | lam_diff = (single_results[0][2] - lam_b[0]).norm() 81 | eps = 0.1 # Pretty relaxed. 82 | if zhat_diff > eps or lam_diff > eps: 83 | print('===========') 84 | print("Warning: Single and batched solutions might not match.") 85 | print(" + zhat_diff: {}".format(zhat_diff)) 86 | print(" + lam_diff: {}".format(lam_diff)) 87 | print(" + (nz, neq, nineq, nBatch) = ({}, {}, {}, {})".format( 88 | nz, neq, nineq, nBatch)) 89 | print('===========') 90 | 91 | return single_time, batched_time 92 | 93 | if __name__=='__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /profile/optnet-single.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import sys 5 | 6 | import numpy as np 7 | import numpy.random as npr 8 | 9 | import adact 10 | import adact_forward_ip as aip 11 | 12 | import itertools 13 | import time 14 | 15 | import torch 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--nTrials', type=int, default=5) 20 | parser.add_argument('--nBatch', type=int, default=128) 21 | args = parser.parse_args() 22 | 23 | npr.seed(0) 24 | 25 | print('==== CPU ===\n') 26 | prof(args, False) 27 | 28 | print('\n\n==== GPU ===\n') 29 | prof(args, True) 30 | 31 | def prof(args, cuda): 32 | print('| nz | neq | nineq | cvxpy | pdipm |') 33 | print('|-------+-------+-------+-------+-------|') 34 | for nz,neq,nineq in itertools.product([10,100,200], [0,10,50], [10,50]): 35 | if nz >= neq and nz >= nineq: 36 | times = [] 37 | for i in range(args.nTrials): 38 | times.append(prof_instance(nz, neq, nineq, args.nBatch, cuda)) 39 | times = np.array(times) 40 | cp, pdipm = times.mean(axis=0) 41 | cp_sd, pdipm_sd = times.std(axis=0) 42 | print("| {:5d} | {:5d} | {:5d} | {:.3f} +/- {:.3f} | {:.3f} +/- {:.3f} |".format( 43 | nz, neq, nineq, cp, cp_sd, pdipm, pdipm_sd)) 44 | 45 | def prof_instance(nz, neq, nineq, nIter, cuda): 46 | L = np.tril(npr.uniform(0,1, (nz,nz))) + np.eye(nz,nz) 47 | G = npr.randn(nineq,nz) 48 | A = npr.randn(neq,nz) 49 | z0 = npr.randn(nz) 50 | s0 = np.ones(nineq) 51 | p = npr.randn(nz) 52 | 53 | p, L, G, A, z0, s0 = [torch.Tensor(x) for x in [p, L, G, A, z0, s0]] 54 | Q = torch.mm(L, L.t())+0.001*torch.eye(nz).type_as(L) 55 | if cuda: 56 | p, L, Q, G, A, z0, s0 = [x.cuda() for x in [p, L, Q, G, A, z0, s0]] 57 | 58 | af = adact.AdactFunction() 59 | 60 | start = time.time() 61 | # One-time cost for numpy conversion. 62 | p_np, L_np, G_np, A_np, z0_np, s0_np = [adact.toNp(v) for v in [p, L, G, A, z0, s0]] 63 | cp = time.time()-start 64 | for i in range(nIter): 65 | start = time.time() 66 | zhat, nu, lam = af.forward_single_np(p_np, L_np, G_np, A_np, z0_np, s0_np) 67 | cp += time.time()-start 68 | 69 | b = torch.mv(A, z0) if neq > 0 else None 70 | h = torch.mv(G, z0)+s0 71 | L_Q, L_S, R = aip.pre_factor_kkt(Q, G, A, nineq, neq) 72 | pdipm = [] 73 | for i in range(nIter): 74 | start = time.time() 75 | zhat_ip, nu_ip, lam_ip = aip.forward_single(p, Q, G, A, b, h, L_Q, L_S, R) 76 | pdipm.append(time.time()-start) 77 | return cp, np.sum(pdipm) 78 | 79 | if __name__=='__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /sudoku/create.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Some portions from: https://www.ocf.berkeley.edu/~arel/sudoku/main.html 4 | 5 | import argparse 6 | import numpy as np 7 | import numpy.random as npr 8 | import torch 9 | 10 | from tqdm import tqdm 11 | 12 | import os, sys 13 | import shutil 14 | 15 | import random, copy 16 | 17 | import matplotlib as mpl 18 | mpl.use('Agg') 19 | import matplotlib.pyplot as plt 20 | plt.style.use('bmh') 21 | 22 | import sys 23 | from IPython.core import ultratb 24 | sys.excepthook = ultratb.FormattedTB(mode='Verbose', 25 | color_scheme='Linux', call_pdb=1) 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--boardSz', type=int, default=2) 30 | parser.add_argument('--nSamples', type=int, default=10000) 31 | parser.add_argument('--data', type=str, default='data') 32 | args = parser.parse_args() 33 | 34 | npr.seed(0) 35 | 36 | save = os.path.join(args.data, str(args.boardSz)) 37 | if os.path.isdir(save): 38 | shutil.rmtree(save) 39 | os.makedirs(save) 40 | 41 | X = [] 42 | Y = [] 43 | for i in tqdm(range(args.nSamples)): 44 | Xi, Yi = sample(args) 45 | X.append(Xi) 46 | Y.append(Yi) 47 | 48 | X = np.array(X) 49 | Y = np.array(Y) 50 | 51 | for loc,arr in (('features.pt', X), ('labels.pt', Y)): 52 | fname = os.path.join(save, loc) 53 | with open(fname, 'wb') as f: 54 | torch.save(torch.Tensor(arr), f) 55 | print("Created {}".format(fname)) 56 | 57 | def sample(args): 58 | solution = construct_puzzle_solution(args.boardSz) 59 | Nsq = args.boardSz*args.boardSz 60 | nKeep = npr.randint(0, Nsq) 61 | board, nKept = pluck(copy.deepcopy(solution), nKeep) 62 | solution = toOneHot(solution) 63 | board = toOneHot(board) 64 | return board, solution 65 | 66 | def toOneHot(X): 67 | X = np.array(X) 68 | Nsq = X.shape[0] 69 | Y = np.zeros((Nsq, Nsq, Nsq)) 70 | for i in range(1,Nsq+1): 71 | Y[:,:,i-1][X == i] = 1.0 72 | return Y 73 | 74 | def construct_puzzle_solution(N): 75 | """ 76 | Randomly arrange numbers in a grid while making all rows, columns and 77 | squares (sub-grids) contain the numbers 1 through Nsq. 78 | 79 | For example, "sample" (above) could be the output of this function. """ 80 | # Loop until we're able to fill all N^4 cells with numbers, while 81 | # satisfying the constraints above. 82 | Nsq = N*N 83 | while True: 84 | try: 85 | puzzle = [[0]*Nsq for i in range(Nsq)] # start with blank puzzle 86 | rows = [set(range(1,Nsq+1)) for i in range(Nsq)] # set of available 87 | columns = [set(range(1,Nsq+1)) for i in range(Nsq)] # numbers for each 88 | squares = [set(range(1,Nsq+1)) for i in range(Nsq)] # row, column and square 89 | for i in range(Nsq): 90 | for j in range(Nsq): 91 | # pick a number for cell (i,j) from the set of remaining available numbers 92 | choices = rows[i].intersection(columns[j]).intersection( 93 | squares[(i//N)*N + j//N]) 94 | choice = random.choice(list(choices)) 95 | 96 | puzzle[i][j] = choice 97 | 98 | rows[i].discard(choice) 99 | columns[j].discard(choice) 100 | squares[(i//N)*N + j//N].discard(choice) 101 | 102 | # success! every cell is filled. 103 | return puzzle 104 | 105 | except IndexError: 106 | # if there is an IndexError, we have worked ourselves in a corner (we just start over) 107 | pass 108 | 109 | def pluck(puzzle, nKeep=0): 110 | """ 111 | Randomly pluck out K cells (numbers) from the solved puzzle grid, ensuring that any 112 | plucked number can still be deduced from the remaining cells. 113 | 114 | For deduction to be possible, each other cell in the plucked number's row, column, 115 | or square must not be able to contain that number. """ 116 | 117 | Nsq = len(puzzle) 118 | N = int(np.sqrt(Nsq)) 119 | 120 | 121 | def canBeA(puz, i, j, c): 122 | """ 123 | Answers the question: can the cell (i,j) in the puzzle "puz" contain the number 124 | in cell "c"? """ 125 | v = puz[c//Nsq][c%Nsq] 126 | if puz[i][j] == v: return True 127 | if puz[i][j] in range(1,Nsq+1): return False 128 | 129 | for m in range(Nsq): # test row, col, square 130 | # if not the cell itself, and the mth cell of the group contains the value v, then "no" 131 | if not (m==c//Nsq and j==c%Nsq) and puz[m][j] == v: return False 132 | if not (i==c//Nsq and m==c%Nsq) and puz[i][m] == v: return False 133 | if not ((i//N)*N + m//N==c//Nsq and (j//N)*N + m%N==c%Nsq) \ 134 | and puz[(i//N)*N + m//N][(j//N)*N + m%N] == v: 135 | return False 136 | 137 | return True 138 | 139 | 140 | """ 141 | starts with a set of all N^4 cells, and tries to remove one (randomly) at a time 142 | but not before checking that the cell can still be deduced from the remaining cells. """ 143 | cells = set(range(Nsq*Nsq)) 144 | cellsleft = cells.copy() 145 | while len(cells) > nKeep and len(cellsleft): 146 | cell = random.choice(list(cellsleft)) # choose a cell from ones we haven't tried 147 | cellsleft.discard(cell) # record that we are trying this cell 148 | 149 | # row, col and square record whether another cell in those groups could also take 150 | # on the value we are trying to pluck. (If another cell can, then we can't use the 151 | # group to deduce this value.) If all three groups are True, then we cannot pluck 152 | # this cell and must try another one. 153 | row = col = square = False 154 | 155 | for i in range(Nsq): 156 | if i != cell//Nsq: 157 | if canBeA(puzzle, i, cell%Nsq, cell): row = True 158 | if i != cell%Nsq: 159 | if canBeA(puzzle, cell//Nsq, i, cell): col = True 160 | if not (((cell//Nsq)/N)*N + i//N == cell//Nsq and ((cell//Nsq)%N)*N + i%N == cell%Nsq): 161 | if canBeA(puzzle, ((cell//Nsq)//N)*N + i//N, 162 | ((cell//Nsq)%N)*N + i%N, cell): square = True 163 | 164 | if row and col and square: 165 | continue # could not pluck this cell, try again. 166 | else: 167 | # this is a pluckable cell! 168 | puzzle[cell//Nsq][cell%Nsq] = 0 # 0 denotes a blank cell 169 | cells.discard(cell) # remove from the set of visible cells (pluck it) 170 | # we don't need to reset "cellsleft" because if a cell was not pluckable 171 | # earlier, then it will still not be pluckable now (with less information 172 | # on the board). 173 | 174 | # This is the puzzle we found, in all its glory. 175 | return (puzzle, len(cells)) 176 | 177 | if __name__=='__main__': 178 | main() 179 | -------------------------------------------------------------------------------- /sudoku/data/2/features.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/optnet/af53140bfcace770f09d078bcd316f000fd0b854/sudoku/data/2/features.pt -------------------------------------------------------------------------------- /sudoku/data/2/labels.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/optnet/af53140bfcace770f09d078bcd316f000fd0b854/sudoku/data/2/labels.pt -------------------------------------------------------------------------------- /sudoku/data/3/features.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/optnet/af53140bfcace770f09d078bcd316f000fd0b854/sudoku/data/3/features.pt -------------------------------------------------------------------------------- /sudoku/data/3/labels.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/optnet/af53140bfcace770f09d078bcd316f000fd0b854/sudoku/data/3/labels.pt -------------------------------------------------------------------------------- /sudoku/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from itertools import product 5 | 6 | import scipy.sparse as spa 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import Module 11 | import torch.optim as optim 12 | 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from torch.nn.parameter import Parameter 16 | 17 | import torchvision.datasets as dset 18 | import torchvision.transforms as transforms 19 | from torch.utils.data import DataLoader 20 | 21 | import cvxpy as cp 22 | 23 | from block import block 24 | 25 | from qpth.qp import SpQPFunction, QPFunction 26 | 27 | try: 28 | from osqpth.osqpth import OSQP, DiffModes 29 | except: 30 | pass 31 | 32 | 33 | class FC(nn.Module): 34 | def __init__(self, nFeatures, nHidden, bn=False): 35 | super().__init__() 36 | self.bn = bn 37 | 38 | fcs = [] 39 | prevSz = nFeatures 40 | for sz in nHidden: 41 | fc = nn.Linear(prevSz, sz) 42 | prevSz = sz 43 | fcs.append(fc) 44 | for sz in list(reversed(nHidden))+[nFeatures]: 45 | fc = nn.Linear(prevSz, sz) 46 | prevSz = sz 47 | fcs.append(fc) 48 | self.fcs = nn.ModuleList(fcs) 49 | 50 | def __call__(self, x): 51 | nBatch = x.size(0) 52 | Nsq = x.size(1) 53 | in_x = x 54 | x = x.view(nBatch, -1) 55 | 56 | for fc in self.fcs: 57 | x = F.relu(fc(x)) 58 | 59 | x = x.view_as(in_x) 60 | ex = x.exp() 61 | exs = ex.sum(3).expand(nBatch, Nsq, Nsq, Nsq) 62 | x = ex/exs 63 | 64 | return x 65 | 66 | class Conv(nn.Module): 67 | def __init__(self, boardSz): 68 | super().__init__() 69 | 70 | self.boardSz = boardSz 71 | 72 | convs = [] 73 | Nsq = boardSz**2 74 | prevSz = Nsq 75 | szs = [512]*10 + [Nsq] 76 | for sz in szs: 77 | conv = nn.Conv2d(prevSz, sz, kernel_size=3, padding=1) 78 | convs.append(conv) 79 | prevSz = sz 80 | 81 | self.convs = nn.ModuleList(convs) 82 | 83 | def __call__(self, x): 84 | nBatch = x.size(0) 85 | Nsq = x.size(1) 86 | 87 | for i in range(len(self.convs)-1): 88 | x = F.relu(self.convs[i](x)) 89 | x = self.convs[-1](x) 90 | 91 | ex = x.exp() 92 | exs = ex.sum(3).expand(nBatch, Nsq, Nsq, Nsq) 93 | x = ex/exs 94 | 95 | return x 96 | 97 | def get_sudoku_matrix(n): 98 | X = np.array([[cp.Variable(n**2) for i in range(n**2)] for j in range(n**2)]) 99 | cons = ([x >= 0 for row in X for x in row] + 100 | [cp.sum(x) == 1 for row in X for x in row] + 101 | [sum(row) == np.ones(n**2) for row in X] + 102 | [sum([row[i] for row in X]) == np.ones(n**2) for i in range(n**2)] + 103 | [sum([sum(row[i:i+n]) for row in X[j:j+n]]) == np.ones(n**2) for i in range(0,n**2,n) for j in range(0, n**2, n)]) 104 | f = sum([cp.sum(x) for row in X for x in row]) 105 | prob = cp.Problem(cp.Minimize(f), cons) 106 | 107 | A = np.asarray(prob.get_problem_data(cp.ECOS)[0]["A"].todense()) 108 | A0 = [A[0]] 109 | rank = 1 110 | for i in range(1,A.shape[0]): 111 | if np.linalg.matrix_rank(A0+[A[i]], tol=1e-12) > rank: 112 | A0.append(A[i]) 113 | rank += 1 114 | 115 | return np.array(A0) 116 | 117 | 118 | class OptNetEq(nn.Module): 119 | def __init__(self, n, Qpenalty, qp_solver, trueInit=False): 120 | super().__init__() 121 | 122 | self.qp_solver = qp_solver 123 | 124 | nx = (n**2)**3 125 | self.Q = Variable(Qpenalty*torch.eye(nx).double().cuda()) 126 | self.Q_idx = spa.csc_matrix(self.Q.detach().cpu().numpy()).nonzero() 127 | 128 | self.G = Variable(-torch.eye(nx).double().cuda()) 129 | self.h = Variable(torch.zeros(nx).double().cuda()) 130 | t = get_sudoku_matrix(n) 131 | 132 | if trueInit: 133 | self.A = Parameter(torch.DoubleTensor(t).cuda()) 134 | else: 135 | self.A = Parameter(torch.rand(t.shape).double().cuda()) 136 | self.log_z0 = Parameter(torch.zeros(nx).double().cuda()) 137 | # self.b = Variable(torch.ones(self.A.size(0)).double().cuda()) 138 | 139 | if self.qp_solver == 'osqpth': 140 | t = torch.cat((self.A, self.G), dim=0) 141 | self.AG_idx = spa.csc_matrix(t.detach().cpu().numpy()).nonzero() 142 | 143 | # @profile 144 | def forward(self, puzzles): 145 | nBatch = puzzles.size(0) 146 | 147 | p = -puzzles.view(nBatch, -1) 148 | b = self.A.mv(self.log_z0.exp()) 149 | 150 | if self.qp_solver == 'qpth': 151 | y = QPFunction(verbose=-1)( 152 | self.Q, p.double(), self.G, self.h, self.A, b 153 | ).float().view_as(puzzles) 154 | elif self.qp_solver == 'osqpth': 155 | _l = torch.cat( 156 | (b, torch.full(self.h.shape, float('-inf'), 157 | device=self.h.device, dtype=self.h.dtype)), 158 | dim=0) 159 | _u = torch.cat((b, self.h), dim=0) 160 | Q_data = self.Q[self.Q_idx[0], self.Q_idx[1]] 161 | 162 | AG = torch.cat((self.A, self.G), dim=0) 163 | AG_data = AG[self.AG_idx[0], self.AG_idx[1]] 164 | y = OSQP(self.Q_idx, self.Q.shape, self.AG_idx, AG.shape, 165 | diff_mode=DiffModes.FULL)( 166 | Q_data, p.double(), AG_data, _l, _u).float().view_as(puzzles) 167 | else: 168 | assert False 169 | 170 | return y 171 | 172 | 173 | class SpOptNetEq(nn.Module): 174 | def __init__(self, n, Qpenalty, trueInit=False): 175 | super().__init__() 176 | nx = (n**2)**3 177 | self.nx = nx 178 | 179 | spTensor = torch.cuda.sparse.DoubleTensor 180 | iTensor = torch.cuda.LongTensor 181 | dTensor = torch.cuda.DoubleTensor 182 | 183 | self.Qi = iTensor([range(nx), range(nx)]) 184 | self.Qv = Variable(dTensor(nx).fill_(Qpenalty)) 185 | self.Qsz = torch.Size([nx, nx]) 186 | 187 | self.Gi = iTensor([range(nx), range(nx)]) 188 | self.Gv = Variable(dTensor(nx).fill_(-1.0)) 189 | self.Gsz = torch.Size([nx, nx]) 190 | self.h = Variable(torch.zeros(nx).double().cuda()) 191 | 192 | t = get_sudoku_matrix(n) 193 | neq = t.shape[0] 194 | if trueInit: 195 | I = t != 0 196 | self.Av = Parameter(dTensor(t[I])) 197 | Ai_np = np.nonzero(t) 198 | self.Ai = torch.stack((torch.LongTensor(Ai_np[0]), 199 | torch.LongTensor(Ai_np[1]))).cuda() 200 | self.Asz = torch.Size([neq, nx]) 201 | else: 202 | # TODO: This is very dense: 203 | self.Ai = torch.stack((iTensor(list(range(neq))).unsqueeze(1).repeat(1, nx).view(-1), 204 | iTensor(list(range(nx))).repeat(neq))) 205 | self.Av = Parameter(dTensor(neq*nx).uniform_()) 206 | self.Asz = torch.Size([neq, nx]) 207 | self.b = Variable(torch.ones(neq).double().cuda()) 208 | 209 | def forward(self, puzzles): 210 | nBatch = puzzles.size(0) 211 | 212 | p = -puzzles.view(nBatch,-1).double() 213 | 214 | return SpQPFunction( 215 | self.Qi, self.Qsz, self.Gi, self.Gsz, self.Ai, self.Asz, verbose=-1)( 216 | self.Qv.expand(nBatch, self.Qv.size(0)), 217 | p, 218 | self.Gv.expand(nBatch, self.Gv.size(0)), 219 | self.h.expand(nBatch, self.h.size(0)), 220 | self.Av.expand(nBatch, self.Av.size(0)), 221 | self.b.expand(nBatch, self.b.size(0)) 222 | ).float().view_as(puzzles) 223 | 224 | 225 | class OptNetIneq(nn.Module): 226 | def __init__(self, n, Qpenalty, nineq): 227 | super().__init__() 228 | nx = (n**2)**3 229 | self.Q = Variable(Qpenalty*torch.eye(nx).double().cuda()) 230 | self.G1 = Variable(-torch.eye(nx).double().cuda()) 231 | self.h1 = Variable(torch.zeros(nx).double().cuda()) 232 | # if trueInit: 233 | # self.A = Parameter(torch.DoubleTensor(get_sudoku_matrix(n)).cuda()) 234 | # else: 235 | # # t = get_sudoku_matrix(n) 236 | # # self.A = Parameter(torch.rand(t.shape).double().cuda()) 237 | # # import IPython, sys; IPython.embed(); sys.exit(-1) 238 | self.A = Parameter(torch.rand(50,nx).double().cuda()) 239 | self.G2 = Parameter(torch.Tensor(128, nx).uniform_(-1,1).double().cuda()) 240 | self.z2 = Parameter(torch.zeros(nx).double().cuda()) 241 | self.s2 = Parameter(torch.ones(128).double().cuda()) 242 | # self.b = Variable(torch.ones(self.A.size(0)).double().cuda()) 243 | 244 | def forward(self, puzzles): 245 | nBatch = puzzles.size(0) 246 | 247 | p = -puzzles.view(nBatch,-1) 248 | 249 | h2 = self.G2.mv(self.z2)+self.s2 250 | G = torch.cat((self.G1, self.G2), 0) 251 | h = torch.cat((self.h1, h2), 0) 252 | e = Variable(torch.Tensor()) 253 | 254 | return QPFunction(verbose=False)( 255 | self.Q, p.double(), G, h, e, e 256 | ).float().view_as(puzzles) 257 | 258 | class OptNetLatent(nn.Module): 259 | def __init__(self, n, Qpenalty, nLatent, nineq, trueInit=False): 260 | super().__init__() 261 | nx = (n**2)**3 262 | self.fc_in = nn.Linear(nx, nLatent) 263 | self.Q = Variable(Qpenalty*torch.eye(nLatent).cuda()) 264 | self.G = Parameter(torch.Tensor(nineq, nLatent).uniform_(-1,1).cuda()) 265 | self.z = Parameter(torch.zeros(nLatent).cuda()) 266 | self.s = Parameter(torch.ones(nineq).cuda()) 267 | self.fc_out = nn.Linear(nLatent, nx) 268 | 269 | def forward(self, puzzles): 270 | nBatch = puzzles.size(0) 271 | 272 | x = puzzles.view(nBatch,-1) 273 | x = self.fc_in(x) 274 | 275 | e = Variable(torch.Tensor()) 276 | 277 | h = self.G.mv(self.z)+self.s 278 | x = QPFunction(verbose=False)( 279 | self.Q, x, self.G, h, e, e, 280 | ) 281 | 282 | x = self.fc_out(x) 283 | x = x.view_as(puzzles) 284 | return x 285 | 286 | 287 | # if __name__=="__main__": 288 | # sudoku = SolveSudoku(2, 0.2) 289 | # puzzle = [[4, 0, 0, 0], [0,0,4,0], [0,2,0,0], [0,0,0,1]] 290 | # Y = Variable(torch.DoubleTensor(np.array([[np.array(np.eye(5,4,-1)[i,:]) for i in row] for row in puzzle])).cuda()) 291 | # solution = sudoku(Y.unsqueeze(0)) 292 | # print(solution.view(1,4,4,4)) 293 | -------------------------------------------------------------------------------- /sudoku/plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | 5 | import matplotlib as mpl 6 | mpl.use('Agg') 7 | import matplotlib.pyplot as plt 8 | plt.style.use('bmh') 9 | import pandas as pd 10 | import numpy as np 11 | import math 12 | 13 | import os 14 | import sys 15 | import json 16 | import glob 17 | 18 | def main(): 19 | # import sys 20 | # from IPython.core import ultratb 21 | # sys.excepthook = ultratb.FormattedTB(mode='Verbose', 22 | # color_scheme='Linux', call_pdb=1) 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('workDir', type=str) 26 | args = parser.parse_args() 27 | 28 | trainF = os.path.join(args.workDir, 'train.csv') 29 | testF = os.path.join(args.workDir, 'test.csv') 30 | 31 | trainDf = pd.read_csv(trainF, sep=',') 32 | testDf = pd.read_csv(testF, sep=',') 33 | 34 | plotLoss(trainDf, testDf, args.workDir) 35 | plotErr(trainDf, testDf, args.workDir) 36 | 37 | initDf = os.path.join(args.workDir, 'D.init') 38 | if os.path.exists(initDf): 39 | initD = np.loadtxt(initDf) 40 | latestD = np.loadtxt(os.path.join(args.workDir, 'D.latest')) 41 | plotD(initD, latestD, args.workDir) 42 | 43 | loss_fname = os.path.join(args.workDir, 'loss.png') 44 | err_fname = os.path.join(args.workDir, 'err.png') 45 | loss_err_fname = os.path.join(args.workDir, 'loss-error.png') 46 | os.system('convert +append "{}" "{}" "{}"'.format(loss_fname, err_fname, loss_err_fname)) 47 | print('Created {}'.format(loss_err_fname)) 48 | 49 | def plotLoss(trainDf, testDf, workDir): 50 | # fig, ax = plt.subplots(1, 1, figsize=(5,2)) 51 | fig, ax = plt.subplots(1, 1) 52 | # fig.tight_layout() 53 | 54 | trainEpoch = trainDf['epoch'].values 55 | trainLoss = trainDf['loss'].values 56 | 57 | N = np.argmax(trainEpoch==1.0) 58 | trainEpoch = trainEpoch[N-1:] 59 | trainLoss = np.convolve(trainLoss, np.full(N, 1./N), mode='valid') 60 | plt.plot(trainEpoch, trainLoss, label='Train') 61 | if not testDf.empty: 62 | plt.plot(testDf['epoch'].values, testDf['loss'].values, label='Test') 63 | plt.xlabel("Epoch") 64 | plt.ylabel("MSE") 65 | plt.xlim(xmin=0) 66 | plt.grid(b=True, which='major', color='k', linestyle='-') 67 | plt.grid(b=True, which='minor', color='k', linestyle='--', alpha=0.2) 68 | plt.legend() 69 | # ax.set_yscale('log') 70 | ax.set_ylim(0, None) 71 | for ext in ['pdf', 'png']: 72 | f = os.path.join(workDir, "loss."+ext) 73 | fig.savefig(f) 74 | print("Created {}".format(f)) 75 | 76 | def plotErr(trainDf, testDf, workDir): 77 | # fig, ax = plt.subplots(1, 1, figsize=(5,2)) 78 | fig, ax = plt.subplots(1, 1) 79 | # fig.tight_layout() 80 | 81 | trainEpoch = trainDf['epoch'].values 82 | trainLoss = trainDf['err'].values 83 | 84 | N = np.argmax(trainEpoch==1.0) 85 | trainEpoch = trainEpoch[N-1:] 86 | trainLoss = np.convolve(trainLoss, np.full(N, 1./N), mode='valid') 87 | plt.plot(trainEpoch, trainLoss, label='Train') 88 | if not testDf.empty: 89 | plt.plot(testDf['epoch'].values, testDf['err'].values, label='Test') 90 | plt.xlabel("Epoch") 91 | plt.ylabel("Error") 92 | plt.xlim(xmin=0) 93 | plt.grid(b=True, which='major', color='k', linestyle='-') 94 | plt.grid(b=True, which='minor', color='k', linestyle='--', alpha=0.2) 95 | plt.legend() 96 | # ax.set_yscale('log') 97 | ax.set_ylim(0, None) 98 | for ext in ['pdf', 'png']: 99 | f = os.path.join(workDir, "err."+ext) 100 | fig.savefig(f) 101 | print("Created {}".format(f)) 102 | 103 | def plotD(initD, latestD, workDir): 104 | def p(D, fname): 105 | plt.clf() 106 | lim = max(np.abs(np.min(D)), np.abs(np.max(D))) 107 | clim = (-lim, lim) 108 | plt.imshow(D, cmap='bwr', interpolation='nearest', clim=clim) 109 | plt.colorbar() 110 | plt.savefig(os.path.join(workDir, fname)) 111 | 112 | p(initD, 'initD.png') 113 | p(latestD, 'latestD.png') 114 | 115 | latestDs = latestD**6 116 | latestDs = latestDs/np.sum(latestDs, axis=1)[:,None] 117 | I = np.argsort(latestDs.dot(np.arange(latestDs.shape[1]))) 118 | latestDs = latestD[I] 119 | initDs = initD[I] 120 | 121 | p(initDs, 'initD_sorted.png') 122 | p(latestDs, 'latestD_sorted.png') 123 | 124 | # Dcombined = np.concatenate((initDs, np.zeros((initD.shape[0], 10)), latestDs), axis=1) 125 | # p(Dcombined, 'Dcombined.png') 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /sudoku/prof-sparse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import csv 5 | import os 6 | import shutil 7 | from tqdm import tqdm 8 | import time 9 | 10 | try: import setGPU 11 | except ImportError: pass 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch.autograd import Variable 17 | 18 | import numpy as np 19 | import numpy.random as npr 20 | 21 | import sys 22 | 23 | import matplotlib as mpl 24 | mpl.use('Agg') 25 | import matplotlib.pyplot as plt 26 | plt.style.use('bmh') 27 | 28 | import setproctitle 29 | 30 | import models 31 | 32 | import sys 33 | from IPython.core import ultratb 34 | sys.excepthook = ultratb.FormattedTB(mode='Verbose', 35 | color_scheme='Linux', call_pdb=1) 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--no-cuda', action='store_true') 40 | parser.add_argument('--nTrials', type=int, default=5) 41 | # parser.add_argument('--boardSz', type=int, default=2) 42 | # parser.add_argument('--batchSz', type=int, default=150) 43 | parser.add_argument('--Qpenalty', type=float, default=0.1) 44 | args = parser.parse_args() 45 | 46 | args.cuda = not args.no_cuda and torch.cuda.is_available() 47 | setproctitle.setproctitle('bamos.sudoku.prof-sparse') 48 | 49 | print('=== nTrials: {}'.format(args.nTrials)) 50 | print('| {:8s} | {:8s} | {:21s} | {:21s} |'.format( 51 | 'boardSz', 'batchSz', 'dense forward (s)', 'sparse forward (s)')) 52 | for boardSz in [2,3]: 53 | with open('data/{}/features.pt'.format(boardSz), 'rb') as f: 54 | X = torch.load(f) 55 | with open('data/{}/labels.pt'.format(boardSz), 'rb') as f: 56 | Y = torch.load(f) 57 | N, nFeatures = X.size(0), int(np.prod(X.size()[1:])) 58 | 59 | for batchSz in [1, 64, 128]: 60 | dmodel = models.OptNetEq(boardSz, args.Qpenalty, trueInit=True) 61 | spmodel = models.SpOptNetEq(boardSz, args.Qpenalty, trueInit=True) 62 | if args.cuda: 63 | dmodel = dmodel.cuda() 64 | spmodel = spmodel.cuda() 65 | 66 | dtimes = [] 67 | sptimes = [] 68 | for i in range(args.nTrials): 69 | Xbatch = Variable(X[i*batchSz:(i+1)*batchSz]) 70 | Ybatch = Variable(Y[i*batchSz:(i+1)*batchSz]) 71 | if args.cuda: 72 | Xbatch = Xbatch.cuda() 73 | Ybatch = Ybatch.cuda() 74 | 75 | # Make sure buffers are initialized. 76 | # dmodel(Xbatch) 77 | # spmodel(Xbatch) 78 | 79 | start = time.time() 80 | # dmodel(Xbatch) 81 | dtimes.append(time.time()-start) 82 | 83 | start = time.time() 84 | spmodel(Xbatch) 85 | sptimes.append(time.time()-start) 86 | 87 | print('| {:8d} | {:8d} | {:.2e} +/- {:.2e} | {:.2e} +/- {:.2e} |'.format( 88 | boardSz, batchSz, np.mean(dtimes), np.std(dtimes), 89 | np.mean(sptimes), np.std(sptimes))) 90 | 91 | if __name__=='__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /sudoku/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import csv 5 | import os 6 | import shutil 7 | from tqdm import tqdm 8 | 9 | try: import setGPU 10 | except ImportError: pass 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torch.autograd import Variable 16 | 17 | import numpy as np 18 | import numpy.random as npr 19 | 20 | import sys 21 | import time 22 | 23 | import matplotlib as mpl 24 | mpl.use('Agg') 25 | import matplotlib.pyplot as plt 26 | plt.style.use('bmh') 27 | 28 | import setproctitle 29 | 30 | import models 31 | 32 | import sys 33 | from IPython.core import ultratb 34 | sys.excepthook = ultratb.FormattedTB(mode='Verbose', 35 | color_scheme='Linux', call_pdb=1) 36 | 37 | def print_header(msg): 38 | print('===>', msg) 39 | 40 | def main(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--no-cuda', action='store_true') 43 | parser.add_argument('--boardSz', type=int, default=2) 44 | parser.add_argument('--batchSz', type=int, default=150) 45 | parser.add_argument('--testBatchSz', type=int, default=200) 46 | parser.add_argument('--nEpoch', type=int, default=100) 47 | parser.add_argument('--testPct', type=float, default=0.1) 48 | parser.add_argument('--save', type=str) 49 | parser.add_argument('--work', type=str, default='work') 50 | parser.add_argument('--qp-solver', type=str, default='qpth', 51 | choices=['qpth', 'osqpth']) 52 | subparsers = parser.add_subparsers(dest='model') 53 | subparsers.required = True 54 | fcP = subparsers.add_parser('fc') 55 | fcP.add_argument('--nHidden', type=int, nargs='+', default=[100,100]) 56 | fcP.add_argument('--bn', action='store_true') 57 | convP = subparsers.add_parser('conv') 58 | convP.add_argument('--nHidden', type=int, default=50) 59 | convP.add_argument('--bn', action='store_true') 60 | spOptnetEqP = subparsers.add_parser('spOptnetEq') 61 | spOptnetEqP.add_argument('--Qpenalty', type=float, default=0.1) 62 | optnetEqP = subparsers.add_parser('optnetEq') 63 | optnetEqP.add_argument('--Qpenalty', type=float, default=0.1) 64 | optnetIneqP = subparsers.add_parser('optnetIneq') 65 | optnetIneqP.add_argument('--Qpenalty', type=float, default=0.1) 66 | optnetIneqP.add_argument('--nineq', type=int, default=100) 67 | optnetLatent = subparsers.add_parser('optnetLatent') 68 | optnetLatent.add_argument('--Qpenalty', type=float, default=0.1) 69 | optnetLatent.add_argument('--nLatent', type=int, default=100) 70 | optnetLatent.add_argument('--nineq', type=int, default=100) 71 | args = parser.parse_args() 72 | 73 | args.cuda = not args.no_cuda and torch.cuda.is_available() 74 | t = '{}.{}'.format(args.boardSz, args.model) 75 | if args.model == 'optnetEq' or args.model == 'spOptnetEq': 76 | t += '.Qpenalty={}'.format(args.Qpenalty) 77 | elif args.model == 'optnetIneq': 78 | t += '.Qpenalty={}'.format(args.Qpenalty) 79 | t += '.nineq={}'.format(args.nineq) 80 | elif args.model == 'optnetLatent': 81 | t += '.Qpenalty={}'.format(args.Qpenalty) 82 | t += '.nLatent={}'.format(args.nLatent) 83 | t += '.nineq={}'.format(args.nineq) 84 | elif args.model == 'fc': 85 | t += '.nHidden:{}'.format(','.join([str(x) for x in args.nHidden])) 86 | if args.bn: 87 | t += '.bn' 88 | if args.save is None: 89 | args.save = os.path.join(args.work, t) 90 | setproctitle.setproctitle('bamos.sudoku.' + t) 91 | 92 | with open('data/{}/features.pt'.format(args.boardSz), 'rb') as f: 93 | X = torch.load(f) 94 | with open('data/{}/labels.pt'.format(args.boardSz), 'rb') as f: 95 | Y = torch.load(f) 96 | 97 | N, nFeatures = X.size(0), int(np.prod(X.size()[1:])) 98 | 99 | nTrain = int(N*(1.-args.testPct)) 100 | nTest = N-nTrain 101 | 102 | trainX = X[:nTrain] 103 | trainY = Y[:nTrain] 104 | testX = X[nTrain:] 105 | testY = Y[nTrain:] 106 | 107 | assert(nTrain % args.batchSz == 0) 108 | assert(nTest % args.testBatchSz == 0) 109 | 110 | save = args.save 111 | if os.path.isdir(save): 112 | shutil.rmtree(save) 113 | os.makedirs(save) 114 | 115 | npr.seed(1) 116 | 117 | print_header('Building model') 118 | if args.model == 'fc': 119 | nHidden = args.nHidden 120 | model = models.FC(nFeatures, nHidden, args.bn) 121 | elif args.model == 'conv': 122 | model = models.Conv(args.boardSz) 123 | elif args.model == 'optnetEq': 124 | model = models.OptNetEq( 125 | n=args.boardSz, Qpenalty=args.Qpenalty, qp_solver=args.qp_solver, 126 | trueInit=False) 127 | elif args.model == 'spOptnetEq': 128 | model = models.SpOptNetEq(args.boardSz, args.Qpenalty, trueInit=False) 129 | elif args.model == 'optnetIneq': 130 | model = models.OptNetIneq(args.boardSz, args.Qpenalty, args.nineq) 131 | elif args.model == 'optnetLatent': 132 | model = models.OptNetLatent(args.boardSz, args.Qpenalty, args.nLatent, args.nineq) 133 | else: 134 | assert False 135 | 136 | if args.cuda: 137 | model = model.cuda() 138 | 139 | fields = ['epoch', 'loss', 'err'] 140 | trainF = open(os.path.join(save, 'train.csv'), 'w') 141 | trainW = csv.writer(trainF) 142 | trainW.writerow(fields) 143 | trainF.flush() 144 | fields = ['epoch', 'loss', 'err'] 145 | testF = open(os.path.join(save, 'test.csv'), 'w') 146 | testW = csv.writer(testF) 147 | testW.writerow(fields) 148 | testF.flush() 149 | 150 | 151 | if 'optnet' in args.model: 152 | # if args.tvInit: lr = 1e-4 153 | # elif args.learnD: lr = 1e-2 154 | # else: lr = 1e-3 155 | lr = 1e-1 156 | else: 157 | lr = 1e-3 158 | optimizer = optim.Adam(model.parameters(), lr=lr) 159 | 160 | # writeParams(args, model, 'init') 161 | # test(args, 0, model, testF, testW, testX, testY) 162 | for epoch in range(1, args.nEpoch+1): 163 | # update_lr(optimizer, epoch) 164 | train(args, epoch, model, trainF, trainW, trainX, trainY, optimizer) 165 | test(args, epoch, model, testF, testW, testX, testY) 166 | torch.save(model, os.path.join(args.save, 'latest.pth')) 167 | # writeParams(args, model, 'latest') 168 | os.system('./plot.py "{}" &'.format(args.save)) 169 | 170 | def writeParams(args, model, tag): 171 | if args.model == 'optnet': 172 | A = model.A.data.cpu().numpy() 173 | np.savetxt(os.path.join(args.save, 'A.{}'.format(tag)), A) 174 | 175 | # @profile 176 | def train(args, epoch, model, trainF, trainW, trainX, trainY, optimizer): 177 | batchSz = args.batchSz 178 | 179 | batch_data_t = torch.FloatTensor(batchSz, trainX.size(1), trainX.size(2), trainX.size(3)) 180 | batch_targets_t = torch.FloatTensor(batchSz, trainY.size(1), trainX.size(2), trainX.size(3)) 181 | if args.cuda: 182 | batch_data_t = batch_data_t.cuda() 183 | batch_targets_t = batch_targets_t.cuda() 184 | batch_data = Variable(batch_data_t, requires_grad=False) 185 | batch_targets = Variable(batch_targets_t, requires_grad=False) 186 | for i in range(0, trainX.size(0), batchSz): 187 | start = time.time() 188 | batch_data.data[:] = trainX[i:i+batchSz] 189 | batch_targets.data[:] = trainY[i:i+batchSz] 190 | # Fixed batch size for debugging: 191 | # batch_data.data[:] = trainX[:batchSz] 192 | # batch_targets.data[:] = trainY[:batchSz] 193 | 194 | optimizer.zero_grad() 195 | preds = model(batch_data) 196 | loss = nn.MSELoss()(preds, batch_targets) 197 | loss.backward() 198 | optimizer.step() 199 | 200 | err = computeErr(preds.data)/batchSz 201 | print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f} Err: {:.4f} Time: {:.2f}s'.format( 202 | epoch, i+batchSz, trainX.size(0), 203 | float(i+batchSz)/trainX.size(0)*100, 204 | loss.item(), err, time.time()-start)) 205 | 206 | trainW.writerow( 207 | (epoch-1+float(i+batchSz)/trainX.size(0), loss.item(), err)) 208 | trainF.flush() 209 | 210 | def test(args, epoch, model, testF, testW, testX, testY): 211 | batchSz = args.testBatchSz 212 | 213 | test_loss = 0 214 | batch_data_t = torch.FloatTensor(batchSz, testX.size(1), testX.size(2), testX.size(3)) 215 | batch_targets_t = torch.FloatTensor(batchSz, testY.size(1), testX.size(2), testX.size(3)) 216 | if args.cuda: 217 | batch_data_t = batch_data_t.cuda() 218 | batch_targets_t = batch_targets_t.cuda() 219 | batch_data = Variable(batch_data_t, volatile=True) 220 | batch_targets = Variable(batch_targets_t, volatile=True) 221 | 222 | nErr = 0 223 | for i in range(0, testX.size(0), batchSz): 224 | print('Testing model: {}/{}'.format(i, testX.size(0)), end='\r') 225 | batch_data.data[:] = testX[i:i+batchSz] 226 | batch_targets.data[:] = testY[i:i+batchSz] 227 | output = model(batch_data) 228 | test_loss += nn.MSELoss()(output, batch_targets) 229 | nErr += computeErr(output.data) 230 | 231 | nBatches = testX.size(0)/batchSz 232 | test_loss = test_loss.item()/nBatches 233 | test_err = nErr/testX.size(0) 234 | print('TEST SET RESULTS:' + ' ' * 20) 235 | print('Average loss: {:.4f}'.format(test_loss)) 236 | print('Err: {:.4f}'.format(test_err)) 237 | 238 | testW.writerow((epoch, test_loss, test_err)) 239 | testF.flush() 240 | 241 | def computeErr(pred): 242 | batchSz = pred.size(0) 243 | nsq = int(pred.size(1)) 244 | n = int(np.sqrt(nsq)) 245 | s = (nsq-1)*nsq//2 # 0 + 1 + ... + n^2-1 246 | I = torch.max(pred, 3)[1].squeeze().view(batchSz, nsq, nsq) 247 | 248 | def invalidGroups(x): 249 | valid = (x.min(1)[0] == 0) 250 | valid *= (x.max(1)[0] == nsq-1) 251 | valid *= (x.sum(1) == s) 252 | return ~valid 253 | 254 | boardCorrect = torch.ones(batchSz).type_as(pred) 255 | for j in range(nsq): 256 | # Check the jth row and column. 257 | boardCorrect[invalidGroups(I[:,j,:])] = 0 258 | boardCorrect[invalidGroups(I[:,:,j])] = 0 259 | 260 | # Check the jth block. 261 | row, col = n*(j // n), n*(j % n) 262 | M = invalidGroups(I[:,row:row+n,col:col+n].contiguous().view(batchSz,-1)) 263 | boardCorrect[M] = 0 264 | 265 | if boardCorrect.sum() == 0: 266 | return batchSz 267 | 268 | return batchSz-boardCorrect.sum().item() 269 | 270 | if __name__=='__main__': 271 | main() 272 | -------------------------------------------------------------------------------- /sudoku/true-Qpenalty-errors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | import numpy as np 7 | 8 | import matplotlib as mpl 9 | mpl.use('Agg') 10 | import matplotlib.pyplot as plt 11 | plt.style.use('bmh') 12 | 13 | import models 14 | from train import computeErr 15 | 16 | batchSz = 128 17 | 18 | boards = {} 19 | for boardSz in (2,3): 20 | with open('data/{}/features.pt'.format(boardSz), 'rb') as f: 21 | unsolvedBoards = Variable(torch.load(f).cuda()[:,:,:,:]) 22 | nBoards = unsolvedBoards.size(0) 23 | with open('data/{}/labels.pt'.format(boardSz), 'rb') as f: 24 | solvedBoards = Variable(torch.load(f).cuda()[:nBoards,:,:,:]) 25 | boards[boardSz] = (unsolvedBoards, solvedBoards) 26 | 27 | nBatches = nBoards//batchSz 28 | results = {} 29 | startIdx = 0 30 | 31 | ranges = { 32 | 2: np.linspace(0.1, 2.0, num=11), 33 | 3: np.linspace(0.1, 1.0, num=10) 34 | } 35 | 36 | for i in range(nBatches): 37 | nSeen = (i+1)*batchSz 38 | print('=== {} Boards Seen ==='.format(nSeen)) 39 | 40 | for boardSz in (2,3): 41 | unsolvedBoards, solvedBoards = boards[boardSz] 42 | 43 | print('--- Board Sz: {} ---'.format(boardSz)) 44 | print('| {:15s} | {:15s} | {:15s} |'.format('Qpenalty', '% Boards Wrong', '# Blanks Wrong')) 45 | 46 | for j,Qpenalty in enumerate(ranges[boardSz]): 47 | model = models.OptNetEq(boardSz, Qpenalty, trueInit=True).cuda() 48 | X_batch = unsolvedBoards[startIdx:startIdx+batchSz] 49 | Y_batch = solvedBoards[startIdx:startIdx+batchSz] 50 | preds = model(X_batch).data 51 | err = computeErr(preds) 52 | 53 | # nWrong is not an exact metric because a board might have multiple solutions. 54 | predBoards = torch.max(preds, 3)[1].squeeze().view(batchSz, -1) 55 | trueBoards = torch.max(Y_batch.data, 3)[1].squeeze().view(batchSz, -1) 56 | nWrong = ((predBoards-trueBoards).abs().cpu().numpy() > 1e-7).sum(axis=1) 57 | 58 | results_key = (boardSz, j) 59 | if results_key not in results: 60 | results_j = {'err': err, 'nWrong': nWrong} 61 | results[results_key] = results_j 62 | else: 63 | results_j = results[results_key] 64 | results_j['err'] += err 65 | results_j['nWrong'] = np.concatenate((results_j['nWrong'], nWrong)) 66 | 67 | err = results_j['err']/(batchSz*(i+1)) 68 | nWrong = np.mean(results_j['nWrong']) 69 | print('| {:15f} | {:15f} | {:15f} |'.format(Qpenalty, err, nWrong)) 70 | 71 | print('='*50) 72 | print('\n\n') 73 | 74 | startIdx += batchSz 75 | -------------------------------------------------------------------------------- /tests/optnet-back.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Run these tests with: nosetests -v -d test-adact-back.py 4 | # This will run all functions even if one throws an assertion. 5 | # 6 | # For debugging: ./test-adact-back.py 7 | # Easier to print statements. 8 | # This will exit after the first assertion. 9 | 10 | import os 11 | import sys 12 | 13 | import torch 14 | 15 | import numpy as np 16 | import numpy.random as npr 17 | import numpy.testing as npt 18 | np.set_printoptions(precision=2) 19 | 20 | import numdifftools as nd 21 | import cvxpy as cp 22 | 23 | from torch.autograd import Function, Variable 24 | 25 | import adact 26 | import adact_forward_ip as aip 27 | 28 | from solver import BlockSolver as Solver 29 | 30 | from nose.tools import with_setup, assert_almost_equal 31 | 32 | import sys 33 | from IPython.core import ultratb 34 | sys.excepthook = ultratb.FormattedTB(mode='Verbose', 35 | color_scheme='Linux', call_pdb=1) 36 | 37 | ATOL=1e-2 38 | RTOL=1e-7 39 | verbose = True 40 | cuda = True 41 | 42 | def test_back(): 43 | npr.seed(1) 44 | nBatch, nz, neq, nineq = 1, 10, 1, 3 45 | # nz, neq, nineq = 3,3,3 46 | 47 | L = np.tril(np.random.randn(nz,nz)) + 2.*np.eye(nz,nz) 48 | Q = L.dot(L.T)+1e-4*np.eye(nz) 49 | G = 100.*npr.randn(nineq,nz) 50 | A = 100.*npr.randn(neq,nz) 51 | z0 = 1.*npr.randn(nz) 52 | s0 = 100.*np.ones(nineq) 53 | s0[:nineq//2] = 1e-6 54 | # print(np.linalg.norm(L)) 55 | # print(np.linalg.norm(G)) 56 | # print(np.linalg.norm(A)) 57 | # print(np.linalg.norm(z0)) 58 | # print(np.linalg.norm(s0)) 59 | 60 | p = npr.randn(nBatch,nz) 61 | # print(np.linalg.norm(p)) 62 | truez = npr.randn(nBatch,nz) 63 | 64 | af = adact.AdactFunction() 65 | zhat_0, nu_0, lam_0 = af.forward_single_np(p[0], L, G, A, z0, s0) 66 | dl_dzhat_0 = zhat_0-truez[0] 67 | S = Solver(L, A, G, z0, s0, 1e-8) 68 | S.reinit(lam_0, zhat_0) 69 | dp_0, dL_0, dG_0, dA_0, dz0_0, ds0_0 = af.backward_single_np_solver( 70 | S, zhat_0, nu_0, lam_0, dl_dzhat_0, L, G, A, z0, s0) 71 | # zhat_1, nu_1, lam_1 = af.forward_single_np(p[1], L, G, A, z0, s0) 72 | # dl_dzhat_1 = zhat_1-truez[1] 73 | # S.reinit(lam_1, zhat_1) 74 | # dp_1, dL_1, dG_1, dA_1, dz0_1, ds0_1 = af.backward_single_np_solver( 75 | # S, zhat_1, nu_1, lam_1, dl_dzhat_1, L, G, A, z0, s0) 76 | 77 | p, L, G, A, z0, s0, truez = [torch.DoubleTensor(x) for x in [p, L, G, A, z0, s0, truez]] 78 | Q = torch.mm(L, L.t())+0.001*torch.eye(nz).type_as(L) 79 | if cuda: 80 | p, L, Q, G, A, z0, s0, truez = [x.cuda() for x in [p, L, Q, G, A, z0, s0, truez]] 81 | p, L, G, A, z0, s0 = [Variable(x) for x in [p, L, G, A, z0, s0]] 82 | for x in [p, L, G, A, z0, s0]: x.requires_grad = True 83 | 84 | # Q_LU, S_LU, R = aip.pre_factor_kkt_batch(Q, G, A, nBatch) 85 | # b = torch.mv(A, z0) if neq > 0 else None 86 | # h = torch.mv(G, z0)+s0 87 | # zhat_b, nu_b, lam_b = aip.forward_batch(p, Q, G, A, b, h, Q_LU, S_LU, R) 88 | 89 | zhats = af(p, L, G, A, z0, s0) 90 | dl_dzhat = zhats.data - truez 91 | zhats.backward(dl_dzhat) 92 | dp, dL, dG, dA, dz0, ds0 = [x.grad.clone() for x in [p, L, G, A, z0, s0]] 93 | 94 | if __name__=='__main__': 95 | test_back() 96 | -------------------------------------------------------------------------------- /tests/optnet-np.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Run these tests with: nosetests -v -d test-adact-np.py 4 | # This will run all functions even if one throws an assertion. 5 | # 6 | # For debugging: ./test-adact-back.py 7 | # Easier to print statements. 8 | # This will exit after the first assertion. 9 | 10 | import os 11 | import sys 12 | 13 | import torch 14 | 15 | import numpy as np 16 | import numpy.random as npr 17 | import numpy.testing as npt 18 | np.set_printoptions(precision=2) 19 | 20 | import numdifftools as nd 21 | import cvxpy as cp 22 | 23 | import adact 24 | import adact_forward_ip as aip 25 | 26 | from solver import BlockSolver as Solver 27 | 28 | from nose.tools import with_setup, assert_almost_equal 29 | 30 | ATOL=1e-2 31 | RTOL=1e-7 32 | 33 | npr.seed(1) 34 | nz, neq, nineq = 5,0,4 35 | # nz, neq, nineq = 3,3,3 36 | 37 | L = np.tril(np.random.randn(nz,nz)) + 2.*np.eye(nz,nz) 38 | Q = L.dot(L.T)+1e-8*np.eye(nz) 39 | G = 1000.*npr.randn(nineq,nz) 40 | A = 10000.*npr.randn(neq,nz) 41 | z0 = 1.*npr.randn(nz) 42 | s0 = 100.*np.ones(nineq) 43 | 44 | p = npr.randn(nz) 45 | truez = npr.randn(nz) 46 | 47 | af = adact.AdactFunction() 48 | 49 | zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) 50 | dl_dzhat = zhat-truez 51 | 52 | # dp, dL, dG, dA, dz0, ds0 = af.backward_single_np(zhat, nu, lam, dl_dzhat, L, G, A, z0, s0) 53 | 54 | S = Solver(L, A, G, z0, s0, 1e-8) 55 | S.reinit(lam, zhat) 56 | dp, dL, dG, dA, dz0, ds0 = af.backward_single_np_solver(S, zhat, nu, lam, dl_dzhat, L, G, A, z0, s0) 57 | 58 | verbose = True 59 | 60 | 61 | def test_ip_forward(): 62 | p_t, Q_t, G_t, A_t, z0_t, s0_t = [torch.Tensor(x) for x in [p, Q, G, A, z0, s0]] 63 | b = torch.mv(A_t, z0_t) if neq > 0 else None 64 | h = torch.mv(G_t,z0_t)+s0_t 65 | L_Q, L_S, R = aip.pre_factor_kkt(Q_t, G_t, A_t) 66 | 67 | zhat_ip, nu_ip, lam_ip = aip.forward_single(p_t, Q_t, G_t, A_t, b, h, L_Q, L_S, R) 68 | # Unnecessary clones here because of a pytorch bug when calling numpy 69 | # on a tensor with a non-zero offset. 70 | npt.assert_allclose(zhat, zhat_ip.clone().numpy(), rtol=RTOL, atol=ATOL) 71 | if neq > 0: 72 | npt.assert_allclose(nu, nu_ip.clone().numpy(), rtol=RTOL, atol=ATOL) 73 | npt.assert_allclose(lam, lam_ip.clone().numpy(), rtol=RTOL, atol=ATOL) 74 | 75 | def test_dl_dz0(): 76 | def f(z0): 77 | zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) 78 | return 0.5*np.sum(np.square(zhat - truez)) 79 | 80 | df = nd.Gradient(f) 81 | dz0_fd = df(z0) 82 | if verbose: 83 | print('dz0_fd: ', dz0_fd) 84 | print('dz0: ', dz0) 85 | npt.assert_allclose(dz0_fd, dz0, rtol=RTOL, atol=ATOL) 86 | 87 | def test_dl_ds0(): 88 | def f(s0): 89 | zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) 90 | return 0.5*np.sum(np.square(zhat - truez)) 91 | 92 | df = nd.Gradient(f) 93 | ds0_fd = df(s0) 94 | if verbose: 95 | print('ds0_fd: ', ds0_fd) 96 | print('ds0: ', ds0) 97 | npt.assert_allclose(ds0_fd, ds0, rtol=RTOL, atol=ATOL) 98 | 99 | def test_dl_dp(): 100 | def f(p): 101 | zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) 102 | return 0.5*np.sum(np.square(zhat - truez)) 103 | 104 | df = nd.Gradient(f) 105 | dp_fd = df(p) 106 | if verbose: 107 | print('dp_fd: ', dp_fd) 108 | print('dp: ', dp) 109 | npt.assert_allclose(dp_fd, dp, rtol=RTOL, atol=ATOL) 110 | 111 | def test_dl_dp_batch(): 112 | def f(p): 113 | zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) 114 | return 0.5*np.sum(np.square(zhat - truez)) 115 | 116 | df = nd.Gradient(f) 117 | dp_fd = df(p) 118 | if verbose: 119 | print('dp_fd: ', dp_fd) 120 | print('dp: ', dp) 121 | npt.assert_allclose(dp_fd, dp, rtol=RTOL, atol=ATOL) 122 | 123 | def test_dl_dA(): 124 | def f(A): 125 | A = A.reshape(neq,nz) 126 | zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) 127 | return 0.5*np.sum(np.square(zhat - truez)) 128 | 129 | df = nd.Gradient(f) 130 | dA_fd = df(A.ravel()).reshape(neq, nz) 131 | if verbose: 132 | print('dA_fd[1,:]: ', dA_fd[1,:]) 133 | print('dA[1,:]: ', dA[1,:]) 134 | npt.assert_allclose(dA_fd, dA, rtol=RTOL, atol=ATOL) 135 | 136 | def test_dl_dG(): 137 | def f(G): 138 | G = G.reshape(nineq,nz) 139 | zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) 140 | return 0.5*np.sum(np.square(zhat - truez)) 141 | 142 | df = nd.Gradient(f) 143 | dG_fd = df(G.ravel()).reshape(nineq, nz) 144 | if verbose: 145 | print('dG_fd[1,:]: ', dG_fd[1,:]) 146 | print('dG[1,:]: ', dG[1,:]) 147 | npt.assert_allclose(dG_fd, dG, rtol=RTOL, atol=ATOL) 148 | 149 | def test_dl_dL(): 150 | def f(l0): 151 | L_ = np.copy(L) 152 | L_[:,0] = l0 153 | zhat, nu, lam = af.forward_single_np(p, L_, G, A, z0, s0) 154 | return 0.5*np.sum(np.square(zhat - truez)) 155 | 156 | df = nd.Gradient(f) 157 | dL_fd = df(L[:,0]) 158 | dl0 = np.array(dL[:,0]).ravel() 159 | if verbose: 160 | print('dL_fd: ', dL_fd) 161 | print('dL: ', dl0) 162 | npt.assert_allclose(dL_fd, dl0, rtol=RTOL, atol=ATOL) 163 | 164 | if __name__=='__main__': 165 | # test_ip_forward() 166 | test_dl_dp() 167 | # test_dl_dp_batch() 168 | # test_dl_dz0() 169 | # test_dl_ds0() 170 | # if neq > 0: 171 | # test_dl_dA() 172 | # test_dl_dG() 173 | # test_dl_dL() 174 | -------------------------------------------------------------------------------- /util/init.plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import numpy.random as npr 5 | 6 | import matplotlib as mpl 7 | mpl.use('Agg') 8 | import matplotlib.pyplot as plt 9 | plt.style.use('bmh') 10 | 11 | for i in range(5): 12 | nz, neq, nineq = 2,0,10 13 | G = npr.uniform(-1., 1., (nineq,nz)) 14 | z0 = np.zeros(nz) 15 | s0 = np.ones(nineq) 16 | 17 | l, u = -4, 4 18 | b = np.linspace(l, u, num=1000) 19 | C, D = np.meshgrid(b, b) 20 | Z = [] 21 | for c,d in zip(C.ravel(), D.ravel()): 22 | x = np.array([c,d]) 23 | z = np.all(G.dot(x) <= G.dot(z0)+s0).astype(np.float32) 24 | Z.append(z) 25 | Z = np.array(Z).reshape(C.shape) 26 | 27 | fig, ax = plt.subplots(1, 1, figsize=(8,8)) 28 | plt.axis([l, u, l, u]) 29 | CS = plt.contourf(C, D, Z, cmap=plt.cm.Blues) 30 | f = 'data/2016-11-02/init.{}.png'.format(i) 31 | plt.savefig(f) 32 | print('created '+f) 33 | --------------------------------------------------------------------------------