├── .gitignore ├── CHANGES.txt ├── LICENSE.txt ├── MANIFEST ├── README.md ├── dist ├── proximal_gradient-0.1.0-py3-none-any.whl └── proximal_gradient-0.1.0.tar.gz ├── examples └── xor_linf1.py ├── proximal_gradient ├── __init__.py └── proximalGradient.py ├── setup.py └── tests ├── __init__.py └── testit.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.swp 3 | -------------------------------------------------------------------------------- /CHANGES.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentonMurray/ProxGradPytorch/c534a49142ac9ec149ca67de24bb0487fde1607b/CHANGES.txt -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kenton Murray 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST: -------------------------------------------------------------------------------- 1 | # file GENERATED by distutils, do NOT edit 2 | setup.py 3 | prox-grad-pytorch/__init__.py 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProxGradPyTorch 2 | ProxGradPyTorch is a PyTorch implementation of many of the proximal gradient algorithms from [Parikh and Boyd (2014)](https://web.stanford.edu/~boyd/papers/prox_algs.html). In particular, many of these algorithms are useful for Auto-Sizing Neural Networks [(Murray and Chiang 2015)](https://www.aclweb.org/anthology/D15-1107). 3 | 4 | If you use this toolkit, we would appreciate it if you could cite: 5 | 6 | @inproceedings{murray2019autosizing, 7 | author={Murray, Kenton and Kinnison, Jeffery and Nguyen, Toan Q. and Scheirer, Walter and Chiang, David}, 8 | title={Auto-Sizing the Transformer Network: Improving Speed, Efficiency, and Performance for Low-Resource Machine Translation}, 9 | year=2019, 10 | booktitle={Proceedings of the Third Workshop on Neural Generation and Translation}, 11 | } 12 | 13 | ## Installation 14 | 15 | The only dependency is on pytorch >=0.4.1 16 | 17 | The simplest way to install is using PyPI. Simply type: 18 | 19 | ``` 20 | pip install proximal-gradient 21 | ``` 22 | 23 | In the headers for any file that you want to use ProxGradPytorch, add the following line: 24 | 25 | ``` 26 | import proximal_gradient.proximalGradient as pg 27 | ``` 28 | 29 | # From Source 30 | 31 | To build from source, simply clone this repository. Currently, there is a dependency on pytorch >=0.4.1 On Linux, it's easiest to add the repo to your shared library path: 32 | 33 | ``` 34 | export LD_LIBRARY_PATH="[install_dir]/ProxGradPytorch/prox-grad-pytorch:$LD_LIBRARY_PATH" 35 | ``` 36 | 37 | In the headers for any file that you want to use ProxGradPytorch, add the following line: 38 | 39 | ``` 40 | import proximalGradient as pg 41 | ``` 42 | 43 | ## Running 44 | 45 | Proximal Gradient Algorithms make use of a two-step process. First, normal backpropogation is run on your network: 46 | 47 | ``` 48 | # Zero gradients, perform a backward pass, and update the weights. 49 | optimizer.zero_grad() 50 | loss.backward() 51 | optimizer.step() 52 | ``` 53 | 54 | This is just a standard pytorch update. Second, you run the proximal gradient algorithm. Many of these algorithms have a closed form solution and do not rely on stored gradients. For instance, to apply L2,1 regularization to a tensor named model.linear1, you run the following code: 55 | 56 | ``` 57 | pg.l21(model.linear1.weight, model.linear1.bias, reg=0.005) 58 | ``` 59 | 60 | This will apply a group regularizer over each row. Assuming that the row is the input to a non-linearity where f(0) = 0 (and is all of the inputs to a neuron), then this will auto-size that layer. There are many other regularizers implemented as well that are not just for auto-sizing (for instance L_infinity, L_2, etc.). 61 | 62 | ## Auto-Sizing 63 | 64 | [Murray et al. (2019)](https://www.aclweb.org/anthology/D19-5634/), make use of these algorithms for auto-sizing. Auto-sizing is a method for deleting the number of neurons in a network subject to a few assumptions. At a basic level, if all the weights of a neuron are 0.0, it does not matter what the input to that neuron is -- everything will be 0.0. If the non-linearity maps f(0) to 0, such as tanh or ReLU, the output is 0.0 and it is as if the neuron does not exist. Auto-sizing relies on the use of sparse group regularizers in order to drive these weights to 0. As sparse regularizers are often non-differentiable, the authors rely on the proximal gradient methods in this toolkit. For a more complete description of auto-sizing, see either that paper or [Murray and Chiang (2015)](https://www.aclweb.org/anthology/D15-1107). 65 | 66 | As an example of auto-sizing, let's look at simple xor example build with a two layer network (also available in the examples): 67 | 68 | ``` 69 | import torch 70 | from torch.autograd import Variable 71 | 72 | class TwoLayerNet(torch.nn.Module): 73 | def __init__(self, D_in, H, D_out): 74 | super(TwoLayerNet, self).__init__() 75 | self.linear1 = torch.nn.Linear(D_in, H) 76 | self.linear2 = torch.nn.Linear(H, D_out) 77 | 78 | def forward(self, x): 79 | h_relu = self.linear1(x).clamp(min=0) 80 | y_pred = self.linear2(h_relu) 81 | return y_pred 82 | 83 | # D_in is input dimension; H is hidden dimension; D_out is output dimension. 84 | D_in, H, D_out = 2, 100, 1 85 | 86 | # Inputs and Outputs for xor 87 | inputs = list(map(lambda s: Variable(torch.Tensor([s])), [ 88 | [0, 0], 89 | [0, 1], 90 | [1, 0], 91 | [1, 1] 92 | ])) 93 | targets = list(map(lambda s: Variable(torch.Tensor([s])), [ 94 | [0], 95 | [1], 96 | [1], 97 | [0] 98 | ])) 99 | 100 | # Construct model 101 | model = TwoLayerNet(D_in, H, D_out) 102 | 103 | # Loss, Optimizer, and Proximal Gradient 104 | criterion = torch.nn.MSELoss(reduction='sum') 105 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 106 | for t in range(5000): 107 | for input, target in zip(inputs, targets): 108 | # Forward pass: Compute predicted y by passing x to the model 109 | y_pred = model(input) 110 | 111 | # Compute loss 112 | loss = criterion(y_pred, target) 113 | 114 | # Zero gradients, perform a backward pass, and update the weights. 115 | optimizer.zero_grad() 116 | loss.backward() 117 | optimizer.step() 118 | 119 | # Neurons Left (H) 120 | print("H (model.linear1.weight):", (model.linear1.weight.nonzero()[:,0]).unique().size(0)) 121 | 122 | print("Final results:") 123 | for input, target in zip(inputs, targets): 124 | output = model(input) 125 | print("Input:", input, "Target:", target, "Predicted:", output) 126 | ``` 127 | 128 | To auto-size this network, which will reduce the dimension of H, only requires two lines of code. First, we import this toolkit: 129 | 130 | ``` 131 | import proximalGradient as pg 132 | ``` 133 | 134 | Then, we simply apply the proximal gradient step after optimizer.step(): 135 | 136 | ``` 137 | pg.linf1(model.linear1.weight, model.linear1.bias, reg=0.1) 138 | ``` 139 | 140 | So, the final code is: 141 | 142 | 143 | ``` 144 | import torch 145 | from torch.autograd import Variable 146 | import proximalGradient as pg 147 | 148 | 149 | class TwoLayerNet(torch.nn.Module): 150 | def __init__(self, D_in, H, D_out): 151 | super(TwoLayerNet, self).__init__() 152 | self.linear1 = torch.nn.Linear(D_in, H) 153 | self.linear2 = torch.nn.Linear(H, D_out) 154 | 155 | def forward(self, x): 156 | h_relu = self.linear1(x).clamp(min=0) 157 | y_pred = self.linear2(h_relu) 158 | return y_pred 159 | 160 | 161 | # D_in is input dimension; H is hidden dimension; D_out is output dimension. 162 | D_in, H, D_out = 2, 100, 1 163 | 164 | # Inputs and Outputs for xor 165 | inputs = list(map(lambda s: Variable(torch.Tensor([s])), [ 166 | [0, 0], 167 | [0, 1], 168 | [1, 0], 169 | [1, 1] 170 | ])) 171 | targets = list(map(lambda s: Variable(torch.Tensor([s])), [ 172 | [0], 173 | [1], 174 | [1], 175 | [0] 176 | ])) 177 | 178 | 179 | # Construct model 180 | model = TwoLayerNet(D_in, H, D_out) 181 | 182 | # Neurons to Start (H) 183 | print("H initially (model.linear1.weight):", (model.linear1.weight.nonzero()[:,0]).unique().size(0)) 184 | 185 | # Loss, Optimizer, and Proximal Gradient 186 | criterion = torch.nn.MSELoss(reduction='sum') 187 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 188 | for t in range(5000): 189 | for input, target in zip(inputs, targets): 190 | # Forward pass: Compute predicted y by passing x to the model 191 | y_pred = model(input) 192 | 193 | # Compute loss 194 | loss = criterion(y_pred, target) 195 | 196 | # Zero gradients, perform a backward pass, and update the weights. 197 | optimizer.zero_grad() 198 | loss.backward() 199 | optimizer.step() 200 | 201 | # Proximal Gradient Step 202 | pg.linf1(model.linear1.weight, model.linear1.bias, reg=0.005) 203 | 204 | # Neurons Left (H) 205 | print("H remaining (model.linear1.weight):", (model.linear1.weight.nonzero()[:,0]).unique().size(0)) 206 | 207 | print("Final results:") 208 | for input, target in zip(inputs, targets): 209 | output = model(input) 210 | print("Input:", input, "Target:", target, "Predicted:", output) 211 | ``` 212 | 213 | Though random initializations vary, frequently there are around 15 of the 100 neurons (H) left. 214 | -------------------------------------------------------------------------------- /dist/proximal_gradient-0.1.0-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentonMurray/ProxGradPytorch/c534a49142ac9ec149ca67de24bb0487fde1607b/dist/proximal_gradient-0.1.0-py3-none-any.whl -------------------------------------------------------------------------------- /dist/proximal_gradient-0.1.0.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentonMurray/ProxGradPytorch/c534a49142ac9ec149ca67de24bb0487fde1607b/dist/proximal_gradient-0.1.0.tar.gz -------------------------------------------------------------------------------- /examples/xor_linf1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import proximal_gradient.proximalGradient as pg 4 | 5 | class TwoLayerNet(torch.nn.Module): 6 | def __init__(self, D_in, H, D_out): 7 | super(TwoLayerNet, self).__init__() 8 | self.linear1 = torch.nn.Linear(D_in, H) 9 | self.linear2 = torch.nn.Linear(H, D_out) 10 | 11 | def forward(self, x): 12 | h_relu = self.linear1(x).clamp(min=0) 13 | y_pred = self.linear2(h_relu) 14 | return y_pred 15 | 16 | # D_in is input dimension; H is hidden dimension; D_out is output dimension. 17 | D_in, H, D_out = 2, 100, 1 18 | 19 | # Inputs and Outputs for xor 20 | inputs = list(map(lambda s: Variable(torch.Tensor([s])), [ 21 | [0, 0], 22 | [0, 1], 23 | [1, 0], 24 | [1, 1] 25 | ])) 26 | targets = list(map(lambda s: Variable(torch.Tensor([s])), [ 27 | [0], 28 | [1], 29 | [1], 30 | [0] 31 | ])) 32 | 33 | # Construct model 34 | model = TwoLayerNet(D_in, H, D_out) 35 | 36 | # Neurons to Start (H) 37 | print("H initially (model.linear1.weight):", (model.linear1.weight.nonzero()[:,0]).unique().size(0)) 38 | 39 | # Loss, Optimizer, and Proximal Gradient 40 | criterion = torch.nn.MSELoss(reduction='sum') 41 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 42 | for t in range(5000): 43 | for input, target in zip(inputs, targets): 44 | # Forward pass: Compute predicted y by passing x to the model 45 | y_pred = model(input) 46 | 47 | # Compute loss 48 | loss = criterion(y_pred, target) 49 | 50 | # Zero gradients, perform a backward pass, and update the weights. 51 | optimizer.zero_grad() 52 | loss.backward() 53 | optimizer.step() 54 | 55 | # Proximal Gradient Step 56 | pg.linf1(model.linear1.weight, model.linear1.bias, reg=0.005) 57 | 58 | # Neurons Left (H) 59 | print("H (model.linear1.weight):", (model.linear1.weight.nonzero()[:,0]).unique().size(0)) 60 | print("Final results:") 61 | for input, target in zip(inputs, targets): 62 | output = model(input) 63 | print("Input:", input, "Target:", target, "Predicted:", output) 64 | -------------------------------------------------------------------------------- /proximal_gradient/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentonMurray/ProxGradPytorch/c534a49142ac9ec149ca67de24bb0487fde1607b/proximal_gradient/__init__.py -------------------------------------------------------------------------------- /proximal_gradient/proximalGradient.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | 5 | def package(): 6 | print("Kenton") 7 | 8 | def l21(parameter, bias=None, reg=0.01, lr=0.1): 9 | """L21 Regularization""" 10 | 11 | if bias is not None: 12 | w_and_b = torch.cat((parameter, bias.unfold(0,1,1)),1) 13 | else: 14 | w_and_b = parameter 15 | L21 = reg # lambda: regularization strength 16 | Norm = (lr*L21/w_and_b.norm(2, dim=1)) # Key insight here: apply rowwise (by using dim 1) 17 | if Norm.is_cuda: 18 | ones = torch.ones(w_and_b.size(0), device=torch.device("cuda")) 19 | else: 20 | ones = torch.ones(w_and_b.size(0), device=torch.device("cpu")) 21 | l21T = 1.0 - torch.min(ones, Norm) 22 | update = (parameter*(l21T.unsqueeze(1))) 23 | parameter.data = update 24 | # Update bias 25 | if bias is not None: 26 | update_b = (bias*l21T) 27 | bias.data = update_b 28 | 29 | def l21_slow(parameter, reg=0.01, lr=0.1): 30 | """L21 Regularization (Slow implementation. Used for 31 | sanity checks.)""" 32 | 33 | w_and_b = parameter 34 | l21s = [] 35 | for row in w_and_b: 36 | L21 = reg 37 | l21 = lr * L21/row.norm(2) 38 | l21 = 1.0 - min(1.0, l21) 39 | l21s.append(l21) 40 | counter = 0 41 | for row in parameter: 42 | updated = row * l21s[counter] 43 | parameter.data[counter] = updated 44 | counter = counter + 1 45 | counter = 0 46 | 47 | def linf1(parameter, bias=None, reg=0.01, lr=0.1): 48 | """Linfity1 Regularization using Proximal Gradients""" 49 | 50 | Norm = reg*lr 51 | if bias is not None: 52 | w_and_b = torch.cat((parameter, bias.unfold(0,1,1)),1) 53 | else: 54 | w_and_b = parameter 55 | sorted_w_and_b, indices = torch.sort(torch.abs(w_and_b), descending=True) 56 | 57 | # CUDA or CPU 58 | devicetype="cuda" 59 | if w_and_b.is_cuda: 60 | devicetype="cuda" 61 | else: 62 | devicetype="cpu" 63 | 64 | 65 | #SLOW 66 | rows, cols = sorted_w_and_b.size() 67 | 68 | sorted_z = torch.cat((sorted_w_and_b, torch.zeros(rows,1, device=torch.device(devicetype))),1) 69 | subtracted = torch.clamp(sorted_w_and_b - sorted_z[:,1:],max=Norm) #Max=Norm important 70 | 71 | scale_indices = torch.cumsum(torch.ones(rows,cols, device=torch.device(devicetype)),1) 72 | scaled_subtracted = subtracted * scale_indices 73 | max_mass = torch.cumsum(scaled_subtracted,1) 74 | nonzero = torch.clamp(-1*(max_mass - Norm),0) 75 | 76 | oneN = 1.0/scale_indices 77 | 78 | # Algorithm described in paper, but these are all efficient GPU operation steps) 79 | # First we subtract every value from the cell next to it 80 | nonzero_ones = torch.clamp(nonzero * 1000000, max=1) #Hacky, but efficient 81 | shifted_ones = torch.cat((torch.ones(rows,1, device=torch.device(devicetype)),nonzero_ones[:,:(cols-1)]),1) 82 | over_one = -1*(nonzero_ones - shifted_ones) 83 | last_one = torch.cat((over_one,torch.zeros(rows,1, device=torch.device(devicetype))),1)[:,1:] 84 | max_remain = last_one * nonzero 85 | shift_max = torch.cat((torch.zeros(rows,1, device=torch.device(devicetype)),max_remain[:,:(cols-1)]),1) 86 | first_col_nonzero_ones = torch.cat((torch.ones(rows,1, device=torch.device(devicetype)),nonzero_ones[:,1:]),1) #Edge case for only first column 87 | tosub = first_col_nonzero_ones * subtracted + shift_max * oneN 88 | 89 | # We flip the tensor so that we can get a cumulative sum for the value to subtract, then flip back 90 | nastyflipS = torch.flip(torch.flip(tosub,[0,1]),[0]) 91 | aggsubS = torch.cumsum(nastyflipS,1) 92 | nastyflipagainS = torch.flip(torch.flip(aggsubS,[0,1]),[0]) 93 | 94 | # The proximal gradient step is equal to subtracting the sorted cumulative sum 95 | updated_weights = sorted_w_and_b - nastyflipagainS 96 | unsorted = torch.zeros(rows,cols, device=torch.device(devicetype)).scatter_(1,indices,updated_weights) 97 | final_w_and_b = torch.sign(w_and_b) * unsorted 98 | 99 | # Actually update parameters and bias 100 | if bias is not None: 101 | update = final_w_and_b[:,:cols-1] 102 | parameter.data = update 103 | update_b = final_w_and_b[:,-1] 104 | bias.data = update_b 105 | else: 106 | parameter.data = final_w_and_b 107 | 108 | 109 | 110 | 111 | def linf(parameter, bias=None, reg=0.01, lr=0.1): 112 | """L Infinity Regularization using proximal gradients over entire tensor""" 113 | 114 | if bias is not None: 115 | w_and_b = torch.squeeze(torch.cat((parameter, bias.unfold(0,1,1)),1), 0) 116 | else: 117 | w_and_b = torch.squeeze(parameter, 0) 118 | print("w_and_b:", w_and_b) 119 | sorted_w_and_b, indices = torch.sort(torch.abs(w_and_b), descending=True) 120 | print("sorted_w_and_b:", sorted_w_and_b) 121 | 122 | 123 | 124 | def l2(parameter, bias=None, reg=0.01, lr=0.1): 125 | """L2 Regularization over the entire parameter's values using proximal gradients""" 126 | 127 | if bias is not None: 128 | w_and_b = torch.cat((parameter, bias.unfold(0,1,1)),1) 129 | else: 130 | w_and_b = parameter 131 | L2 = reg # lambda: regularization strength 132 | Norm = (lr*L2/w_and_b.norm(2)) 133 | if Norm.is_cuda: 134 | ones_w = torch.ones(parameter.size(), device=torch.device("cuda")) 135 | else: 136 | ones_w = torch.ones(parameter.size(), device=torch.device("cpu")) 137 | l2T = 1.0 - torch.min(ones_w, Norm) 138 | update = (parameter*l2T) 139 | parameter.data = update 140 | # Update bias 141 | if bias is not None: 142 | if Norm.is_cuda: 143 | ones_b = torch.ones(bias.size(), device=torch.device("cuda")) 144 | else: 145 | ones_b = torch.ones(bias.size(), device=torch.device("cpu")) 146 | l2T = 1.0 - torch.min(ones_b, bias) 147 | update_b = (bias*l2T) 148 | bias.data = update_b 149 | 150 | def l1(parameter, bias=None, reg=0.01, lr=0.1): 151 | """L1 Regularization using Proximal Gradients""" 152 | Norm = reg*lr 153 | 154 | # Update W 155 | if parameter.is_cuda: 156 | Norms_w = Norm*torch.ones(parameter.size(), device=torch.device("cuda")) 157 | else: 158 | Norms_w = Norm*torch.ones(parameter.size(), device=torch.device("cpu")) 159 | pos = torch.min(Norms_w, Norm*torch.clamp(parameter, min=0)) # get all positive values 160 | neg = torch.min(Norms_w, -1.0*Norm*torch.clamp(parameter, max=0)) # get all negative values 161 | update_w = parameter - pos + neg # l1 step is the magnitude of all positive and all negative 162 | parameter.data = update_w 163 | 164 | if bias is not None: 165 | if bias.is_cuda: 166 | Norms_b = Norm*torch.ones(bias.size(), device=torch.device("cuda")) 167 | else: 168 | Norms_b = Norm*torch.ones(bias.size(), device=torch.device("cpu")) 169 | pos = torch.min(Norms_b, Norm*torch.clamp(bias, min=0)) 170 | neg = torch.min(Norms_b, -1.0*Norm*torch.clamp(bias, max=0)) 171 | update_b = bias - pos + neg 172 | bias.data = update_b 173 | 174 | def elasticnet(parameter, bias=None, reg=0.01, lr=0.1, gamma=1.0): 175 | """Elastic Net Regularization using Proximal Gradients. 176 | This is a linear combination of an l1 and a quadratic penalty.""" 177 | if gamma < 0.0: 178 | print("Warning, gamma should be positive. Otherwise you are not shrinking.") 179 | #TODO: Is gamma of 1.0 a good value? 180 | Norm = reg*lr*gamma 181 | l1(parameter, bias, reg, lr) 182 | update_w = (1.0/(1.0 + Norm))*parameter 183 | parameter.data = update_w 184 | if bias is not None: 185 | update_b = (1.0/(1.0 + Norm))*bias 186 | bias.data = update_b 187 | 188 | def logbarrier(parameter, bias=None, reg=0.01, lr=0.1): 189 | """Project onto logbarrier. Useful for minimization 190 | of f(x) when x >= b. 191 | F(A) = -log(det(A))""" 192 | Norm = reg*lr 193 | 194 | # Update W 195 | squared = torch.mul(parameter, parameter) 196 | squared = squared + 4*Norm 197 | squareroot = torch.sqrt(squared) 198 | update_w = (parameter + squareroot)/2.0 199 | parameter.data = update_w 200 | 201 | if bias is not None: 202 | squared = torch.mul(bias, bias) 203 | squared = squared + 4*Norm 204 | squareroot = torch.sqrt(squared) 205 | update_b = (bias + squareroot)/2.0 206 | bias.data = update_b 207 | 208 | 209 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from distutils.core import setup 3 | 4 | setup( 5 | name='proximal_gradient', 6 | version='0.1.0', 7 | author='Kenton Murray', 8 | author_email='kmurray4@nd.edu', 9 | packages=['proximal_gradient'], 10 | scripts=[], 11 | url='https://github.com/KentonMurray/ProxGradPytorch', 12 | license='LICENSE.txt', 13 | description='Proximal Gradient Methods for Pytorch', 14 | #long_description=open('README.md').read(), 15 | install_requires=[ 16 | "torch >= 0.4.0", 17 | ], 18 | setup_requires=['wheel'], 19 | ) 20 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentonMurray/ProxGradPytorch/c534a49142ac9ec149ca67de24bb0487fde1607b/tests/__init__.py -------------------------------------------------------------------------------- /tests/testit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import proximal_gradient.proximalGradient as pg 3 | 4 | class OneLayerNet(torch.nn.Module): 5 | def __init__(self, D_in, H, D_out): 6 | """ 7 | The constructor creates one linear layer and assigns it a name. 8 | """ 9 | super(OneLayerNet, self).__init__() 10 | self.linear1 = torch.nn.Linear(D_in, D_out) 11 | self.linear1.name = "linear1" 12 | 13 | def forward(self, x): 14 | """ 15 | Simple forward step 16 | """ 17 | y_pred = self.linear1(x) 18 | # Uncomment for verbose debugging 19 | #print("linear1:", self.linear1) 20 | #for param in self.linear1.parameters(): 21 | # print("param:", param) 22 | # print("param.grad:", param.grad) 23 | ##print("linear1.grad:", self.linear1.grad) 24 | ##print("linear1.grad:", self.linear1.data) 25 | return y_pred 26 | 27 | def build_model(): 28 | # Values for the network size 29 | N, D_in, H, D_out = 4, 3, 4, 2 30 | #N, D_in, H, D_out = 4, 3, 10, 5 31 | 32 | # Create random Tensors to hold inputs and outputs 33 | x = torch.zeros(N, D_in) 34 | y = torch.ones(N, D_out) 35 | print("x.requires_grad") 36 | print(x.requires_grad) 37 | 38 | # Construct our model by instantiating the class defined above 39 | model = OneLayerNet(D_in, H, D_out) 40 | print("model:", model) 41 | 42 | criterion = torch.nn.MSELoss(size_average=False) 43 | #optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) 44 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-1) 45 | return (x,y,model,criterion,optimizer) 46 | 47 | 48 | def test_l1(network): 49 | x, y, model, criterion, optimizer = network 50 | for t in range(10): 51 | # Forward pass: Compute predicted y by passing x to the model 52 | y_pred = model(x) 53 | 54 | # Compute and print loss 55 | cross_entropy_loss = criterion(y_pred, y) 56 | loss = cross_entropy_loss 57 | 58 | # Zero gradients, perform a backward pass, and update the weights. 59 | optimizer.zero_grad() 60 | loss.backward() 61 | # print("model.linear1.weight.grad:", model.linear1.weight.grad) 62 | # print("model.linear1.bias.grad:", model.linear1.bias.grad) 63 | # print("model.linear1.weight before:", model.linear1.weight) 64 | # print("model.linear1.bias before:", model.linear1.bias) 65 | optimizer.step() 66 | # print("model.linear1.weight after:", model.linear1.weight) 67 | # print("model.linear1.bias after:", model.linear1.bias) 68 | # print("model.linear1.weight.norm():", model.linear1.weight.norm()) 69 | 70 | #L1... 71 | print("weight before:", model.linear1.weight) 72 | pg.l1(model.linear1.weight, model.linear1.bias, reg=0.1) 73 | print("weight after:", model.linear1.weight) 74 | 75 | def test_l21(network): 76 | print("Testing l21") 77 | x, y, model, criterion, optimizer = network 78 | for t in range(10): 79 | # Forward pass: Compute predicted y by passing x to the model 80 | y_pred = model(x) 81 | 82 | # Compute and print loss 83 | cross_entropy_loss = criterion(y_pred, y) 84 | loss = cross_entropy_loss 85 | 86 | # Zero gradients, perform a backward pass, and update the weights. 87 | optimizer.zero_grad() 88 | loss.backward() 89 | optimizer.step() 90 | 91 | #print("weight before:", model.linear1.weight) 92 | #print("bias before:", model.linear1.bias) 93 | #pg.l21(model.linear1.weight, reg=0.1) # Use defaults 94 | pg.l21(model.linear1.weight, model.linear1.bias, reg=0.1) 95 | #pg.l21(model.linear1, reg=0.01) # Test different learning rates 96 | #pg.l21(model.linear1, reg=0.1) 97 | #pg.l21_slow(model.linear1.weight, reg=0.1) # Slow version to double check accuracy 98 | #print("weight after:", model.linear1.weight) 99 | #print("bias after:", model.linear1.bias) 100 | 101 | def test_l2(network): 102 | print("Testing l2") 103 | x, y, model, criterion, optimizer = network 104 | for t in range(10): 105 | # Forward pass: Compute predicted y by passing x to the model 106 | y_pred = model(x) 107 | 108 | # Compute and print loss 109 | cross_entropy_loss = criterion(y_pred, y) 110 | loss = cross_entropy_loss 111 | 112 | # Zero gradients, perform a backward pass, and update the weights. 113 | optimizer.zero_grad() 114 | loss.backward() 115 | optimizer.step() 116 | 117 | # Seeing about l2... 118 | #print("L2:", model.linear1.weight.norm(2)) 119 | #print("weight before:", model.linear1.weight) 120 | pg.l2(model.linear1.weight, model.linear1.bias, reg=0.1) 121 | #print("weight after:", model.linear1.weight) 122 | 123 | 124 | def test_linf1(network): 125 | x, y, model, criterion, optimizer = network 126 | for t in range(500): 127 | # Forward pass: Compute predicted y by passing x to the model 128 | y_pred = model(x) 129 | 130 | # Compute and print loss 131 | cross_entropy_loss = criterion(y_pred, y) 132 | loss = cross_entropy_loss 133 | 134 | # Zero gradients, perform a backward pass, and update the weights. 135 | optimizer.zero_grad() 136 | loss.backward() 137 | optimizer.step() 138 | 139 | #Linf1... 140 | print("weight before:", model.linear1.weight) 141 | print("bias before:", model.linear1.bias) 142 | pg.linf1(model.linear1.weight, model.linear1.bias, reg=0.1) 143 | print("weight after:", model.linear1.weight) 144 | print("bias after:", model.linear1.bias) 145 | 146 | 147 | def test_linf(network): 148 | x, y, model, criterion, optimizer = network 149 | for t in range(10): 150 | # Forward pass: Compute predicted y by passing x to the model 151 | y_pred = model(x) 152 | 153 | # Compute and print loss 154 | cross_entropy_loss = criterion(y_pred, y) 155 | loss = cross_entropy_loss 156 | 157 | # Zero gradients, perform a backward pass, and update the weights. 158 | optimizer.zero_grad() 159 | loss.backward() 160 | optimizer.step() 161 | 162 | #Linf 163 | print("weight before:", model.linear1.weight) 164 | pg.linf(model.linear1.weight, model.linear1.bias, reg=0.1) 165 | print("weight after:", model.linear1.weight) 166 | 167 | 168 | def test_elasticnet(network): 169 | print("Testing elasticnet") 170 | x, y, model, criterion, optimizer = network 171 | for t in range(10): 172 | # Forward pass: Compute predicted y by passing x to the model 173 | y_pred = model(x) 174 | 175 | # Compute and print loss 176 | cross_entropy_loss = criterion(y_pred, y) 177 | loss = cross_entropy_loss 178 | 179 | # Zero gradients, perform a backward pass, and update the weights. 180 | optimizer.zero_grad() 181 | loss.backward() 182 | optimizer.step() 183 | 184 | #Elastic Net 185 | #print("weight before:", model.linear1.weight) 186 | pg.elasticnet(model.linear1.weight, model.linear1.bias, reg=0.1) 187 | #print("weight after:", model.linear1.weight) 188 | 189 | def test_logbarrier(network): 190 | print("Testing logbarrier") 191 | x, y, model, criterion, optimizer = network 192 | for t in range(10): 193 | # Forward pass: Compute predicted y by passing x to the model 194 | y_pred = model(x) 195 | 196 | # Compute and print loss 197 | cross_entropy_loss = criterion(y_pred, y) 198 | loss = cross_entropy_loss 199 | 200 | # Zero gradients, perform a backward pass, and update the weights. 201 | optimizer.zero_grad() 202 | loss.backward() 203 | optimizer.step() 204 | #Log Barrier... 205 | #print("weight before:", model.linear1.weight) 206 | pg.logbarrier(model.linear1.weight, model.linear1.bias, reg=0.1) 207 | #print("weight after:", model.linear1.weight) 208 | 209 | 210 | def main(): 211 | network = build_model() 212 | test_l1(network) 213 | test_linf1(network) 214 | test_elasticnet(network) 215 | test_logbarrier(network) 216 | test_l2(network) 217 | test_l21(network) 218 | 219 | if __name__ == "__main__": 220 | main() 221 | --------------------------------------------------------------------------------