├── .gitignore ├── pip_requirements.txt ├── conda_requirements.txt ├── get_data.sh ├── LICENSE ├── paper_plots.tex ├── models.py ├── paper_experiments.sh ├── proxprop_pytorch.py ├── README.md ├── util.py ├── proxprop_plots.py └── ProxProp.py /.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pip_requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib2tikz 2 | -------------------------------------------------------------------------------- /conda_requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | pyyaml 4 | mkl 5 | setuptools 6 | cmake 7 | cffi 8 | -------------------------------------------------------------------------------- /get_data.sh: -------------------------------------------------------------------------------- 1 | # download and extract the CIFAR-10 dataset 2 | wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 3 | 4 | mkdir data 5 | current_dir=$(pwd) 6 | mv cifar-10-python.tar.gz "$current_dir/data/" 7 | cd "$current_dir/data/" 8 | tar xzf cifar-10-python.tar.gz 9 | rm cifar-10-python.tar.gz 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017-2018 Thomas Frerix and Thomas Möllenhoff 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /paper_plots.tex: -------------------------------------------------------------------------------- 1 | % TeX file to generate paper plots after the experiments have been run. 2 | 3 | \documentclass{article} 4 | \usepackage{graphicx} 5 | \usepackage[utf8]{inputenc} 6 | \usepackage{pgfplots} 7 | \begin{document} 8 | \begin{figure*}[t!] 9 | \centering 10 | \begin{tabular}{ccc} 11 | \resizebox{.31\linewidth}{!}{\input{paper_experiments/exact_vs_inexact_prox_mlp_plot.tex}} & 12 | \resizebox{.31\linewidth}{!}{\input{paper_experiments/inexact_prox_vs_sgd_mlp_plot.tex}} & 13 | \resizebox{.31\linewidth}{!}{\input{paper_experiments/exact_vs_inexact_prox_mlp_plot_val_acc.tex}} 14 | \end{tabular} 15 | \caption{Exact and inexact solvers for ProxProp compared with BackProp for an MLP on CIFAR-10.} 16 | \end{figure*} 17 | 18 | \begin{figure*}[t!] 19 | \centering 20 | \begin{tabular}{cc} 21 | \resizebox{.47\linewidth}{!}{\input{paper_experiments/proxprop_vs_sgd_adam_convnet_epochs_plot.tex}} & 22 | \resizebox{.47\linewidth}{!}{\input{paper_experiments/proxprop_vs_sgd_adam_convnet_time_plot.tex}} \\ 23 | \resizebox{.47\linewidth}{!}{\input{paper_experiments/proxprop_vs_sgd_adam_convnet_epochs_plot_val.tex}} & 24 | \resizebox{.47\linewidth}{!}{\input{paper_experiments/proxprop_vs_sgd_adam_convnet_time_plot_val.tex}} \\ 25 | \end{tabular} 26 | \caption{ProxProp as a first-order oracle in combination with the Adam optimizer for a convolutional neural network on CIFAR-10.} 27 | \end{figure*} 28 | \end{document} 29 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from operator import mul 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | import torch.utils.data as data 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | from ProxProp import ProxPropLinear, ProxPropConv2d 13 | 14 | class ProxPropConvNet(nn.Module): 15 | def __init__(self, input_size, num_classes, tau_prox, optimization_mode='prox_cg1'): 16 | super(ProxPropConvNet, self).__init__() 17 | img_dim = input_size[0] 18 | self.layers = nn.Sequential( 19 | ProxPropConv2d(img_dim, 16, kernel_size=5, tau_prox=tau_prox, padding=2, optimization_mode=optimization_mode), 20 | nn.ReLU(), 21 | nn.MaxPool2d(kernel_size=2, stride=2), 22 | ProxPropConv2d(16, 20, kernel_size=5, tau_prox=tau_prox, padding=2, optimization_mode=optimization_mode), 23 | nn.ReLU(), 24 | nn.MaxPool2d(kernel_size=2, stride=2), 25 | ProxPropConv2d(20, 20, kernel_size=5, tau_prox=tau_prox, padding=2, optimization_mode=optimization_mode), 26 | nn.ReLU() 27 | ) 28 | self.final_fc = nn.Linear(input_size[1]*input_size[2]//16 * 20 , 10) 29 | 30 | def forward(self, x): 31 | x = self.layers(x) 32 | return self.final_fc(x.view(x.size(0), -1)) 33 | 34 | 35 | class ProxPropMLP(nn.Module): 36 | def __init__(self, input_size, hidden_sizes, num_classes, tau_prox=1., optimization_mode='prox_cg1'): 37 | super(ProxPropMLP, self).__init__() 38 | input_size_flat = reduce(mul, input_size, 1) 39 | self.layers = [] 40 | self.layers.append(ProxPropLinear(input_size_flat, hidden_sizes[0], tau_prox=tau_prox, optimization_mode=optimization_mode)) 41 | for k, _ in enumerate(hidden_sizes[:-1]): 42 | self.layers.append(ProxPropLinear(hidden_sizes[k], hidden_sizes[k+1], tau_prox=tau_prox, optimization_mode=optimization_mode)) 43 | self.layers = nn.ModuleList(self.layers) 44 | self.final_fc = nn.Linear(hidden_sizes[-1], num_classes) 45 | self.relu = nn.ReLU() 46 | 47 | def forward(self, x): 48 | x = x.view(x.size(0), -1) 49 | for layer in self.layers: 50 | x = layer(x) 51 | x = self.relu(x) 52 | x = self.final_fc(x) 53 | return x 54 | 55 | -------------------------------------------------------------------------------- /paper_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | trap "exit" INT 4 | set -e 5 | 6 | current_date_time=`date "+%Y-%m-%d %H:%M:%S"` 7 | echo "Reproducing experiments of the paper Proximal Backpropagation at $current_date_time" 8 | echo "Checking your python version..." 9 | ret=`python -c 'import sys; print(sys.version_info >= (3,7))'` 10 | if [ $ret ]; then 11 | echo "Python version is >= 3.7" 12 | else 13 | echo "We require python version >= 3.7. Note we assume that python3 is aliased by python." 14 | fi 15 | 16 | echo "Checking your pdflatex installation..." 17 | command -v pdflatex >/dev/null && echo "Found pdflatex." || { echo >&2 "We need pdflatex to render the results, but it seems like it's not installed. Aborting."; exit 1; } 18 | 19 | package=pgfplots.sty 20 | kpsewhich $package >/dev/null && echo "Found $package" || { echo >&2 "We need the latex package $package to render the results, but it seems like it's not installed. Aborting."; exit 1; } 21 | 22 | package=inputenc.sty 23 | kpsewhich $package >/dev/null && echo "Found $package" || { echo >&2 "We need the latex package $package to render the results, but it seems like it's not installed. Aborting."; exit 1; } 24 | 25 | package=graphicx.sty 26 | kpsewhich $package >/dev/null && echo "Found $package" || { echo >&2 "We need the latex package $package to render the results, but it seems like it's not installed. Aborting."; exit 1; } 27 | 28 | echo "Starting experiments on GPU device 0..." 29 | OUT_DIR=paper_experiments 30 | mkdir -p $OUT_DIR 31 | 32 | python proxprop_pytorch.py --learning_rate 1 --tau_prox 5e-2 --optimization_mode prox_exact --outfile $OUT_DIR/MLP_NesterovSGD_exact_tauprox5e-2_lr1 --model MLP --optimizer sgd 33 | 34 | python proxprop_pytorch.py --learning_rate 5e-2 --num_epochs 1000 --optimization_mode gradient --outfile $OUT_DIR/MLP_NesterovSGD_gradient_lr5e-2 --model MLP --optimizer sgd 35 | 36 | python proxprop_pytorch.py --learning_rate 1 --tau_prox 5e-2 --optimization_mode prox_cg3 --outfile $OUT_DIR/MLP_NesterovSGD_cg3_tauprox5e-2_lr1 --model MLP --optimizer sgd 37 | 38 | python proxprop_pytorch.py --learning_rate 1 --tau_prox 5e-2 --optimization_mode prox_cg5 --outfile $OUT_DIR/MLP_NesterovSGD_cg5_tauprox5e-2_lr1 --model MLP --optimizer sgd 39 | 40 | python proxprop_pytorch.py --learning_rate 1 --tau_prox 5e-2 --optimization_mode prox_cg10 --outfile $OUT_DIR/MLP_NesterovSGD_cg10_tauprox5e-2_lr1 --model MLP --optimizer sgd 41 | 42 | python proxprop_pytorch.py --learning_rate 1e-3 --num_epochs 1000 --optimization_mode gradient --outfile $OUT_DIR/ConvNet_Adam_gradient_lr1e-3 --model ConvNet --optimizer adam 43 | 44 | python proxprop_pytorch.py --learning_rate 1e-3 --tau_prox 1 --optimization_mode prox_cg3 --outfile $OUT_DIR/ConvNet_Adam_cg3_tauprox1_lr1e-3 --model ConvNet --optimizer adam 45 | 46 | python proxprop_pytorch.py --learning_rate 1e-3 --tau_prox 1 --optimization_mode prox_cg10 --outfile $OUT_DIR/ConvNet_Adam_cg10_tauprox1_lr1e-3 --model ConvNet --optimizer adam 47 | 48 | echo "Finished running experiments..." 49 | 50 | echo "Extracting data..." 51 | python proxprop_plots.py > /dev/null 52 | 53 | echo "Compiling plots..." 54 | pdflatex paper_plots.tex > /dev/null 55 | 56 | echo "Tidying up..." 57 | rm paper_plots.aux paper_plots.log 58 | rm paper_experiments/*.tex 59 | 60 | current_date_time=`date "+%Y-%m-%d %H:%M:%S"` 61 | echo "Finished reproducing paper experiments at $current_date_time. The resulting plots are in paper_plots.pdf." 62 | -------------------------------------------------------------------------------- /proxprop_pytorch.py: -------------------------------------------------------------------------------- 1 | import pickle, gzip 2 | import sys 3 | import os 4 | import argparse 5 | import time 6 | import datetime 7 | from functools import reduce 8 | from operator import mul 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | import torch.utils.data as data 14 | import torch.nn.functional as F 15 | from util import load_dataset, train, compute_loss_and_accuracy 16 | from models import ProxPropMLP, ProxPropConvNet 17 | 18 | """ 19 | Define all default parameters here. 20 | """ 21 | model_default = 'ConvNet' 22 | optimizer_default = 'adam' 23 | batch_size_default = 500 24 | learning_rate_default = 1e-3 25 | weight_decay_default = 0. 26 | num_epochs_default = 50 27 | tau_prox_default = 1. 28 | momentum_default = 0.95 29 | nesterov_default = True 30 | optimization_mode_default = 'prox_cg1' 31 | use_cuda_default = True 32 | dataset_default = 'cifar-10' 33 | num_training_samples_default = -1 34 | 35 | outfile_default = '' 36 | cuda_device_default = 0 37 | 38 | # make these parameters parsable with above defined default values 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--dataset', type=str, default=dataset_default) 41 | parser.add_argument('--num_training_samples', type=int, default=num_training_samples_default) 42 | parser.add_argument('--optimization_mode', type=str, default=optimization_mode_default) 43 | parser.add_argument('--no-nesterov', dest='nesterov', action='store_false', default=nesterov_default) 44 | parser.add_argument('--momentum', type=float, default=momentum_default) 45 | parser.add_argument('--tau_prox', type=float, default=tau_prox_default) 46 | parser.add_argument('--num_epochs', type=int, default=num_epochs_default) 47 | parser.add_argument('--weight_decay', type=float, default=weight_decay_default) 48 | parser.add_argument('--learning_rate', type=float, default=learning_rate_default) 49 | parser.add_argument('--batch_size', type=int, default=batch_size_default) 50 | parser.add_argument('--optimizer', type=str, default=optimizer_default) 51 | parser.add_argument('--model', type=str, default=model_default) 52 | parser.add_argument('--outfile', type=str, default=outfile_default) 53 | parser.add_argument('--cuda_device', type=int, default=cuda_device_default) 54 | parser.add_argument('--no-cuda', dest='use_cuda', action='store_false', default=use_cuda_default) 55 | args = parser.parse_args() 56 | params = vars(args) 57 | 58 | ########################################################################################################################## 59 | x_train, y_train, x_val, y_val, x_test, y_test = load_dataset(params['dataset'], num_training_samples=params['num_training_samples']) 60 | 61 | params['train_size'] = x_train.shape 62 | params['val_size'] = x_val.shape 63 | params['test_size'] = x_test.shape 64 | 65 | print('x_train dims: ' + str(x_train.shape)) 66 | print('x_val dims: ' + str(x_val.shape)) 67 | print('x_test dims: ' + str(x_test.shape)) 68 | print('Maximum value of training set: ' + str(np.max(x_train))) 69 | print('Minimum value of training set: ' + str(np.min(x_train))) 70 | 71 | device = torch.device('cuda' if params['use_cuda'] else 'cpu') 72 | 73 | print('Training parameters:') 74 | for k,v in params.items(): 75 | print(k, v) 76 | 77 | input_size = x_train.shape[1:] 78 | if params['model'] == 'MLP': 79 | model = ProxPropMLP(input_size, hidden_sizes=[4000, 1000, 4000], num_classes=10, tau_prox=params['tau_prox'], optimization_mode=params['optimization_mode']).to(device) 80 | elif params['model'] == 'ConvNet': 81 | model = ProxPropConvNet(input_size, 10, params['tau_prox'], optimization_mode=params['optimization_mode']).to(device) 82 | else: 83 | raise ValueError('The model {} you have provided is not valid.'.format(params['model'])) 84 | 85 | print('model: \n' + str(model)) 86 | 87 | loss_fn = torch.nn.CrossEntropyLoss() 88 | if params['optimizer'] == 'sgd': 89 | optimizer = torch.optim.SGD(model.parameters(), lr=params['learning_rate'], momentum=params['momentum'], weight_decay=params['weight_decay'], nesterov=params['nesterov']) 90 | elif params['optimizer'] == 'adam': 91 | optimizer = torch.optim.Adam(model.parameters(), lr=params['learning_rate'], weight_decay=params['weight_decay']) 92 | else: 93 | raise ValueError('The optimizer {} you have provided is not valid.'.format(params['optimizer'])) 94 | 95 | data = x_train, y_train, x_val, y_val, x_test, y_test 96 | training_metrics = train(model, loss_fn, optimizer, data, params['num_epochs'], params['batch_size'], device=device) 97 | 98 | if params['outfile'] != '': 99 | pickle_data = {} 100 | pickle_data['timestamp'] =datetime.datetime.fromtimestamp(int(time.time())).strftime('%Y-%m-%d %H:%M:%S') 101 | pickle_data['params'] = params 102 | pickle_data['training_metrics'] = training_metrics 103 | pickle.dump(pickle_data, open( os.path.join(params['outfile'] + '.p'), "wb" ) ) 104 | 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Proximal Backpropagation 2 | ================ 3 | Proximal Backpropagation (ProxProp) is a neural network training algorithm that takes *implicit* instead of *explicit* gradient steps to update the network parameters. 4 | We have analyzed this algorithm in our ICLR 2018 paper: 5 | 6 | **Proximal Backpropagation** (Thomas Frerix, Thomas Möllenhoff, Michael Moeller, Daniel Cremers; ICLR 2018) [https://arxiv.org/abs/1706.04638] 7 | 8 | tl;dr 9 | ------------------- 10 | - We provide a PyTorch implementation of ProxProp for Python 3 and PyTorch 1.0.1. 11 | - The results of our paper can be reproduced by executing the script `paper_experiments.sh`. 12 | - ProxProp is implemented as a `torch.nn.Module` (a 'layer') and can be combined with any other layer and first-order optimizer. 13 | While a ProxPropConv2d and a ProxPropLinear layer already exist, you can generate a ProxProp layer for your favorite linear layer with one line of code. 14 | 15 | Installation 16 | ------------------- 17 | 1. Make sure you have a running Python 3 (tested with Python 3.7) ecosytem. We recommend that you use a [conda](https://conda.io/docs/) install, as this is also the recommended option to get the latest PyTorch running. 18 | For this README and for the scripts, we assume that you have `conda` running with Python 3.7. 19 | 2. Clone this repository and switch to the directory. 20 | 3. Install the dependencies via `conda install --file conda_requirements.txt` and `pip install -r pip_requirements.txt`. 21 | 4. Install [PyTorch](http://pytorch.org/) with magma support. 22 | We have tested our code with PyTorch 1.0.1 and CUDA 10.0. 23 | You can install this setup via 24 | ``` 25 | conda install -c pytorch magma-cuda100 26 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 27 | ``` 28 | 5. (optional, but necessary to reproduce paper experiments) Download the CIFAR-10 dataset by executing `get_data.sh` 29 | 30 | Training neural networks with ProxProp 31 | ------------------- 32 | ProxProp is implemented as a custom linear layer (`torch.nn.Module`) with its own backward pass to take implicit gradient steps on the network parameters. 33 | With this design choice it can be combined with any other layer, for which one takes explicit gradient steps. 34 | Furthermore, the resulting update direction can be used with any first-order optimizer that expects a suitable update direction in parameter space. 35 | In our [paper](https://arxiv.org/abs/1706.04638) we prove that ProxProp generates a descent direction and show experiments with Nesterov SGD and Adam. 36 | 37 | You can use our pre-defined layers `ProxPropConv2d` and `ProxPropLinear`, corresponding to `nn.Conv2d` and `nn.Linear`, by importing 38 | 39 | `from ProxProp import ProxPropConv2d, ProxPropLinear` 40 | 41 | Besides the usual layer parameters, as detailed in the [PyTorch docs](http://pytorch.org/docs/master/), you can provide: 42 | 43 | - `tau_prox`: step size for a proximal step; default is `tau_prox=1` 44 | - `optimization_mode`: can be one of `'prox_exact'`, `'prox_cg{N}'`, `'gradient'` for an exact proximal step, an approximate proximal step with `N` conjugate gradient steps and an explicit gradient step, respectively; default is `optimization_mode='prox_cg1'`. 45 | The `'gradient'` mode is for a fair comparison with SGD, as it incurs the same overhead as the other methods in exploiting a generic implementation with the provided PyTorch API. 46 | 47 | If you want to use ProxProp to optimize your favorite linear layer, you can generate the respective module with one line of code. 48 | As an example for the the `Conv3d` layer: 49 | 50 | ``` 51 | from ProxProp import proxprop_module_generator 52 | ProxPropConv3d = proxprop_module_generator(torch.nn.Conv3d) 53 | ``` 54 | 55 | This gives you a default implementation for the approximate conjugate gradient solver, which treats all parameters as a stacked vector. 56 | If you want to use the exact solver or want to use the conjugate gradient solver more efficiently, you have to provide the respective reshaping methods to `proxprop_module_generator`, as this requires specific knowledge of the layer's structure and cannot be implemented generically. 57 | As a template, take a look at the `ProxProp.py` file, where we have done this for the `ProxPropLinear` layer. 58 | 59 | By reusing the forward/backward implementations of existing PyTorch modules, ProxProp becomes readily accessible. 60 | However, we pay an overhead associated with generically constructing the backward pass using the PyTorch API. 61 | We have intentionally sided with genericity over speed. 62 | 63 | Reproduce paper experiments 64 | ------------------- 65 | To reproduce the paper experiments execute the script `paper_experiments.sh`. 66 | This will run our paper's experiments, store the results in the directory `paper_experiments/` and subsequently compile the results into the file `paper_plots.pdf`. 67 | We use an NVIDIA Titan X GPU; executing the script takes roughly 3 hours. 68 | 69 | Acknowledgement 70 | ------------------- 71 | We want to thank [Soumith Chintala](https://github.com/soumith) for helping us track down a mysterious bug and the whole PyTorch dev team for their continued development effort and great support to the community. 72 | 73 | Publication 74 | ------------------- 75 | If you use ProxProp, please acknowledge our paper by citing 76 | 77 | ``` 78 | @article{Frerix-et-al-18, 79 | title = {Proximal Backpropagation}, 80 | author={Thomas Frerix, Thomas Möllenhoff, Michael Moeller, Daniel Cremers}, 81 | journal={International Conference on Learning Representations}, 82 | year={2018}, 83 | url = {https://arxiv.org/abs/1706.04638} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import pickle 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | import torch.utils.data as data 9 | import torch.nn.functional as F 10 | 11 | def normalize_to_unit_interval(x, normalize_by = None): 12 | if normalize_by == None: 13 | max_val = np.max(x) 14 | return x / max_val, max_val 15 | else: 16 | return x / normalize_by, normalize_by 17 | 18 | def load_CIFAR_batch(filename): 19 | """ load single batch of cifar """ 20 | with open(filename, 'rb') as f: 21 | datadict = pickle.load(f, encoding='latin1') 22 | X = datadict['data'] 23 | Y = datadict['labels'] 24 | X = X.reshape(10000, 3, 32, 32).astype('float32') 25 | Y = np.array(Y) 26 | return X, Y 27 | 28 | 29 | def load_dataset(dataset, num_training_samples=-1): 30 | if dataset == 'cifar-10': 31 | datapath = 'data/cifar-10-batches-py' 32 | xs = [] 33 | ys = [] 34 | for b in range(1,6): 35 | f = os.path.join(datapath, 'data_batch_%d' % (b, )) 36 | X, Y = load_CIFAR_batch(f) 37 | xs.append(X) 38 | ys.append(Y) 39 | x_all = np.concatenate(xs) 40 | y_all = np.concatenate(ys) 41 | x_train = x_all[:-5000] 42 | x_val = x_all[-5000:] 43 | y_train = y_all[:-5000] 44 | y_val = y_all[-5000:] 45 | del X, Y 46 | x_test, y_test = load_CIFAR_batch(os.path.join(datapath, 'test_batch')) 47 | 48 | x_train, normalize_by = normalize_to_unit_interval(x_train) 49 | x_val, _ = normalize_to_unit_interval(x_val, normalize_by) 50 | x_test, _ = normalize_to_unit_interval(x_test, normalize_by) 51 | 52 | else: 53 | raise ValueError('Import for the dataset you have provided is not yet implemented.') 54 | 55 | if num_training_samples > 0: 56 | x_train = x_train[:num_training_samples] 57 | y_train = y_train[:num_training_samples] 58 | 59 | return x_train, y_train, x_val, y_val, x_test, y_test 60 | 61 | 62 | def compute_loss_and_accuracy(model, loss_fn, x, y, batch_size=64, device='cuda'): 63 | num_samples = x.shape[0] 64 | data = torch.utils.data.TensorDataset(torch.from_numpy(x), torch.from_numpy(y)) 65 | loader = torch.utils.data.DataLoader(dataset=data, batch_size=batch_size, shuffle=False) 66 | 67 | correct_samples = 0 68 | loss = 0. 69 | num_batches = 0 70 | model.train(False) 71 | with torch.no_grad(): 72 | for sample_x, sample_y in loader: 73 | num_batches += 1 74 | sample_x = sample_x.to(device) 75 | sample_y = sample_y.to(device) 76 | sample_out = model(sample_x) 77 | loss += loss_fn(sample_out, sample_y).item() 78 | _, y_pred = sample_out.max(dim=1) 79 | correct_samples += sample_y.numel() - torch.nonzero(y_pred - sample_y).numel() 80 | acc = float(correct_samples) / float(num_samples) 81 | loss = float(loss) / float(num_batches) 82 | model.train(True) 83 | return loss, acc 84 | 85 | def train(model, loss_fn, optimizer, data, num_epochs, batch_size, scheduler=None, device='cuda'): 86 | x_train, y_train, x_val, y_val, x_test, y_test = data 87 | train_data = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) 88 | train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True) 89 | 90 | start_time = time.time() 91 | training_metrics = {} 92 | training_metrics['minibatch_avg_loss'] = [] 93 | training_metrics['full_batch_loss'] = [] 94 | training_metrics['full_batch_acc'] = [] 95 | training_metrics['val_loss'] = [] 96 | training_metrics['val_acc'] = [] 97 | training_metrics['epoch_time'] = [] 98 | 99 | #initial loss and accuracies 100 | epoch_full_batch_loss, epoch_full_batch_training_acc = compute_loss_and_accuracy(model, loss_fn, x_train, y_train, batch_size=batch_size, device=device) 101 | val_loss, val_acc = compute_loss_and_accuracy(model, loss_fn, x_val, y_val, batch_size=batch_size, device=device) 102 | training_metrics['minibatch_avg_loss'].append(epoch_full_batch_loss) 103 | training_metrics['full_batch_loss'].append(epoch_full_batch_loss) 104 | training_metrics['full_batch_acc'].append(epoch_full_batch_training_acc) 105 | training_metrics['val_loss'].append(val_loss) 106 | training_metrics['val_acc'].append(val_acc) 107 | training_metrics['epoch_time'].append(0) 108 | 109 | 110 | for epoch in range(num_epochs): 111 | epoch_start_time = time.time() 112 | epoch_summed_loss = 0 113 | epoch_batch_counter = 0 114 | for x_batch, y_batch in train_loader: 115 | x = x_batch.to(device) 116 | y = y_batch.to(device) 117 | 118 | x_out= model(x) 119 | loss = loss_fn(x_out, y) 120 | optimizer.zero_grad() 121 | loss.backward() 122 | if scheduler is not None: 123 | scheduler.step(epoch=epoch) 124 | optimizer.step() 125 | epoch_summed_loss += loss.item() 126 | epoch_batch_counter += 1 127 | 128 | epoch_end_time = time.time() 129 | epoch_time = epoch_end_time - epoch_start_time 130 | training_metrics['epoch_time'].append(epoch_time) 131 | epoch_full_batch_loss, epoch_full_batch_training_acc = compute_loss_and_accuracy(model, loss_fn, x_train, y_train, batch_size=batch_size, device=device) 132 | val_loss, val_acc = compute_loss_and_accuracy(model, loss_fn, x_val, y_val, device=device) 133 | epoch_avg_loss = epoch_summed_loss / epoch_batch_counter 134 | training_metrics['minibatch_avg_loss'].append(epoch_avg_loss) 135 | training_metrics['full_batch_loss'].append(epoch_full_batch_loss) 136 | training_metrics['full_batch_acc'].append(epoch_full_batch_training_acc) 137 | training_metrics['val_loss'].append(val_loss) 138 | training_metrics['val_acc'].append(val_acc) 139 | print(str(epoch) + ': mini batch avg loss: ' + str(epoch_avg_loss) + ', full batch loss: ' + str(epoch_full_batch_loss) + ', epoch time: ' + str(epoch_time) + 's') 140 | print('Trained in {0} seconds.'.format(int(time.time() - start_time))) 141 | test_loss, test_acc = compute_loss_and_accuracy(model, loss_fn, x_test, y_test, batch_size=256, device=device) 142 | training_metrics['test_loss_acc'] = (test_loss, test_acc) 143 | print('Avg. test loss: {0}, avg. test accuracy: {1}'.format(test_loss, test_acc)) 144 | return training_metrics 145 | -------------------------------------------------------------------------------- /proxprop_plots.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate PGFPlots TeX output plots from raw plotting data input saved as pickle. 3 | """ 4 | import os 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import matplotlib.cm as cm 9 | import numpy as np 10 | from matplotlib2tikz import save as tikz_save 11 | import pickle 12 | from itertools import accumulate 13 | 14 | output_path = 'paper_experiments/' 15 | input_path = 'paper_experiments/' 16 | 17 | def load_data(filename, key, input_path=input_path): 18 | with open(os.path.join(input_path, filename), 'rb') as f: 19 | data = pickle.load(f) 20 | if key not in data['training_metrics']: 21 | raise ValueError('The provided key "{}" is not a key of the data dict.'.format(key)) 22 | else: 23 | return data['training_metrics'][key] 24 | 25 | #global properties of all plots - if not locally specified otherwise 26 | y_axis = 'val_acc' 27 | y_label = 'Validation Accuracy' 28 | label_fontsize = 'footnotesize' 29 | linewidth = 4 30 | axis_parameter_set = {'xticklabel style={font=\\%s}' % label_fontsize, 'yticklabel style={font=\\%s}' % label_fontsize, 'legend style={line width=%s, fill=gray!7}' % (linewidth/2)} 31 | 32 | """ 33 | Exact vs. inexact solves in epochs 34 | """ 35 | y_axis_local = 'full_batch_loss' 36 | plt.figure() 37 | outfile = 'exact_vs_inexact_prox_mlp_plot.tex' 38 | exact_data = load_data('MLP_NesterovSGD_exact_tauprox5e-2_lr1.p', y_axis_local) 39 | sgd_data = load_data('MLP_NesterovSGD_gradient_lr5e-2.p', y_axis_local) 40 | cg3_data = load_data('MLP_NesterovSGD_cg3_tauprox5e-2_lr1.p', y_axis_local) 41 | cg5_data = load_data('MLP_NesterovSGD_cg5_tauprox5e-2_lr1.p', y_axis_local) 42 | cg10_data = load_data('MLP_NesterovSGD_cg10_tauprox5e-2_lr1.p', y_axis_local) 43 | plt.title('CIFAR-10, 3072-4000-1000-4000-10 MLP') 44 | x = range(50) 45 | plt.xlabel('Epochs') 46 | plt.ylabel('Full Batch Training Loss') 47 | plt.plot(x, sgd_data[:50], lw=linewidth, label='BackProp') 48 | plt.plot(x, cg3_data[:50], lw=linewidth, label='ProxProp (cg3)') 49 | plt.plot(x, cg5_data[:50], lw=linewidth, label='ProxProp (cg5)') 50 | plt.plot(x, cg10_data[:50], lw=linewidth, label='ProxProp (cg10)') 51 | plt.plot(x, exact_data[:50], lw=linewidth, label='ProxProp (exact)') 52 | plt.legend(frameon=False) 53 | tikz_save(os.path.join(output_path, outfile), extra_axis_parameters=axis_parameter_set) 54 | 55 | """ 56 | Exact vs. inexact solves in epochs validation accuracy 57 | """ 58 | y_axis_local = 'val_acc' 59 | plt.figure() 60 | outfile = 'exact_vs_inexact_prox_mlp_plot_val_acc.tex' 61 | exact_data = load_data('MLP_NesterovSGD_exact_tauprox5e-2_lr1.p', y_axis_local) 62 | sgd_data = load_data('MLP_NesterovSGD_gradient_lr5e-2.p', y_axis_local) 63 | cg3_data = load_data('MLP_NesterovSGD_cg3_tauprox5e-2_lr1.p', y_axis_local) 64 | cg5_data = load_data('MLP_NesterovSGD_cg5_tauprox5e-2_lr1.p', y_axis_local) 65 | cg10_data = load_data('MLP_NesterovSGD_cg10_tauprox5e-2_lr1.p', y_axis_local) 66 | plt.title('CIFAR-10, 3072-4000-1000-4000-10 MLP') 67 | x = range(50) 68 | plt.xlabel('Epochs') 69 | plt.ylabel('Validation Accuracy') 70 | plt.plot(x, sgd_data[:50], lw=linewidth, label='BackProp') 71 | plt.plot(x, cg3_data[:50], lw=linewidth, label='ProxProp (cg3)') 72 | plt.plot(x, cg5_data[:50], lw=linewidth, label='ProxProp (cg5)') 73 | plt.plot(x, cg10_data[:50], lw=linewidth, label='ProxProp (cg10)') 74 | plt.plot(x, exact_data[:50], lw=linewidth, label='ProxProp (exact)') 75 | plt.legend(frameon=False) 76 | tikz_save(os.path.join(output_path, outfile), extra_axis_parameters=axis_parameter_set) 77 | 78 | """ 79 | Inexact solve vs. SGD in time 80 | """ 81 | y_axis_local = 'full_batch_loss' 82 | plt.figure() 83 | outfile = 'inexact_prox_vs_sgd_mlp_plot.tex' 84 | exact_data = load_data('MLP_NesterovSGD_exact_tauprox5e-2_lr1.p', y_axis_local) 85 | sgd_data = load_data('MLP_NesterovSGD_gradient_lr5e-2.p', y_axis_local) 86 | cg3_data = load_data('MLP_NesterovSGD_cg3_tauprox5e-2_lr1.p', y_axis_local) 87 | cg5_data = load_data('MLP_NesterovSGD_cg5_tauprox5e-2_lr1.p', y_axis_local) 88 | cg10_data = load_data('MLP_NesterovSGD_cg10_tauprox5e-2_lr1.p', y_axis_local) 89 | plt.title('CIFAR-10, 3072-4000-1000-4000-10 MLP') 90 | epoch_time_exact = load_data('MLP_NesterovSGD_exact_tauprox5e-2_lr1.p', 'epoch_time') 91 | epoch_time_sgd = load_data('MLP_NesterovSGD_gradient_lr5e-2.p', 'epoch_time') 92 | epoch_time_cg3 = load_data('MLP_NesterovSGD_cg3_tauprox5e-2_lr1.p', 'epoch_time') 93 | epoch_time_cg5 = load_data('MLP_NesterovSGD_cg5_tauprox5e-2_lr1.p', 'epoch_time') 94 | epoch_time_cg10 = load_data('MLP_NesterovSGD_cg10_tauprox5e-2_lr1.p', 'epoch_time') 95 | x_exact = list(accumulate(epoch_time_exact)) 96 | x_cg3 = list(accumulate(epoch_time_cg3)) 97 | x_cg5 = list(accumulate(epoch_time_cg5)) 98 | x_cg10 = list(accumulate(epoch_time_cg10)) 99 | x_sgd = list(accumulate(epoch_time_sgd)) 100 | plt.xlabel('Time [s]') 101 | plt.ylabel('Full Batch Training Loss') 102 | x_max = 5 * 60 103 | x_sgd = [x for x in x_sgd if x <= x_max] 104 | x_cg3 = [x for x in x_cg3 if x <= x_max] 105 | x_cg5 = [x for x in x_cg5 if x <= x_max] 106 | x_cg10 = [x for x in x_cg10 if x <= x_max] 107 | x_exact = [x for x in x_exact if x <= x_max] 108 | sgd_data = sgd_data[:len(x_sgd)] 109 | cg3_data = cg3_data[:len(x_cg3)] 110 | cg5_data = cg5_data[:len(x_cg5)] 111 | cg10_data = cg10_data[:len(x_cg10)] 112 | exact_data = exact_data[:len(x_exact)] 113 | plt.plot(x_sgd, sgd_data, lw=linewidth, label='BackProp') 114 | plt.plot(x_cg3, cg3_data, lw=linewidth, label='ProxProp (cg3)') 115 | plt.plot(x_cg5, cg5_data, lw=linewidth, label='ProxProp (cg5)') 116 | plt.plot(x_cg10, cg10_data, lw=linewidth, label='ProxProp (cg10)') 117 | plt.plot(x_exact, exact_data, lw=linewidth, label='ProxProp (exact)') 118 | plt.legend(frameon=False) 119 | tikz_save(os.path.join(output_path, outfile), extra_axis_parameters=axis_parameter_set) 120 | 121 | """ 122 | ConvNet: ProxProp vs. SGD directions with Adam full batch loss in epochs 123 | """ 124 | y_axis = 'full_batch_loss' 125 | y_label = 'Full Batch Loss' 126 | plt.figure() 127 | outfile = 'proxprop_vs_sgd_adam_convnet_epochs_plot.tex' 128 | cg3_data = load_data('ConvNet_Adam_cg3_tauprox1_lr1e-3.p', y_axis) 129 | cg10_data = load_data('ConvNet_Adam_cg10_tauprox1_lr1e-3.p', y_axis) 130 | sgd_data = load_data('ConvNet_Adam_gradient_lr1e-3.p', y_axis) 131 | plt.title('CIFAR-10, Convolutional Neural Network') 132 | x = range(len(cg3_data)) 133 | plt.xlabel('Epochs') 134 | plt.ylabel(y_label) 135 | plt.plot(x, sgd_data[:51], lw=linewidth, label='Adam + BackProp') 136 | plt.plot(x, cg3_data, lw=linewidth, label='Adam + ProxProp (3 cg)') 137 | plt.plot([], []) # null plot to advance the color cycler 138 | plt.plot(x, cg10_data, lw=linewidth, label='Adam + ProxProp (10 cg)') 139 | plt.legend(frameon=False) 140 | tikz_save(os.path.join(output_path, outfile), extra_axis_parameters=axis_parameter_set) 141 | 142 | """ 143 | ConvNet: ProxProp vs. SGD directions with Adam full batch loss in time 144 | """ 145 | plt.figure() 146 | outfile = 'proxprop_vs_sgd_adam_convnet_time_plot.tex' 147 | cg3_data = load_data('ConvNet_Adam_cg3_tauprox1_lr1e-3.p', y_axis) 148 | cg10_data = load_data('ConvNet_Adam_cg10_tauprox1_lr1e-3.p', y_axis) 149 | sgd_data = load_data('ConvNet_Adam_gradient_lr1e-3.p', y_axis) 150 | plt.title('CIFAR-10, Convolutional Neural Network') 151 | epoch_time_cg3 = load_data('ConvNet_Adam_cg3_tauprox1_lr1e-3.p', 'epoch_time') 152 | x_cg3 = list(accumulate(epoch_time_cg3)) 153 | max_cg3_time = max(x_cg3) 154 | 155 | epoch_time_sgd = load_data('ConvNet_Adam_gradient_lr1e-3.p', 'epoch_time') 156 | x_sgd = list(accumulate(epoch_time_sgd)) 157 | x_sgd = [x for x in x_sgd if x <= max_cg3_time] 158 | 159 | epoch_time_cg10 = load_data('ConvNet_Adam_cg10_tauprox1_lr1e-3.p', 'epoch_time') 160 | x_cg10 = list(accumulate(epoch_time_cg10)) 161 | x_cg10 = [x for x in x_cg10 if x <= max_cg3_time] 162 | 163 | sgd_data = sgd_data[:len(x_sgd)] 164 | cg3_data = cg3_data[:len(x_cg3)] 165 | cg10_data = cg10_data[:len(x_cg10)] 166 | 167 | plt.xlabel('Time [s]') 168 | plt.ylabel(y_label) 169 | plt.plot(x_sgd, sgd_data, lw=linewidth, label='Adam + BackProp') 170 | plt.plot(x_cg3, cg3_data, lw=linewidth, label='Adam + ProxProp (3 cg)') 171 | plt.plot([], []) # null plot to advance the color cycler 172 | plt.plot(x_cg10, cg10_data, lw=linewidth, label='Adam + ProxProp (10 cg)') 173 | plt.legend(frameon=False) 174 | tikz_save(os.path.join(output_path, outfile), extra_axis_parameters=axis_parameter_set) 175 | 176 | """ 177 | ConvNet: ProxProp vs. SGD directions with Adam validation accuracy in epochs 178 | """ 179 | y_axis = 'val_acc' 180 | y_label = 'Validation Accuracy' 181 | plt.figure() 182 | outfile = 'proxprop_vs_sgd_adam_convnet_epochs_plot_val.tex' 183 | cg3_data = load_data('ConvNet_Adam_cg3_tauprox1_lr1e-3.p', y_axis) 184 | cg10_data = load_data('ConvNet_Adam_cg10_tauprox1_lr1e-3.p', y_axis) 185 | sgd_data = load_data('ConvNet_Adam_gradient_lr1e-3.p', y_axis) 186 | plt.title('CIFAR-10, Convolutional Neural Network') 187 | x = range(len(cg3_data)) 188 | plt.xlabel('Epochs') 189 | plt.ylabel(y_label) 190 | plt.plot(x, sgd_data[:51], lw=linewidth, label='Adam + BackProp') 191 | plt.plot(x, cg3_data, lw=linewidth, label='Adam + ProxProp (3 cg)') 192 | plt.plot([], []) # null plot to advance the color cycler 193 | plt.plot(x, cg10_data, lw=linewidth, label='Adam + ProxProp (10 cg)') 194 | plt.legend(frameon=False) 195 | tikz_save(os.path.join(output_path, outfile), extra_axis_parameters=axis_parameter_set) 196 | 197 | """ 198 | ConvNet: ProxProp vs. SGD directions with Adam validation accuracy in time 199 | """ 200 | y_axis = 'val_acc' 201 | y_label = 'Validation Accuracy' 202 | plt.figure() 203 | outfile = 'proxprop_vs_sgd_adam_convnet_time_plot_val.tex' 204 | cg3_data = load_data('ConvNet_Adam_cg3_tauprox1_lr1e-3.p', y_axis) 205 | cg10_data = load_data('ConvNet_Adam_cg10_tauprox1_lr1e-3.p', y_axis) 206 | sgd_data = load_data('ConvNet_Adam_gradient_lr1e-3.p', y_axis) 207 | plt.title('CIFAR-10, Convolutional Neural Network') 208 | epoch_time_cg3 = load_data('ConvNet_Adam_cg3_tauprox1_lr1e-3.p', 'epoch_time') 209 | x_cg3 = list(accumulate(epoch_time_cg3)) 210 | max_cg3_time = max(x_cg3) 211 | 212 | epoch_time_sgd = load_data('ConvNet_Adam_gradient_lr1e-3.p', 'epoch_time') 213 | x_sgd = list(accumulate(epoch_time_sgd)) 214 | x_sgd = [x for x in x_sgd if x <= max_cg3_time] 215 | 216 | epoch_time_cg10 = load_data('ConvNet_Adam_cg10_tauprox1_lr1e-3.p', 'epoch_time') 217 | x_cg10 = list(accumulate(epoch_time_cg10)) 218 | x_cg10 = [x for x in x_cg10 if x <= max_cg3_time] 219 | 220 | sgd_data = sgd_data[:len(x_sgd)] 221 | cg3_data = cg3_data[:len(x_cg3)] 222 | cg10_data = cg10_data[:len(x_cg10)] 223 | 224 | plt.xlabel('Time [s]') 225 | plt.ylabel(y_label) 226 | plt.plot(x_sgd, sgd_data, lw=linewidth, label='Adam + BackProp') 227 | plt.plot(x_cg3, cg3_data, lw=linewidth, label='Adam + ProxProp (3 cg)') 228 | plt.plot([], []) # null plot to advance the color cycler 229 | plt.plot(x_cg10, cg10_data, lw=linewidth, label='Adam + ProxProp (10 cg)') 230 | plt.legend(frameon=False) 231 | tikz_save(os.path.join(output_path, outfile), extra_axis_parameters=axis_parameter_set) 232 | -------------------------------------------------------------------------------- /ProxProp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | import torch 4 | from torch.autograd import Function 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from functools import partial 8 | import numpy 9 | 10 | 11 | def conjugate_gradient_block(A, B, x0=None, tol=1e-2, maxit=None, eps=1e-6): 12 | """ 13 | Solve the linear system A X = B using conjugate gradient, where 14 | - A is an abstract linear operator implementing an n*n matrix 15 | - B is a matrix right hand size of size n*s 16 | - X0 is an initial guess of size n*s 17 | 18 | Essentially, this runs #s classical conjugate gradient algorithms in 'parallel', 19 | and terminates when the worst of the #s residuals is below tol. 20 | """ 21 | X = x0 22 | R = B - A(X) 23 | P = R 24 | Rs_old = torch.norm(R, 2., dim=0) ** 2. 25 | tol_scale = torch.mean(Rs_old) 26 | 27 | k = 0 28 | while True: 29 | k += 1 30 | AP = A(P) 31 | alpha = Rs_old / (torch.sum(P * AP, dim=0) + eps) 32 | X += P * alpha 33 | 34 | if k == maxit: 35 | break 36 | 37 | R -= AP * alpha 38 | 39 | Rs_new = torch.norm(R, 2., dim=0) ** 2. 40 | res = torch.max(Rs_new) 41 | if res < (tol ** 2.) * tol_scale: 42 | break 43 | 44 | P = R + P * (Rs_new / (Rs_old + eps)) 45 | Rs_old = Rs_new 46 | 47 | return X, k 48 | 49 | 50 | def optimization_step(A, Y, Z, beta, apply_cg, mode='prox_cg1'): 51 | """ 52 | Optimization step for several different modes: 53 | - 'prox_exact' takes an exact proximal step 54 | - 'prox_cgN' takes an approximate proximal step, generated by N conjugate gradient steps 55 | - 'gradient' performs a gradient step; this method recovers classical SGD 56 | 57 | Taking a proximal step amounts to solving: 58 | 59 | argmin_{X} 1/2 ||AX - Y||^2 + beta/2 ||X - Z||^2, 60 | 61 | i.e. solving a linear system (exactly or approximately). 62 | 63 | The method argument 'apply_cg' is the one taken from the respective ProxProp module. 64 | """ 65 | apply_A = partial(apply_cg, A) 66 | 67 | if mode == 'prox_exact': 68 | AA = A.t().mm(A) 69 | I = torch.eye(A.size(1)).type_as(A) 70 | A_tilde = AA + beta * I 71 | b_tilde = A.t().mm(Y) + beta*Z 72 | X, _ = torch.gesv(b_tilde, A_tilde) 73 | elif mode[:7] == 'prox_cg': 74 | num_prox_steps = int(mode[7:]) 75 | apply_A_tilde = lambda x : apply_A(x, 3) + beta*x 76 | b_tilde = apply_A(Y,2) + beta * Z 77 | res = conjugate_gradient_block(apply_A_tilde, b_tilde, x0=Z, maxit=num_prox_steps) 78 | X = res[0] 79 | elif mode == 'gradient': 80 | X = Z - (apply_A(apply_A(Z,1) - Y,2)) 81 | else: 82 | raise ValueError('The optimization mode "{}" you have specified is not valid.'.format(mode)) 83 | return X 84 | 85 | 86 | class ForwardBackwardFunctional(Function): 87 | """ 88 | Generic forward/backward functional that is used in any ProxProp module. 89 | """ 90 | @staticmethod 91 | def forward(ctx, *args): 92 | ctx.optimization_layer = args[-1] 93 | output = ctx.optimization_layer.apply_forward(args[0]) 94 | ctx.save_for_backward(*args[:-1], output) 95 | return output 96 | 97 | @staticmethod 98 | def backward(ctx, grad_z): 99 | input = ctx.saved_variables[0] 100 | params = list(ctx.saved_variables[1:-1]) 101 | output = ctx.saved_variables[-1] 102 | grad_input = None 103 | grad_params = [None] * len(params) 104 | layer = ctx.optimization_layer 105 | 106 | # explicit gradient step on z 107 | z_updated = output - grad_z 108 | 109 | # prox step or gradient step on the network parameters 110 | if layer.optimization_mode == 'prox_exact': 111 | A, Y, Z = layer.to_exact_solve_shape(input, z_updated, *params) 112 | else: 113 | A = input.detach() 114 | Y = z_updated.detach() 115 | Z = layer.to_cg_shape(params).detach() 116 | 117 | X_tensor = optimization_step(A, Y, Z, 1./layer.tau_prox, layer.apply_cg, mode=layer.optimization_mode) 118 | 119 | if 'prox_exact' == layer.optimization_mode: 120 | params_udpated = list(layer.from_exact_solve_shape(X_tensor).values()) 121 | else: 122 | params_udpated = list(layer.from_cg_shape(X_tensor).values()) 123 | 124 | # write difference in grad fields 125 | grad_params = [x[0] - x[1] for x in zip(params,params_udpated)] 126 | 127 | # explicit gradient step on a 128 | input.requires_grad_() 129 | with torch.enable_grad(): 130 | out_temp = ctx.optimization_layer.apply_forward(input) 131 | grad_temp = torch.autograd.grad(out_temp, input, grad_z) 132 | grad_input = grad_temp[0] 133 | return tuple([grad_input] + grad_params + [None]) 134 | 135 | 136 | def proxprop_module_generator(BaseModule, to_cg_shape=None, from_cg_shape=None, to_exact_solve_shape=None, from_exact_solve_shape=None): 137 | """ 138 | Creates a ProxProp module for any given linear BaseModule. 139 | 140 | The methods to_cg_shape and from_cg_shape convert to the conjugate gradient format and back. 141 | By default, they flatten the all variables to a vector. One may want to provide specific functions 142 | to use a special structure of the layer for better (block) conjugate gradient performance. 143 | The consistency of the implementation is automatically checked during initialization. 144 | 145 | The methods to_exact_solve_shape and from_exact_solve_shape cannot be provided by default, but are needed to take an exact proximal step. 146 | You can test your implementation for shape consistency by calling test_exact_solve_reshaping on a test input batch. 147 | 148 | The generated ProxProp module uses the BaseModule's forward method and can therefore leverage existing, efficient implementations. 149 | """ 150 | class ProxPropModule(BaseModule): 151 | def __init__(self, *args, **kwargs): 152 | if 'tau_prox' in kwargs: 153 | tau_prox_arg = [kwargs['tau_prox']] 154 | del kwargs['tau_prox'] 155 | else: 156 | tau_prox_arg = [1.] 157 | if 'optimization_mode' in kwargs: 158 | self.optimization_mode = kwargs['optimization_mode'] 159 | del kwargs['optimization_mode'] 160 | else: 161 | self.optimization_mode = 'prox_cg1' 162 | super().__init__(*args, **kwargs) 163 | self.register_buffer('tau_prox', torch.Tensor(tau_prox_arg)) 164 | self.forward_backward_functional = ForwardBackwardFunctional 165 | self._test_cg_reshaping() 166 | 167 | def _compare_two_params_dicts(self, d1, d2): 168 | assert len(d1.items()) == len(d2.items()) 169 | for name, p1 in d1.items(): 170 | p2 = d2[name] 171 | p1_np = p1.detach().cpu().numpy() 172 | p2_np = p2.detach().cpu().numpy() 173 | assert numpy.allclose(p1_np, p2_np) 174 | 175 | def _test_cg_reshaping(self): 176 | named_params_check = self.from_cg_shape(self.to_cg_shape([p for p in self.parameters()])) 177 | module_named_params = dict(self.named_parameters()) 178 | self._compare_two_params_dicts(module_named_params, named_params_check) 179 | 180 | def test_exact_solve_reshaping(self, x): 181 | """ 182 | Checks an implementation of to/from_exact_solve_shape and can be called with an 183 | input variable of type torch.autograd.Variable of the layer's input shape (including batch dimension). 184 | """ 185 | try: 186 | y = self.apply_forward(x) 187 | params = list(self.parameters()) 188 | A, Y, Z = self.to_exact_solve_shape(x, y, *params) 189 | named_params_check = self.from_exact_solve_shape(Z) 190 | module_named_params = dict(self.named_parameters()) 191 | self._compare_two_params_dicts(module_named_params, named_params_check) 192 | return True 193 | except NotImplementedError: 194 | print('At least one of the exact solve reshaping methods is not implemented.') 195 | return False 196 | except Exception as e: 197 | print(e) 198 | return False 199 | 200 | def forward(self, input): 201 | args = [input] + list(self.parameters()) + [self] 202 | return self.forward_backward_functional.apply(*args) 203 | 204 | def apply_forward(self, x): 205 | return super().forward(x) 206 | 207 | def apply_adjoint(self, forward_out, x): 208 | forward_out.backward(x) 209 | 210 | def to_cg_shape(self, params_list): 211 | """ 212 | Default implementation. Flattens all parameters to a vector. 213 | Expects the module's parameter data tensors in a list as provided by [p.data for p in self.parameters()]. 214 | Returns a tensor with dimensions expected by the conjugate gradient solver. 215 | """ 216 | return torch.cat([p.view(-1) for p in params_list]) 217 | 218 | def from_cg_shape(self, x): 219 | """ 220 | Default implementation. Assumes flattened parameters and reshapes to module parameter shape. 221 | Expects a tensor with shape used by the conjugate gradient solver. 222 | Returns an OrderedDict containing the the module's parameters in the in their native shape and as 223 | an nn.Parameter() object. 224 | """ 225 | prev_var_counter = 0 226 | params_cg = OrderedDict() 227 | for name, p in self.named_parameters(): 228 | n = p.numel() 229 | var_counter = prev_var_counter + n 230 | p_cg = x[prev_var_counter:var_counter].view(p.size()) 231 | params_cg[name] = torch.nn.Parameter(p_cg) 232 | prev_var_counter = var_counter 233 | return params_cg 234 | 235 | def to_exact_solve_shape(self, x, y, *params): 236 | """ 237 | Needs to be implemented for the exact solve of the proximal step. 238 | 239 | Prepares the tensors to solve the proximal step in the form argmin_{X} 1/2 ||AX - Y||^2 + beta/2 ||X - Z||^2, 240 | where X are the updated parameters to solve for. 241 | Expects the current input data x, the already updated non-linear activations from the above layer y and an 242 | argument list of module parameters *params. 243 | Returns (A, Y, Z), where Y are the already updated non-linear activations from the layer above in the right shape. 244 | """ 245 | raise NotImplementedError 246 | 247 | def from_exact_solve_shape(self, exact_solve_out): 248 | """ 249 | Needs to be implemented for the exact solve of the proximal step. 250 | 251 | Reshapes the output of the exact solve method to the native parameter shape for the parameters 252 | as nn.Parameter() objects stored in an OrderedDict. 253 | """ 254 | raise NotImplementedError 255 | 256 | def apply_cg(self, A, x, mode): 257 | """ 258 | Abstract linear operator used for the conjugate gradient solver. 259 | Returns Ax for mode=1, A^Tx for mode=2 and A^T(Ax) for mode=3. 260 | 261 | This method uses the efficient forward implementation of the BaseModule. 262 | The tradeoff for this generic implementation is that we have to assign temporary values to the module's parameters. 263 | """ 264 | if mode == 1: 265 | params_backup = self._parameters 266 | self._parameters = self.from_cg_shape(x) 267 | with torch.enable_grad(): 268 | output = self.apply_forward(A) 269 | self._parameters = params_backup 270 | return output 271 | 272 | elif mode == 2: 273 | self.zero_grad() 274 | with torch.enable_grad(): 275 | self.apply_adjoint(self.apply_forward(A), x) 276 | result = self.to_cg_shape([p.grad for p in self.parameters()]) 277 | self.zero_grad() 278 | return result 279 | 280 | elif mode == 3: 281 | params_backup = self._parameters 282 | self._parameters = self.from_cg_shape(x) 283 | self.zero_grad() 284 | with torch.enable_grad(): 285 | forward_out = self.apply_forward(A.requires_grad_()) 286 | self.apply_adjoint(forward_out, forward_out) 287 | result = self.to_cg_shape([p.grad for p in self.parameters()]) 288 | self.zero_grad() 289 | self._parameters = params_backup 290 | return result 291 | 292 | else: 293 | raise ValueError('Mode {} is not valid. Provide 1 for Ax and 2 for A^Tx.'.format(mode)) 294 | 295 | ProxPropModule.__name__ = 'ProxProp_{}'.format(BaseModule.__name__) 296 | proxprop_module = ProxPropModule 297 | 298 | if to_cg_shape is not None: 299 | setattr(proxprop_module, 'to_cg_shape', to_cg_shape) 300 | 301 | if from_cg_shape is not None: 302 | setattr(proxprop_module, 'from_cg_shape', from_cg_shape) 303 | 304 | if to_exact_solve_shape is not None: 305 | setattr(proxprop_module, 'to_exact_solve_shape', to_exact_solve_shape) 306 | 307 | if from_exact_solve_shape is not None: 308 | setattr(proxprop_module, 'from_exact_solve_shape', from_exact_solve_shape) 309 | 310 | return proxprop_module 311 | 312 | 313 | # generate ProxProp Conv2d module 314 | ProxPropConv2d = proxprop_module_generator(nn.Conv2d) 315 | 316 | 317 | # generate ProxProp Linear module 318 | def linear_to_cg_shape(self, params_list): 319 | """ 320 | Reshape to use the matrix version of the conjugate gradient solver. 321 | """ 322 | W = params_list[0] 323 | b = params_list[1] 324 | return torch.cat((W.t(),torch.unsqueeze(b,0).type_as(W)),0) 325 | 326 | def linear_from_cg_shape(self, x): 327 | """ 328 | Reshape the output of the matrix version of the conjugate gradient solver. 329 | """ 330 | params_cg = OrderedDict() 331 | if self.bias is None: 332 | params_cg['weight'] = torch.nn.Parameter(x) 333 | else: 334 | W, b = torch.split(x, self.weight.size(1), dim=0) 335 | params_cg['weight'] = torch.nn.Parameter(W.t()) 336 | params_cg['bias'] = torch.nn.Parameter(torch.squeeze(b)) 337 | return params_cg 338 | 339 | def linear_to_exact_solve_shape(self, x, z_updated, W, b): 340 | """ 341 | Suitable reshape for the exact solve method. 342 | """ 343 | Z = torch.cat((W.t(),torch.unsqueeze(b,0).type_as(W)),0) 344 | A = torch.cat((x, torch.ones(x.size(0),1).type_as(x)),1) 345 | return A, z_updated, Z 346 | 347 | def linear_from_exact_solve_shape(self, exact_solve_out): 348 | """ 349 | Reshape from exact solve method. 350 | """ 351 | params_out = OrderedDict() 352 | if self.bias is None: 353 | params_out['weight'] = nn.Parameter(exact_solve_out.t()) 354 | else: 355 | W, b = torch.split(exact_solve_out, self.weight.size(1), dim=0) 356 | params_out['weight'] = nn.Parameter(W.t()) 357 | params_out['bias'] = nn.Parameter(b.squeeze()) 358 | return params_out 359 | 360 | ProxPropLinear = proxprop_module_generator(nn.Linear, to_cg_shape=linear_to_cg_shape, from_cg_shape=linear_from_cg_shape, to_exact_solve_shape=linear_to_exact_solve_shape, from_exact_solve_shape=linear_from_exact_solve_shape) 361 | --------------------------------------------------------------------------------