├── .coveragerc ├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── chop ├── __init__.py ├── adversary.py ├── constraints.py ├── optim.py ├── penalties.py ├── stochastic.py └── utils │ ├── __init__.py │ ├── data.py │ ├── image.py │ ├── logging.py │ └── utils.py ├── doc ├── Makefile ├── conf.py ├── index.rst ├── make.bat └── sphinx_ext │ └── github_link.py ├── environment.yml ├── examples ├── README.txt ├── adversarial_robustness │ ├── README.txt │ ├── attack_benchmark.py │ ├── plot_train_robust_cifar10.py │ ├── plot_universal_adversarial_examples.py │ └── plot_visualizing_adversarial_attacks.py ├── plot_bounded_cone.py ├── plot_logistic_regression_L2_penalized.py ├── plot_optim_dynamics.py ├── plot_robust_PCA.py ├── plot_stochastic_dynamics.py ├── training_L1_constrained_net_on_CIFAR10.py └── training_constrained_net_on_mnist.py ├── pyproject.toml ├── setup.py └── tests ├── __init__.py ├── test_adversary.py ├── test_constraints.py ├── test_optim.py ├── test_penalties.py ├── test_stochastic.py └── test_utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = chop 4 | 5 | [report] 6 | exclude_lines = 7 | pragma: no cover 8 | def __repr__ 9 | if self.debug: 10 | if settings.DEBUG 11 | raise AssertionError 12 | raise NotImplementedError 13 | if 0: 14 | if __name__ == .__main__.: 15 | if verbose: 16 | 17 | omit = 18 | **/utils/data.py 19 | **/utils/logging.py 20 | **/utils/image.py 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/vscode,python,jupyternotebooks 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=vscode,python,jupyternotebooks 4 | 5 | ### JupyterNotebooks ### 6 | # gitignore template for Jupyter Notebooks 7 | # website: http://jupyter.org/ 8 | 9 | .ipynb_checkpoints 10 | */.ipynb_checkpoints/* 11 | 12 | # IPython 13 | profile_default/ 14 | ipython_config.py 15 | 16 | # Remove previous ipynb_checkpoints 17 | # git rm -r .ipynb_checkpoints/ 18 | 19 | ### Python ### 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | pip-wheel-metadata/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | *.py,cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | pytestdebug.log 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | doc/_build/ 94 | 95 | # PyBuilder 96 | target/ 97 | 98 | # Jupyter Notebook 99 | 100 | # IPython 101 | 102 | # pyenv 103 | .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | pythonenv* 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # profiling data 154 | .prof 155 | 156 | ### vscode ### 157 | .vscode/ 158 | !.vscode/settings.json 159 | !.vscode/tasks.json 160 | !.vscode/launch.json 161 | !.vscode/extensions.json 162 | *.code-workspace 163 | 164 | # End of https://www.toptal.com/developers/gitignore/api/vscode,python,jupyternotebooks 165 | 166 | doc/ 167 | data/ 168 | models/ 169 | logging/ 170 | runs/ 171 | **/plots/ 172 | *.lprof 173 | *.png 174 | *.zip 175 | *.csv 176 | *.json 177 | *.pt -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - "3.6" 5 | - "3.7" 6 | - "3.8" 7 | - "3.9" 8 | 9 | before_install: 10 | # Here we just install Miniconda, which you shouldn't have to change. 11 | - wget http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 12 | - chmod +x miniconda.sh 13 | - ./miniconda.sh -b -p $HOME/miniconda 14 | - source "$HOME/miniconda/etc/profile.d/conda.sh" 15 | - hash -r 16 | - conda config --set always_yes yes --set changeps1 no 17 | - conda update -q conda 18 | - conda info -a 19 | 20 | install: 21 | # Environment setup. 22 | - conda env create -f environment.yml python=$TRAVIS_PYTHON_VERSION 23 | - conda activate chop 24 | - python setup.py develop 25 | # The following is only needed for examples and coverage tests 26 | - conda install matplotlib 27 | - pip install tqdm 28 | - pip install tensorboardX cox requests 29 | - pip install advertorch copt 30 | - pip install git+https://github.com/RobustBench/robustbench 31 | - pip install coveralls coverage pytest-cov 32 | # For sphinx 33 | - pip install sphinx sphinx-gallery 34 | - pip install memory_profiler 35 | 36 | script: 37 | - py.test -v --cov=chop 38 | 39 | after_success: 40 | - coveralls 41 | 42 | cache: 43 | directories: 44 | - $HOME/chop_data/ 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | New BSD License 2 | 3 | Copyright (c) 2020-2030 Geoffrey Négiar, Fabian Pedregosa. 4 | All rights reserved. 5 | 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | a. Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | b. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | c. Neither the name of the Developers nor the names of 16 | its contributors may be used to endorse or promote products 17 | derived from this software without specific prior written 18 | permission. 19 | 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 29 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 30 | OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 31 | DAMAGE. 32 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.rst 2 | recursive-include doc * 3 | recursive-include tests *.py 4 | recursive-include examples *.py README.txt 5 | include COPYING 6 | include README.md 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorCH OPtimize (CHOP): a library for continuous and constrained optimization built on PyTorch 2 | 3 | ...with applications to adversarially attacking and training neural networks. 4 | 5 | [![Build Status](https://travis-ci.org/openopt/chop.svg?branch=master)](https://travis-ci.org/openopt/chop) 6 | [![Coverage Status](https://coveralls.io/repos/github/openopt/chop/badge.svg?branch=master)](https://coveralls.io/github/openopt/chop?branch=master) 7 | [![DOI](https://zenodo.org/badge/310693245.svg)](https://zenodo.org/badge/latestdoi/310693245) 8 | 9 | :warning: This library is not actively maintained anymore, and I won't be handling new issues in a timely manner. Contact me if you'd like to contribute. :warning: 10 | 11 | ## Stochastic Algorithms 12 | 13 | We define stochastic optimizers in the `chop.stochastic` module. These follow PyTorch Optimizer conventions, similar to the `torch.optim` module. 14 | These can be used to 15 | - train structured models; 16 | - compute universal adversarial perturbations over a dataset. 17 | 18 | ## Full Gradient Algorithms 19 | 20 | We also define full-gradient algorithms which operate on a batch of optimization problems in the `chop.optim` module. These are used for adversarial attacks, using the `chop.Adversary` wrapper. 21 | 22 | ## Installing 23 | 24 | Run the following: 25 | 26 | ``` 27 | pip install chop-pytorch 28 | ``` 29 | or 30 | ``` 31 | pip install git+https://github.com/openopt/chop.git 32 | ``` 33 | for the latest development version. 34 | 35 | Welcome to `chop`! 36 | 37 | ## Examples: 38 | 39 | See `examples` directory and our [webpage](http://openo.pt/chop/auto_examples/index.html). 40 | 41 | ## Tests 42 | 43 | Run the tests with `pytests tests`. 44 | 45 | ## Citing 46 | 47 | If this software is useful to your research, please consider citing it as 48 | 49 | ``` 50 | @article{chop, 51 | author = {Geoffrey Negiar, Fabian Pedregosa}, 52 | title = {CHOP: continuous optimization built on Pytorch}, 53 | year = 2020, 54 | url = {https://github.com/openopt/chop} 55 | } 56 | ``` 57 | 58 | ## Affiliations 59 | 60 | Geoffrey Négiar was in the Mahoney lab and the El Ghaoui lab at UC Berkeley at the time this package was developped. 61 | 62 | Fabian Pedregosa is at Google Research. 63 | -------------------------------------------------------------------------------- /chop/__init__.py: -------------------------------------------------------------------------------- 1 | """chop: constrained optimization for PyTorch""" 2 | __version__ = "0.0.3" 3 | 4 | from . import utils 5 | from . import stochastic 6 | from . import optim 7 | from . import constraints 8 | from . import penalties 9 | from . import adversary 10 | from .adversary import Adversary 11 | -------------------------------------------------------------------------------- /chop/adversary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adversary utility classes. 3 | ============================ 4 | Contains classes for generating adversarial examples and evaluating models 5 | on adversarial examples. 6 | """ 7 | import torch 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | 12 | from chop import utils 13 | 14 | 15 | class Adversary: 16 | """ 17 | Class for generating adversarial examples given a model and data. 18 | """ 19 | def __init__(self, method): 20 | """ 21 | Args: 22 | method: callable 23 | Optimization method to be used by the adversary. 24 | """ 25 | self.method = method 26 | 27 | def perturb(self, data, target, model, criterion, 28 | max_iter=20, 29 | use_best=False, 30 | initializer=None, 31 | callback=None, 32 | *optimizer_args, 33 | **optimizer_kwargs): 34 | """Perturbs the batch of datapoints with true label target, 35 | using specified optimization method. 36 | 37 | Args: 38 | data: torch.Tensor shape: (batch_size, *) 39 | batch of datapoints 40 | 41 | target: torch.Tensor shape: (batch_size,) 42 | 43 | model: torch.nn.Module 44 | model to attack 45 | 46 | max_iter: int 47 | Maximum number of iterations for the optimization method. 48 | 49 | use_best: bool 50 | if True, Return best perturbation so far. 51 | Otherwise, return the last perturbation obtained. 52 | 53 | initializer: callable (optional) 54 | callable which returns a starting point. 55 | Typically a random generator on the constraint set. 56 | Takes shape as only argument. 57 | 58 | callback: callable (optional) 59 | called at each iteration of the optimization method. 60 | 61 | *optimizer_args: tuple 62 | extra arguments for the optimization method 63 | 64 | *optimizer_kwargs: dict 65 | extra keyword arguments for the optimization method 66 | 67 | Returns: 68 | adversarial_loss: torch.Tensor of shape (batch_size,) 69 | vector of losses obtained on the batch 70 | 71 | delta: torch.Tensor of shape (batch_size, *) 72 | perturbation found""" 73 | 74 | device = data.device 75 | batch_size = data.size(0) 76 | 77 | @utils.closure 78 | def loss(delta): 79 | return -criterion(model(data + delta), target) 80 | 81 | if initializer is None: 82 | delta0 = torch.zeros_like(data, device=device) 83 | 84 | else: 85 | delta0 = initializer(data.shape) 86 | 87 | class UseBest: 88 | def __init__(self): 89 | self.best = torch.zeros_like(data, device=device) 90 | self.best_loss = -np.inf * torch.ones(batch_size, device=device) 91 | 92 | def __call__(self, kwargs): 93 | mask = (-kwargs['fval'] > self.best_loss) 94 | self.best_loss[mask] = -kwargs['fval'][mask] 95 | self.best[mask] = kwargs['x'][mask].detach().clone() 96 | 97 | if callback is not None: 98 | return callback(kwargs) 99 | 100 | 101 | cb = UseBest() if use_best else callback 102 | 103 | sol = self.method(loss, delta0, max_iter=max_iter, 104 | *optimizer_args, callback=cb, 105 | **optimizer_kwargs) 106 | 107 | if use_best: 108 | return cb.best_loss, cb.best 109 | 110 | return -sol.fval, sol.x 111 | 112 | def attack_dataset(self, loader, model, criterion, 113 | step=None, max_iter=20, 114 | use_best=False, 115 | initializer=None, 116 | callback=None, 117 | verbose=1, 118 | device=None, 119 | *optimizer_args, 120 | **optimizer_kwargs): 121 | 122 | """Returns a generator of losses, perturbations over 123 | loader.""" 124 | 125 | iterator = enumerate(loader) 126 | if verbose == 1: 127 | iterator = tqdm(iterator, total=len(iterator)) 128 | 129 | for k, (data, target) in iterator: 130 | data.to(device) 131 | target.to(device) 132 | 133 | raise NotImplementedError("The optimization method needs to take " 134 | "arguments which may differ per " 135 | "datapoint.") 136 | adv_loss, delta = self.perturb(data, target, model, criterion, step, 137 | max_iter, use_best, initializer, callback, 138 | *optimizer_args, **optimizer_kwargs) 139 | 140 | def run_evaluation(self, loader, model, criterion): 141 | raise NotImplementedError() 142 | # for adv_loss, delta in self.attack_dataset(loader, model, criterion,): -------------------------------------------------------------------------------- /chop/constraints.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constraints. 3 | =========== 4 | This module contains classes representing constraints. 5 | The methods on each constraint object function batch-wise. 6 | Reshaping will be of order if the constraints are used on the parameters of a model. 7 | This uses an API similar to the one for 8 | the COPT project, https://github.com/openopt/copt. 9 | Part of this code is adapted from https://github.com/ZIB-IOL.""" 10 | 11 | from copy import deepcopy 12 | from collections import defaultdict 13 | import warnings 14 | import dataclasses 15 | import torch 16 | 17 | import numpy as np 18 | from scipy.stats import expon 19 | from torch.distributions import Laplace, Normal 20 | from chop import utils 21 | 22 | 23 | @torch.no_grad() 24 | def get_avg_init_norm(layer, param_type=None, ord=2, repetitions=100): 25 | """Computes the average norm of default layer initialization""" 26 | output = 0 27 | for _ in range(repetitions): 28 | layer.reset_parameters() 29 | warnings.warn("torch.norm is deprecated. Think about updating this.") 30 | output += torch.norm(getattr(layer, param_type), p=ord).item() 31 | return float(output) / repetitions 32 | 33 | 34 | def is_bias(name, param): 35 | return ('bias' in name) or (param.ndim < 2) 36 | 37 | 38 | @torch.no_grad() 39 | def make_model_constraints(model, ord=2, value=300, mode='initialization', constrain_bias=False): 40 | """Create Ball constraints for each layer of model. Ball radius depends on mode (either radius or 41 | factor to multiply average initialization norm with)""" 42 | constraints = [] 43 | 44 | # Compute average init norms if necessary 45 | init_norms = dict() 46 | 47 | if (ord == 'nuc') and constrain_bias: 48 | msg = "'nuc' constraints cannot constrain bias." 49 | warnings.warn(msg) 50 | constrain_bias = False 51 | 52 | if mode == 'initialization': 53 | for layer in model.modules(): 54 | if hasattr(layer, 'reset_parameters'): 55 | for param_type in [entry for entry in ['weight', 'bias'] if (hasattr(layer, entry) and 56 | type(getattr(layer, entry)) != type( 57 | None))]: 58 | param = getattr(layer, param_type) 59 | shape = param.shape 60 | # TODO: figure out how to set the constraint size for NuclearNormBall constraint 61 | avg_norm = get_avg_init_norm(layer, param_type=param_type, ord=2) 62 | if avg_norm == 0.0: 63 | # Catch unlikely case that weight/bias is 0-initialized (e.g. BatchNorm does this) 64 | avg_norm = 1.0 65 | init_norms[shape] = avg_norm 66 | 67 | for name, param in model.named_parameters(): 68 | if is_bias(name, param): 69 | constraint = None 70 | else: 71 | print(name) 72 | if mode == 'radius': 73 | alpha = value 74 | elif mode == 'initialization': 75 | alpha = value * init_norms[param.shape] 76 | else: 77 | msg = f"Unknown mode {mode}." 78 | raise ValueError(msg) 79 | if (type(ord) == int) or (ord == np.inf): 80 | constraint = make_LpBall(alpha, p=ord) 81 | elif ord == 'nuc': 82 | constraint = NuclearNormBall(alpha) 83 | else: 84 | msg = f"ord {ord} is not supported." 85 | raise ValueError(msg) 86 | constraints.append(constraint) 87 | return constraints 88 | 89 | 90 | @torch.no_grad() 91 | def make_feasible(model, proxes): 92 | """ 93 | Projects all parameters of model onto the associated constraint set, 94 | using its prox operator (really a projection here). 95 | This function operates in-place. 96 | 97 | Args: 98 | model: torch.nn.Module 99 | Model to make feasible 100 | prox: [callable] 101 | List of projection operators 102 | """ 103 | for param, prox in zip(model.parameters(), proxes): 104 | if prox is not None: 105 | param.copy_(prox(param.unsqueeze(0)).squeeze(0)) 106 | 107 | @torch.no_grad() 108 | def euclidean_proj_simplex(v, s=1.): 109 | r""" Compute the Euclidean projection on a positive simplex 110 | Solves the optimization problem (using the algorithm from [1]): 111 | ..math:: 112 | min_w 0.5 * || w - v ||_2^2 , s.t. \sum_i w_i = s, w_i >= 0 113 | Parameters 114 | ---------- 115 | v: (n,) numpy array, 116 | n-dimensional vector to project 117 | s: float, optional, default: 1, 118 | radius of the simplex 119 | Returns 120 | ------- 121 | w: (n,) numpy array, 122 | Euclidean projection of v on the simplex 123 | Notes 124 | ----- 125 | The complexity of this algorithm is in O(n log(n)) as it involves sorting v. 126 | Better alternatives exist for high-dimensional sparse vectors (cf. [1]) 127 | However, this implementation still easily scales to millions of dimensions. 128 | References 129 | ---------- 130 | [1] Efficient Projections onto the .1-Ball for Learning in High Dimensions 131 | John Duchi, Shai Shalev-Shwartz, Yoram Singer, and Tushar Chandra. 132 | International Conference on Machine Learning (ICML 2008) 133 | http://www.cs.berkeley.edu/~jduchi/projects/DuchiSiShCh08.pdf 134 | """ 135 | assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s 136 | (n,) = v.shape 137 | # check if we are already on the simplex 138 | if v.sum() == s and (v >= 0).all(): 139 | return v 140 | # get the array of cumulative sums of a sorted (decreasing) copy of v 141 | u, _ = torch.sort(v, descending=True) 142 | cssv = torch.cumsum(u, dim=-1) 143 | # get the number of > 0 components of the optimal solution 144 | rho = (u * torch.arange(1, n + 1, device=v.device) > (cssv - s)).sum() - 1 145 | # compute the Lagrange multiplier associated to the simplex constraint 146 | theta = (cssv[rho] - s) / (rho + 1.0) 147 | # compute the projection by thresholding v using theta 148 | w = torch.clamp(v - theta, min=0) 149 | return w 150 | 151 | 152 | @torch.no_grad() 153 | def euclidean_proj_l1ball(v, s=1.): 154 | """ Compute the Euclidean projection on a L1-ball 155 | Solves the optimization problem (using the algorithm from [1]): 156 | ..math:: 157 | min_w 0.5 * || w - v ||_2^2 , s.t. || w ||_1 <= s 158 | 159 | Args: 160 | 161 | v: (n,) numpy array, 162 | n-dimensional vector to project 163 | s: float, optional, default: 1, 164 | radius of the L1-ball 165 | 166 | Returns: 167 | w: (n,) numpy array, 168 | Euclidean projection of v on the L1-ball of radius s 169 | Notes 170 | ----- 171 | Solves the problem by a reduction to the positive simplex case 172 | See also 173 | -------- 174 | euclidean_proj_simplex 175 | """ 176 | assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s 177 | if len(v.shape) > 1: 178 | raise ValueError 179 | # compute the vector of absolute values 180 | u = abs(v) 181 | # check if v is already a solution 182 | if u.sum() <= s: 183 | # L1-norm is <= s 184 | return v 185 | # v is not already a solution: optimum lies on the boundary (norm == s) 186 | # project *u* on the simplex 187 | w = euclidean_proj_simplex(u, s=s) 188 | # compute the solution to the original problem on v 189 | w *= torch.sign(v) 190 | return w 191 | 192 | 193 | class LpBall: 194 | def __init__(self, alpha): 195 | if not 0. <= alpha: 196 | raise ValueError("Invalid constraint size alpha: {}".format(alpha)) 197 | self.alpha = alpha 198 | self.active_set = defaultdict(float) 199 | 200 | @torch.no_grad() 201 | def fw_gap(self, grad, iterate): 202 | update_direction, _ = self.lmo(-grad, iterate) 203 | return utils.bdot(-grad, update_direction) 204 | 205 | @torch.no_grad() 206 | def random_point(self, shape): 207 | """ 208 | Sample uniformly from the constraint set. 209 | L1 and L2 are implemented here. 210 | Linf implemented in the subclass. 211 | https://arxiv.org/abs/math/0503650 212 | """ 213 | if self.p == 2: 214 | distrib = Normal(0, 1) 215 | elif self.p == 1: 216 | distrib = Laplace(0, 1) 217 | x = distrib.sample(shape) 218 | e = expon(.5).rvs() 219 | denom = torch.sqrt(e + (x ** 2).sum()) 220 | return self.alpha * x / denom 221 | 222 | def __mul__(self, other): 223 | """Scales the constraint by a scalar""" 224 | ret = deepcopy(self) 225 | ret.alpha *= other 226 | return ret 227 | 228 | def __rmul__(self, other): 229 | return self.__mul__(other) 230 | 231 | def __imul__(self, other): 232 | self.alpha *= other 233 | return self 234 | 235 | def __truediv__(self, other): 236 | ret = deepcopy(self) 237 | ret.alpha /= other 238 | return ret 239 | 240 | @torch.no_grad() 241 | def make_feasible(self, model): 242 | """Projects all parameters of model into the constraint set.""" 243 | 244 | for idx, (name, param) in enumerate(model.named_parameters()): 245 | param.copy_(self.prox(param)) 246 | 247 | def is_feasible(self, x, rtol=1e-5, atol=1e-7): 248 | """Checks if x is a feasible point of the constraint.""" 249 | p_norms = (x ** self.p).reshape(x.size(0), -1).sum(-1) 250 | return p_norms.pow(1. / self.p) <= self.alpha * (1. + rtol) + atol 251 | 252 | 253 | class LinfBall(LpBall): 254 | p = np.inf 255 | 256 | @torch.no_grad() 257 | def prox(self, x, step_size=None): 258 | """Projection onto the L-infinity ball. 259 | 260 | Args: 261 | x: torch.Tensor of shape (batchs_size, *) 262 | tensor to project 263 | step_size: Any 264 | Not used here 265 | 266 | Returns: 267 | p: torch.Tensor, same shape as x 268 | projection of x onto the L-infinity ball. 269 | """ 270 | if torch.max(abs(x)) <= self.alpha: 271 | return x 272 | return torch.clamp(x, min=-self.alpha, max=self.alpha) 273 | 274 | @torch.no_grad() 275 | def lmo(self, grad, iterate): 276 | """Linear Maximization Oracle. 277 | Return s - iterate with s solving the linear problem 278 | 279 | ..math:: 280 | max_{||s||_\infty <= alpha} 281 | 282 | Args: 283 | grad: torch.Tensor of shape (batch_size, *) 284 | usually -gradient 285 | iterate: torch.Tensor of shape (batch_size, *) 286 | usually the iterate of the considered algorithm 287 | 288 | Returns: 289 | update_direction: torch.Tensor, same shape as grad and iterate, 290 | s - iterate, where s is the vertex of the constraint most correlated 291 | with u 292 | max_step_size: torch.Tensor of shape (batch_size,) 293 | 1. for a Frank-Wolfe step. 294 | """ 295 | update_direction = -iterate.clone().detach() 296 | update_direction += self.alpha * torch.sign(grad) 297 | return update_direction, torch.ones(iterate.size(0), device=iterate.device, dtype=iterate.dtype) 298 | 299 | @torch.no_grad() 300 | def random_point(self, shape): 301 | """Returns a point of given shape uniformly at random from the constraint set.""" 302 | z = torch.zeros(*shape) 303 | z.uniform_(-self.alpha, self.alpha) 304 | return z 305 | 306 | @torch.no_grad() 307 | def lmo_pairwise(self, grad, iterate, active_set): 308 | fw_direction = self.lmo(grad, iterate) + iterate.clone().detach() 309 | 310 | away_direction = min(self.active_set.keys(), 311 | key=lambda v: torch.tensor(v).dot(grad)) 312 | max_step = self.active_set[away_direction] 313 | away_direction = torch.tensor(away_direction) 314 | return fw_direction - away_direction, max_step 315 | 316 | def is_feasible(self, x, rtol=1e-5, atol=1e-7): 317 | return abs(x).reshape(x.size(0), -1).max(dim=-1)[0] <= self.alpha * (1. + rtol) + atol 318 | 319 | 320 | class L1Ball(LpBall): 321 | p = 1 322 | 323 | @torch.no_grad() 324 | def lmo(self, grad, iterate): 325 | """Linear Maximization Oracle. 326 | Return s - iterate with s solving the linear problem 327 | 328 | ..math:: 329 | max_{||s||_1 <= alpha} 330 | 331 | Args: 332 | grad: torch.Tensor of shape (batch_size, *) 333 | usually -gradient 334 | iterate: torch.Tensor of shape (batch_size, *) 335 | usually the iterate of the considered algorithm 336 | 337 | Returns: 338 | update_direction: torch.Tensor, same shape as grad and iterate, 339 | s - iterate, where s is the vertex of the constraint most correlated 340 | with u 341 | max_step_size: torch.Tensor of shape (batch_size,) 342 | 1. for a Frank-Wolfe step. 343 | """ 344 | update_direction = -iterate.clone().detach() 345 | abs_grad = abs(grad) 346 | batch_size = iterate.size(0) 347 | flatten_abs_grad = abs_grad.view(batch_size, -1) 348 | flatten_largest_mask = (flatten_abs_grad == flatten_abs_grad.max(-1, True)[0]) 349 | largest = torch.where(flatten_largest_mask.view_as(abs_grad)) 350 | 351 | update_direction[largest] += self.alpha * torch.sign( 352 | grad[largest]) 353 | 354 | return update_direction, torch.ones(iterate.size(0), device=iterate.device, dtype=iterate.dtype) 355 | 356 | @torch.no_grad() 357 | def prox(self, x, step_size=None): 358 | """Projection onto the L1 ball. 359 | 360 | Args: 361 | x: torch.Tensor of shape (batchs_size, *) 362 | tensor to project 363 | step_size: Any 364 | Not used here 365 | 366 | Returns: 367 | p: torch.Tensor, same shape as x 368 | projection of x onto the L1 ball. 369 | """ 370 | shape = x.shape 371 | flattened_x = x.view(shape[0], -1) 372 | # TODO vectorize this 373 | projected = [euclidean_proj_l1ball(row, s=self.alpha) for row in flattened_x] 374 | x = torch.stack(projected) 375 | return x.view(*shape) 376 | 377 | 378 | class L2Ball(LpBall): 379 | p = 2 380 | 381 | @torch.no_grad() 382 | def prox(self, x, step_size=None): 383 | """Projection onto the L2 ball. 384 | 385 | Args: 386 | x: torch.Tensor of shape (batchs_size, *) 387 | tensor to project 388 | step_size: Any 389 | Not used here 390 | 391 | Returns: 392 | p: torch.Tensor, same shape as x 393 | projection of x onto the L2 ball. 394 | """ 395 | norms = utils.bnorm(x) 396 | mask = norms > self.alpha 397 | projected = x.clone().detach() 398 | projected[mask] = self.alpha * utils.bdiv(projected[mask], norms[mask]) 399 | return projected 400 | 401 | 402 | @torch.no_grad() 403 | def lmo(self, grad, iterate): 404 | """Linear Maximization Oracle. 405 | Return s - iterate with s solving the linear problem 406 | 407 | ..math:: 408 | max_{||s||_2 <= alpha} 409 | 410 | Args: 411 | grad: torch.Tensor of shape (batch_size, *) 412 | usually -gradient 413 | iterate: torch.Tensor of shape (batch_size, *) 414 | usually the iterate of the considered algorithm 415 | 416 | Returns: 417 | update_direction: torch.Tensor, same shape as grad and iterate, 418 | s - iterate, where s is the vertex of the constraint most correlated 419 | with u 420 | max_step_size: torch.Tensor of shape (batch_size,) 421 | 1. for a Frank-Wolfe step. 422 | """ 423 | update_direction = -iterate.clone().detach() 424 | grad_norms = torch.norm(grad.view(grad.size(0), -1), p=2, dim=-1) 425 | update_direction += self.alpha * (grad.view(grad.size(0), -1).T 426 | / grad_norms).T.view_as(iterate) 427 | return update_direction, torch.ones(iterate.size(0), device=iterate.device, dtype=iterate.dtype) 428 | 429 | 430 | def make_LpBall(alpha, p=1): 431 | if p == 1: 432 | return L1Ball(alpha) 433 | elif p == 2: 434 | return L2Ball(alpha) 435 | 436 | elif p == np.inf: 437 | return LinfBall(alpha) 438 | 439 | raise NotImplementedError("We have only implemented ord={1, 2, np.inf} for now.") 440 | 441 | 442 | class Simplex: 443 | 444 | def __init__(self, alpha): 445 | if alpha >= 0: 446 | self.alpha = alpha 447 | else: 448 | raise ValueError("alpha must be a non negative number.") 449 | 450 | @torch.no_grad() 451 | def prox(self, x, step_size=None): 452 | shape = x.shape 453 | flattened_x = x.view(shape[0], -1) 454 | projected = [euclidean_proj_simplex(row, s=self.alpha) for row in flattened_x] 455 | x = torch.stack(projected) 456 | return x.view(*shape) 457 | 458 | @torch.no_grad() 459 | def lmo(self, grad, iterate): 460 | batch_size = grad.size(0) 461 | shape = iterate.shape 462 | max_vals, max_idx = grad.reshape(batch_size, -1).max(-1) 463 | 464 | update_direction = -iterate.clone().detach().reshape(batch_size, -1) 465 | update_direction[range(batch_size), max_idx] += self.alpha 466 | update_direction = update_direction.reshape(*shape) 467 | 468 | return update_direction, torch.ones(iterate.size(0), device=iterate.device, dtype=iterate.dtype) 469 | 470 | def is_feasible(self, x, rtol=1e-5, atol=1e-7): 471 | batch_size = x.size(0) 472 | reshaped_x = x.reshape(batch_size, -1) 473 | return torch.logical_and(reshaped_x.min(dim=-1)[0] + atol >= 0, 474 | reshaped_x.sum(-1) <= self.alpha * (1. + rtol) + atol) 475 | 476 | 477 | class NuclearNormBall: 478 | """ 479 | Nuclear norm constraint, i.e. sum of absolute eigenvalues. 480 | Also known as the Schatten-1 norm. 481 | We consider the last two dimensions of the input are the ones we compute the Nuclear Norm on. 482 | """ 483 | def __init__(self, alpha): 484 | if not 0. <= alpha: 485 | raise ValueError("Invalid constraint size alpha: {}".format(alpha)) 486 | self.alpha = alpha 487 | 488 | @torch.no_grad() 489 | def lmo(self, grad, iterate): 490 | """ 491 | Computes the LMO for the Nuclear Norm Ball on the last two dimensions. 492 | Returns :math: `s - $iterate$` where 493 | 494 | ..math:: 495 | s = \argmax_u u^\top grad. 496 | 497 | Args: 498 | grad: torch.Tensor of shape (*, m, n) 499 | iterate: torch.Tensor of shape (*, m, n) 500 | Returns: 501 | update_direction: torch.Tensor of shape (*, m, n) 502 | """ 503 | update_direction = -iterate.clone().detach() 504 | u, _, v = utils.power_iteration(grad) 505 | outer = u.unsqueeze(-1) * v.unsqueeze(-2) 506 | update_direction += self.alpha * outer 507 | return update_direction, torch.ones(iterate.size(0), device=iterate.device, dtype=iterate.dtype) 508 | 509 | @torch.no_grad() 510 | def prox(self, x, step_size=None): 511 | """ 512 | Projection operator on the Nuclear Norm constraint set. 513 | """ 514 | U, S, V = torch.svd(x) 515 | # Project S on the alpha-L1 ball 516 | ball = L1Ball(self.alpha) 517 | 518 | S_proj = ball.prox(S.view(-1, S.size(-1))).view_as(S) 519 | 520 | VT = V.transpose(-2, -1) 521 | return torch.matmul(U, torch.matmul(torch.diag_embed(S_proj), VT)) 522 | 523 | def is_feasible(self, x, atol=1e-5, rtol=1e-5): 524 | norms = torch.linalg.norm(x, dim=(-2, -1), ord='nuc') 525 | return (norms <= self.alpha * (1. + rtol) + atol) 526 | 527 | 528 | class GroupL1Ball: 529 | 530 | # TODO: init is shared with the penalty GroupL1 object. Factorize the code. 531 | def __init__(self, alpha, groups): 532 | if alpha >= 0: 533 | self.alpha = alpha 534 | else: 535 | raise ValueError("alpha must be nonnegative.") 536 | # TODO: implement ValueErrors 537 | # groups must be indices and non overlapping 538 | if not isinstance(groups[0], torch.Tensor): 539 | groups = [torch.tensor(group) for group in groups] 540 | while groups[0].dim() < 2: 541 | groups = [group.unsqueeze(-1) for group in groups] 542 | 543 | self.groups = [] 544 | for g in groups: 545 | self.groups.append((...,) + tuple(g.T)) 546 | 547 | def get_group_norms(self, x): 548 | """Compute the vector of L2 norms within groups""" 549 | group_norms = [] 550 | for g in self.groups: 551 | subtensor = x[g] 552 | 553 | group_norms.append(torch.linalg.norm(subtensor, dim=-1)) 554 | 555 | group_norms = torch.stack(group_norms, dim=-1) 556 | return group_norms 557 | 558 | @torch.no_grad() 559 | def lmo(self, grad, iterate): 560 | update_direction = -iterate.detach().clone() 561 | # find group with largest L2 norm 562 | group_norms = self.get_group_norms(grad) 563 | max_groups = torch.argmax(group_norms, dim=-1) 564 | 565 | for k, max_group in enumerate(max_groups): 566 | idx = (k, *self.groups[max_group]) 567 | update_direction[idx] += (self.alpha * grad[idx] 568 | / group_norms[k, max_group]) 569 | 570 | return update_direction, torch.ones(iterate.size(0), device=iterate.device, dtype=iterate.dtype) 571 | 572 | 573 | @torch.no_grad() 574 | def prox(self, x, step_size=None): 575 | """Proximal operator for the GroupL1 constraint""" 576 | 577 | group_norms = self.get_group_norms(x) 578 | l1ball = L1Ball(self.alpha) 579 | normalized_group_norms = l1ball.prox(group_norms) 580 | 581 | output = x.detach().clone() 582 | 583 | # renormalize each group 584 | for k, g in enumerate(self.groups): 585 | renorm = normalized_group_norms[:, k] / group_norms[:, k] 586 | renorm[torch.isnan(renorm)] = 1. 587 | output[g] = utils.bmul(output[g], renorm) 588 | 589 | return output 590 | 591 | def is_feasible(self, x, rtol=1e-5, atol=1e-7): 592 | group_norms = self.get_group_norms(x) 593 | return torch.linalg.norm(group_norms, ord=1, dim=-1) <= (self.alpha * (1. + rtol) 594 | + atol) 595 | 596 | 597 | class Box: 598 | """ 599 | Box constraint. 600 | Args: 601 | a: float or None 602 | min of the box constraint 603 | b: float or None 604 | max of the box constraint 605 | """ 606 | def __init__(self, a=None, b=None): 607 | """ 608 | """ 609 | 610 | if a is None and b is None: 611 | raise ValueError("One of a, b should not be None.") 612 | if a is None: 613 | a = -np.inf 614 | elif b is None: 615 | b = np.inf 616 | else: 617 | if b < a: 618 | raise ValueError(f"This constraint supposes that a <= b. Got {a}, {b}.") 619 | self.a = a 620 | self.b = b 621 | 622 | def prox(self, x, step_size=None): 623 | """Projection operator on the constraint. 624 | Args: 625 | x: torch.Tensor 626 | step_size: Any 627 | Returns: 628 | x_thresh: torch.Tensor 629 | x clamped between a and b. 630 | """ 631 | return torch.clamp(x, min=self.a, max=self.b) 632 | 633 | def is_feasible(self, x, rtol=1e-5, atol=1e-7): 634 | reshaped_x = x.reshape(x.size(0), -1) 635 | return torch.logical_and(reshaped_x.min(-1)[0] >= self.a * (1. + rtol) - atol, 636 | reshaped_x.max(-1)[0] <= self.b * (1. + rtol) + atol) 637 | 638 | 639 | class Cone: 640 | """ 641 | Represents second order cones of revolution centered in vector `u` (batch-wise), and angle :math: `\hat alpha`. 642 | This constraint therefore really represents a batch of cones, which share the same half-angle. 643 | The are all pointed in 0 (the origin). 644 | Formally, the set is the following: 645 | ..math:: 646 | \{x \in R^d,~ \|(uu^\top - Id)x\| \leq \alpha u^\top x \} 647 | Note that :math: `\cos(\hat \alpha) = 1 / (1 + \alpha^2)`. 648 | The standard second order cone (ice-cream cone) is given by 649 | `u = (0, ..., 0, 1)`, `cos_alpha=.5`. 650 | Args: 651 | u: torch.Tensor 652 | batch-wise directions centering the cones 653 | cos_angle: float 654 | cosine of the half-angle of the cone. 655 | """ 656 | def __init__(self, u, cos_angle=.05): 657 | batch_size = u.size(0) 658 | # normalize the cone directions 659 | self.directions = utils.bmul(u, 1. / torch.norm(u.reshape(batch_size, -1), dim=-1)) 660 | self.cos_angle = cos_angle 661 | self.alpha = np.sqrt(1. / cos_angle - 1) 662 | 663 | def proj_u(self, x, step_size=None): 664 | """ 665 | Projects x on self.directions batch-wise 666 | Args: 667 | x: torch.Tensor of shape (batch_size, *) 668 | vectors to project 669 | step_size: Any 670 | Not used 671 | Returns: 672 | proj_x: torch.Tensor of shape (batch_size, *) 673 | batch-wise projection of x onto self.directions 674 | """ 675 | 676 | return utils.bmul(utils.bdot(x, self.directions), self.directions) 677 | 678 | 679 | @torch.no_grad() 680 | def prox(self, x, step_size=None): 681 | """ 682 | Projects `x` batch-wise onto the cone constraint. 683 | Args: 684 | x: torch.Tensor of shape (batch_size, *) 685 | batch of vectors to project 686 | step_size: Any 687 | Not used 688 | Returns: 689 | proj_x: torch.Tensor of shape (batch_size, *) 690 | batch-wise projection of `x` onto the cone constraint. 691 | """ 692 | batch_size = x.size(0) 693 | uTx = utils.bdot(self.directions, x) 694 | p_u = self.proj_u(x) 695 | p_orth_u = x - p_u 696 | norm_p_orth_u = torch.norm(p_orth_u.reshape(batch_size, -1), dim=-1) 697 | identity_idx = (norm_p_orth_u <= self.alpha * uTx) 698 | zero_idx = (self.alpha * norm_p_orth_u <= - uTx) 699 | project_idx = ~torch.logical_or(identity_idx, zero_idx) 700 | 701 | res = x.detach().clone() 702 | res[zero_idx] = 0. 703 | res[project_idx] = utils.bmul((self.alpha * norm_p_orth_u[project_idx] + uTx[project_idx]) / (1. + self.alpha ** 2), 704 | (self.alpha * utils.bmul(p_orth_u[project_idx], 1 / norm_p_orth_u[project_idx]) 705 | + self.directions[project_idx])) 706 | return res 707 | 708 | def is_feasible(self, x, rtol=1e-5, atol=1e-7): 709 | cosines = utils.bdot(x, self.directions) 710 | return abs(cosines) >= utils.bnorm(x) * self.cos_angle * (1. + rtol) + atol 711 | 712 | 713 | @dataclasses.dataclass 714 | class Polytope: 715 | """Constraint defined as the convex hull of a set of vertices. 716 | 717 | Attributes: 718 | vertices: the vertices of the polytope. Shape (batch_size, *) 719 | """ 720 | 721 | vertices: torch.Tensor 722 | 723 | @torch.no_grad() 724 | def lmo(self, grad, iterate): 725 | """Linear Oracle. 726 | 727 | Returns s - iterate, such that 728 | s = argmax_{s\\in vertices} \\langle s, grad \\rangle""" 729 | 730 | batch_size = grad.size(0) 731 | 732 | update_direction = -iterate.detach().clone() 733 | 734 | similarities = utils.bmv(self.vertices, grad) 735 | top_vertex_index = torch.argmax(similarities, dim=-1) 736 | update_direction += self.vertices[range(batch_size), top_vertex_index] 737 | return update_direction, torch.ones(iterate.size(0), device=iterate.device, dtype=iterate.dtype) 738 | 739 | @torch.no_grad() 740 | def lmo_pairwise(self, grad, iterate, active_set): 741 | """Outputs the parwise update direction. 742 | 743 | Returns u-v where 744 | u = argmax_{u\in vertices} \\langle u, grad \\rangle 745 | v = argmax_{v\in active set} \\langle v, -grad \\rangle 746 | 747 | Args: 748 | iterate: current iterate, not used here 749 | grad: the direction to move towards 750 | active_set: previously chosen vertices, and associated weights. 751 | """ 752 | vertices = self.vertices.squeeze() 753 | similarities = vertices @ grad.squeeze() 754 | # FW direction 755 | fw_idx = torch.argmax(similarities).item() 756 | update_direction = vertices[fw_idx].detach().clone() 757 | 758 | # Away direction 759 | active_set_idx = torch.tensor(list(active_set.keys())) 760 | mask = torch.zeros_like(similarities, dtype=torch.bool) 761 | mask[active_set_idx] = 1. 762 | masked_similarities = mask * similarities 763 | masked_similarities[mask == 0] = float('inf') 764 | away_idx = torch.argmin(masked_similarities).item() 765 | update_direction -= vertices[away_idx] 766 | max_step_size = active_set[away_idx] 767 | 768 | return update_direction.unsqueeze(0), fw_idx, away_idx, max_step_size 769 | -------------------------------------------------------------------------------- /chop/optim.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full-gradient optimizers. 3 | ========================= 4 | 5 | This module contains full gradient optimizers in PyTorch. 6 | These optimizers expect to be called on variables of shape (batch_size, *), 7 | and will perform the optimization point-wise over the batch. 8 | 9 | This API is inspired by the COPT project 10 | https://github.com/openopt/copt. 11 | """ 12 | 13 | from numbers import Number 14 | import warnings 15 | from collections import defaultdict 16 | 17 | import torch 18 | 19 | import numpy as np 20 | 21 | from scipy import optimize 22 | from chop import utils 23 | from chop import constraints 24 | 25 | 26 | def minimize_three_split( 27 | closure, 28 | x0, 29 | prox1=None, 30 | prox2=None, 31 | tol=1e-6, 32 | max_iter=1000, 33 | verbose=0, 34 | callback=None, 35 | line_search=True, 36 | step=None, 37 | max_iter_backtracking=100, 38 | backtracking_factor=0.7, 39 | h_Lipschitz=None, 40 | *args_prox 41 | ): 42 | 43 | """Davis-Yin three operator splitting method. 44 | This algorithm can solve problems of the form 45 | 46 | minimize_x f(x) + g(x) + h(x) 47 | 48 | where f is a smooth function and g and h are (possibly non-smooth) 49 | functions for which the proximal operator is known. 50 | 51 | Remark: this method returns x = prox1(...). If g and h are two indicator 52 | functions, this method only garantees that x is feasible for the first. 53 | Therefore if one of the constraints is a hard constraint, 54 | make sure to pass it to prox1. 55 | 56 | Args: 57 | closure: callable 58 | Returns the function values and gradient of the objective function. 59 | With return_gradient=False, returns only the function values. 60 | Shape of return value: (batch_size, *) 61 | 62 | x0 : torch.Tensor(shape: (batch_size, *)) 63 | Initial guess 64 | 65 | prox1 : callable or None 66 | prox1(x, step_size, *args) returns the proximal operator of g at xa 67 | with parameter step_size. 68 | step_size can be a scalar or of shape (batch_size,). 69 | 70 | prox2 : callable or None 71 | prox2(x, step_size, *args) returns the proximal operator of g at xa 72 | with parameter step_size. 73 | alpha can be a scalar or of shape (batch_size,). 74 | 75 | tol: float 76 | Tolerance of the stopping criterion. 77 | 78 | max_iter : int 79 | Maximum number of iterations. 80 | 81 | verbose : int 82 | Verbosity level, from 0 (no output) to 2 (output on each iteration) 83 | 84 | callback : callable. 85 | callback function (optional). 86 | Called with locals() at each step of the algorithm. 87 | The algorithm will exit if callback returns False. 88 | 89 | line_search : boolean 90 | Whether to perform line-search to estimate the step sizes. 91 | 92 | step_size : float or tensor(shape: (batch_size,)) or None 93 | Starting value(s) for the line-search procedure. 94 | if None, step_size will be estimated for each datapoint in the batch. 95 | 96 | max_iter_backtracking: int 97 | maximun number of backtracking iterations. Used in line search. 98 | 99 | backtracking_factor: float 100 | the amount to backtrack by during line search. 101 | 102 | args_prox: iterable 103 | (optional) Extra arguments passed to the prox functions. 104 | 105 | kwargs_prox: dict 106 | (optional) Extra keyword arguments passed to the prox functions. 107 | 108 | 109 | Returns: 110 | res : OptimizeResult 111 | The optimization result represented as a 112 | ``scipy.optimize.OptimizeResult`` object. Important attributes are: 113 | ``x`` the solution tensor, ``success`` a Boolean flag indicating if 114 | the optimizer exited successfully and ``message`` which describes 115 | the cause of the termination. See `scipy.optimize.OptimizeResult` 116 | for a description of other attributes. 117 | """ 118 | 119 | success = torch.zeros(x0.size(0), dtype=bool) 120 | if not max_iter_backtracking > 0: 121 | raise ValueError("Line search iterations need to be greater than 0") 122 | 123 | LS_EPS = np.finfo(np.float).eps 124 | 125 | if prox1 is None: 126 | 127 | @torch.no_grad() 128 | def prox1(x, s=None, *args): 129 | return x 130 | 131 | if prox2 is None: 132 | @torch.no_grad() 133 | def prox2(x, s=None, *args): 134 | return x 135 | 136 | x = x0.detach().clone().requires_grad_(True) 137 | batch_size = x.size(0) 138 | 139 | if step is None: 140 | line_search = True 141 | step_size = 1.0 / utils.init_lipschitz(closure, x) 142 | 143 | elif isinstance(step, Number): 144 | step_size = step * torch.ones(batch_size, 145 | device=x.device, 146 | dtype=x.dtype) 147 | 148 | else: 149 | raise ValueError("step must be float or None.") 150 | 151 | z = prox2(x, step_size, *args_prox) 152 | z = z.clone().detach() 153 | z.requires_grad_(True) 154 | 155 | fval, grad = closure(z) 156 | 157 | x = prox1(z - utils.bmul(step_size, grad), step_size, *args_prox) 158 | u = torch.zeros_like(x) 159 | 160 | for it in range(max_iter): 161 | z.requires_grad_(True) 162 | fval, grad = closure(z) 163 | with torch.no_grad(): 164 | x = prox1(z - utils.bmul(step_size, u + grad), step_size, *args_prox) 165 | incr = x - z 166 | norm_incr = torch.norm(incr.view(incr.size(0), -1), dim=-1) 167 | rhs = fval + utils.bdot(grad, incr) + ((norm_incr ** 2) / (2 * step_size)) 168 | ls_tol = closure(x, return_jac=False) 169 | mask = torch.bitwise_and(norm_incr > 1e-7, line_search) 170 | ls = mask.detach().clone() 171 | # TODO: optimize code in this loop using mask 172 | for it_ls in range(max_iter_backtracking): 173 | if not(mask.any()): 174 | break 175 | rhs[mask] = fval[mask] + utils.bdot(grad[mask], incr[mask]) 176 | rhs[mask] += utils.bmul(norm_incr[mask] ** 2, 1. / (2 * step_size[mask])) 177 | 178 | ls_tol[mask] = closure(x, return_jac=False)[mask] - rhs[mask] 179 | mask &= (ls_tol > LS_EPS) 180 | step_size[mask] *= backtracking_factor 181 | 182 | z = prox2(x + utils.bmul(step_size, u), step_size, *args_prox) 183 | u += utils.bmul(x - z, 1. / step_size) 184 | certificate = utils.bmul(norm_incr, 1. / step_size) 185 | 186 | if callback is not None: 187 | if callback(locals()) is False: 188 | break 189 | 190 | success = torch.bitwise_and(certificate < tol, it > 0) 191 | if success.all(): 192 | break 193 | 194 | return optimize.OptimizeResult(x=x, success=success, nit=it, fval=fval, certificate=certificate) 195 | 196 | 197 | def minimize_pgd_madry(closure, x0, prox, lmo, step=None, max_iter=200, prox_args=(), callback=None): 198 | x = x0.detach().clone() 199 | batch_size = x.size(0) 200 | 201 | if step is None: 202 | # estimate lipschitz constant 203 | # TODO: this is not the optimal step-size (if there even is one.) 204 | # I don't recommend to use this. 205 | L = utils.init_lipschitz(closure, x0) 206 | step_size = 1. / L 207 | 208 | elif isinstance(step, Number): 209 | step_size = torch.ones(batch_size, device=x.device) * step 210 | 211 | elif isinstance(step, torch.Tensor): 212 | step_size = step 213 | 214 | else: 215 | raise ValueError(f"step must be a number or a torch Tensor, got {step} instead") 216 | 217 | for it in range(max_iter): 218 | x.requires_grad = True 219 | _, grad = closure(x) 220 | with torch.no_grad(): 221 | update_direction, _ = lmo(-grad, x) 222 | update_direction += x 223 | x = prox(x + utils.bmul(step_size, update_direction), 224 | step_size, *prox_args) 225 | 226 | if callback is not None: 227 | if callback(locals()) is False: 228 | break 229 | 230 | fval, grad = closure(x) 231 | return optimize.OptimizeResult(x=x, nit=it, fval=fval, grad=grad) 232 | 233 | 234 | def backtracking_pgd(closure, prox, step_size, x, grad, increase=1.01, decrease=.6, max_iter_backtracking=1000): 235 | 236 | batch_size = x.size(0) 237 | rhs = -np.inf * torch.ones(batch_size) 238 | lhs = np.inf * torch.ones(batch_size) 239 | 240 | need_to_backtrack = lhs > rhs 241 | 242 | while (~need_to_backtrack).any(): 243 | step_size[~need_to_backtrack] *= increase 244 | 245 | while need_to_backtrack.any(): 246 | with torch.no_grad(): 247 | x_candidate = prox(x - utils.bmul(step_size, grad), step_size) 248 | 249 | lhs = closure(x_candidate, return_jac=False) 250 | rhs = (closure(x, return_jac=False) - utils.bdot(grad, x - x_candidate) 251 | + utils.bmul(1. / (2 * step_size), 252 | torch.norm((x - x_candidate).view(x.size(0), -1), 253 | dim=-1))) ** 2 254 | 255 | 256 | def minimize_pgd(closure, x0, prox=None, step='backtracking', max_iter=200, 257 | max_iter_backtracking=1000, 258 | backtracking_factor=.6, 259 | tol=1e-8, 260 | *prox_args, 261 | callback=None): 262 | """ 263 | Performs Projected Gradient Descent on batch of objectives of form: 264 | f(x) + g(x). 265 | We suppose we have access to gradient computation for f through closure, 266 | and to the proximal operator of g in prox. 267 | 268 | Args: 269 | closure: callable 270 | 271 | x0: torch.Tensor of shape (batch_size, *). 272 | 273 | prox: callable 274 | proximal operator of g 275 | 276 | step: 'backtracking' or float or torch.tensor of shape (batch_size,) or None. 277 | step size to be used. If None, will be estimated at the beginning 278 | using line search. 279 | If 'backtracking', will be estimated at each step using backtracking line search. 280 | 281 | max_iter: int 282 | number of iterations to perform. 283 | 284 | max_iter_backtracking: int 285 | max number of iterations in the backtracking line search 286 | 287 | backtracking_factor: float 288 | factor by which to multiply the step sizes during line search 289 | 290 | tol: float 291 | stops the algorithm when the certificate is smaller than tol 292 | for all datapoints in the batch 293 | 294 | prox_args: tuple 295 | (optional) additional args for prox 296 | 297 | callback: callable 298 | (optional) Any callable called on locals() at the end of each iteration. 299 | Often used for logging. 300 | """ 301 | x = x0.detach().clone() 302 | batch_size = x.size(0) 303 | 304 | if prox is None: 305 | def prox(x, s=None): 306 | return x 307 | 308 | if step is None: 309 | # estimate lipschitz constant 310 | L = utils.init_lipschitz(closure, x0) 311 | step_size = 1. / L 312 | 313 | elif step == 'backtracking': 314 | L = 1.8 * utils.init_lipschitz(closure, x0) 315 | step_size = 1. / L 316 | 317 | elif type(step) == float: 318 | step_size = step * torch.ones(batch_size, device=x.device) 319 | 320 | else: 321 | raise ValueError("step must be float or backtracking or None") 322 | 323 | for it in range(max_iter): 324 | 325 | fval, grad = closure(x) 326 | x_next = prox(x - utils.bmul(step_size, grad), step_size, *prox_args) 327 | update_direction = x_next - x 328 | 329 | if step == 'backtracking': 330 | step_size *= 1.1 331 | mask = torch.ones(batch_size, dtype=bool, device=x.device) 332 | 333 | with torch.no_grad(): 334 | for _ in range(max_iter_backtracking): 335 | f_next = closure(x_next, return_jac=False) 336 | rhs = (fval 337 | + utils.bdot(grad, update_direction) 338 | + utils.bmul(utils.bdot(update_direction, 339 | update_direction), 340 | 1. / (2. * step_size)) 341 | ) 342 | mask = f_next > rhs 343 | 344 | if not mask.any(): 345 | break 346 | 347 | step_size[mask] *= backtracking_factor 348 | x_next = prox(x - utils.bmul(step_size, 349 | grad), 350 | step_size[mask], 351 | *prox_args) 352 | update_direction[mask] = x_next[mask] - x[mask] 353 | else: 354 | warnings.warn("Maximum number of line-search iterations " 355 | "reached.") 356 | 357 | with torch.no_grad(): 358 | cert = torch.norm(utils.bmul(update_direction, 1. / step_size), 359 | dim=-1) 360 | x.copy_(x_next) 361 | if (cert < tol).all(): 362 | break 363 | 364 | if callback is not None: 365 | if callback(locals()) is False: 366 | break 367 | 368 | fval, grad = closure(x) 369 | return optimize.OptimizeResult(x=x, nit=it, fval=fval, grad=grad, 370 | certificate=cert) 371 | 372 | 373 | def minimize_frank_wolfe(closure, x0, lmo, step='sublinear', 374 | max_iter=200, callback=None, *args, **kwargs): 375 | """Performs the Frank-Wolfe algorithm on a batch of objectives of the form 376 | min_x f(x) 377 | s.t. x in C 378 | 379 | where we have access to the Linear Minimization Oracle (LMO) of the constraint set C, 380 | and the gradient of f through closure. 381 | 382 | Args: 383 | closure: callable 384 | gives function values and the jacobian of f. 385 | 386 | x0: torch.Tensor of shape (batch_size, *). 387 | initial guess 388 | 389 | lmo: callable 390 | Returns update_direction, max_step_size 391 | 392 | step: float or 'sublinear' 393 | step-size scheme to be used. 394 | 395 | max_iter: int 396 | max number of iterations. 397 | 398 | callback: callable 399 | (optional) Any callable called on locals() at the end of each iteration. 400 | Often used for logging. 401 | 402 | Returns: 403 | 404 | result: optimize.OptimizeResult object 405 | Holds the result of the optimization, and certificates of convergence. 406 | """ 407 | x = x0.detach().clone() 408 | batch_size = x.size(0) 409 | if not (isinstance(step, Number) or step == 'sublinear'): 410 | raise ValueError(f"step must be a float or 'sublinear', got {step} instead.") 411 | 412 | if isinstance(step, Number): 413 | step_size = step * torch.ones(batch_size, device=x.device, dtype=x.dtype) 414 | 415 | cert = np.inf * torch.ones(batch_size, device=x.device) 416 | 417 | for it in range(max_iter): 418 | 419 | x.requires_grad = True 420 | fval, grad = closure(x) 421 | update_direction, max_step_size = lmo(-grad, x) 422 | cert = utils.bdot(-grad, update_direction) 423 | 424 | if step == 'sublinear': 425 | step_size = 2. / (it + 2) * torch.ones(batch_size, dtype=x.dtype, device=x.device) 426 | 427 | with torch.no_grad(): 428 | step_size = torch.min(step_size, max_step_size) 429 | x += utils.bmul(update_direction, step_size) 430 | 431 | if callback is not None: 432 | if callback(locals()) is False: 433 | break 434 | 435 | fval, grad = closure(x) 436 | return optimize.OptimizeResult(x=x, nit=it, fval=fval, grad=grad, 437 | certificate=cert) 438 | 439 | 440 | def update_active_set(active_set, 441 | fw_idx, away_idx, 442 | step_size): 443 | 444 | max_step_size = active_set[away_idx] 445 | active_set[fw_idx] += step_size 446 | active_set[away_idx] -= step_size 447 | 448 | if active_set[away_idx] == 0.: 449 | # drop step: remove vertex from active set 450 | del active_set[away_idx] 451 | if active_set[away_idx] < 0.: 452 | raise ValueError(f"The step size used is too large. " 453 | f"{step_size: .3f} vs. {max_step_size:.3f}") 454 | 455 | return active_set 456 | 457 | 458 | def backtracking_fw( 459 | x, 460 | fval, 461 | old_fval, 462 | closure, 463 | certificate, 464 | lipschitz_t, 465 | max_step_size, 466 | update_direction, 467 | norm_update_direction, 468 | tol=torch.finfo(torch.float32).eps 469 | ): 470 | """Performs backtracking line search for Frank-Wolfe algorithms.""" 471 | 472 | ratio_decrease = .9 473 | ratio_increase = 2. 474 | max_linesearch_iter = 100 475 | 476 | if old_fval is not None: 477 | tmp = (certificate ** 2) / (2 * (old_fval - fval) * norm_update_direction) 478 | lipschitz_t = max(min(tmp, lipschitz_t), lipschitz_t * ratio_decrease) 479 | 480 | for _ in range(max_linesearch_iter): 481 | step_size_t = certificate / (norm_update_direction * lipschitz_t) 482 | if step_size_t < max_step_size: 483 | rhs = -0.5 * step_size_t * certificate 484 | else: 485 | step_size_t = max_step_size 486 | rhs = ( 487 | -step_size_t * certificate 488 | + 0.5 * (step_size_t ** 2) * lipschitz_t * norm_update_direction 489 | ) 490 | fval_next, grad_next = closure(x + step_size_t * update_direction) 491 | if fval_next - fval <= rhs + tol: 492 | # .. sufficient decrease condition verified .. 493 | break 494 | else: 495 | lipschitz_t *= ratio_increase 496 | else: 497 | warnings.warn( 498 | "Exhausted line search iterations in minimize_frank_wolfe", RuntimeWarning 499 | ) 500 | return step_size_t, lipschitz_t, fval_next, grad_next 501 | 502 | 503 | def minimize_pairwise_frank_wolfe( 504 | closure, 505 | x0_idx, 506 | polytope, 507 | step='backtracking', 508 | lipschitz=None, 509 | max_iter=200, 510 | tol=1e-6, 511 | callback=None 512 | ): 513 | """Minimize using Pairwise Frank-Wolfe. 514 | 515 | WARNING: This implementation is different from other functions in this file. 516 | As of now, it does not handle batched problems, and only handles Polytope constraints. 517 | 518 | Args: 519 | closure: closure of the function to minimize 520 | x0_idx: the starting vertex on the polytope 521 | polytope: the polytope constraint 522 | step: backtracking line search as defined in [1] 523 | lipschitz: an initial Lipschitz estimate of the gradient 524 | max_iter: maximum number of iterations 525 | tol: tolerance on the Frank-Wolfe gap 526 | callback: a callable callback function 527 | """ 528 | 529 | if not isinstance(polytope, constraints.Polytope): 530 | raise ValueError("polytope must be a `chop.constraints.Polytope`.") 531 | 532 | if polytope.vertices.size(0) != 1: 533 | raise NotImplementedError("This optimizer can only handle one problem instance at a time for now.") 534 | 535 | x = polytope.vertices[:, x0_idx].detach().clone() 536 | active_set = defaultdict(float) 537 | 538 | active_set[x0_idx] = 1. 539 | 540 | cert = float('inf') 541 | 542 | x.requires_grad = True 543 | fval, grad = closure(x) 544 | old_fval = None 545 | 546 | 547 | lipschitz_t = None 548 | step_size = None 549 | 550 | if lipschitz is not None: 551 | lipschitz_t = lipschitz 552 | 553 | for it in range(max_iter): 554 | update_direction, fw_idx, away_idx, max_step_size = polytope.lmo_pairwise(-grad, x, active_set) 555 | norm_update_direction = torch.linalg.norm(update_direction) ** 2 556 | cert = utils.bdot(update_direction, -grad) 557 | 558 | if lipschitz_t is None: 559 | eps = 1e-3 560 | grad_eps = closure(x + eps * update_direction)[1] 561 | lipschitz_t = torch.linalg.norm(grad-grad_eps) / ( 562 | eps * torch.sqrt(norm_update_direction) 563 | ) 564 | print(f"Estimated L_t = {lipschitz_t}") 565 | 566 | if cert <= tol: 567 | break 568 | 569 | if step == 'DR': 570 | step_size = min( 571 | cert / (norm_update_direction * lipschitz_t), max_step_size 572 | ) 573 | fval_next, grad_next = closure(x + step_size * update_direction) 574 | elif step == 'backtracking': 575 | step_size, lipschitz_t, fval_next, grad_next = backtracking_fw( 576 | x, fval, old_fval, closure, cert, lipschitz_t, max_step_size, 577 | update_direction, norm_update_direction 578 | ) 579 | 580 | elif step == 'sublinear': 581 | step_size = 2. / (it + 2.) 582 | step_size = min(step_size, max_step_size) 583 | fval_next, grad_next = closure(x + step_size * update_direction) 584 | 585 | if callback is not None: 586 | if callback(locals()) is False: # pylint: disable=g-bool-id-comparison 587 | break 588 | 589 | with torch.no_grad(): 590 | x.add_(step_size * update_direction) 591 | 592 | update_active_set( 593 | active_set, 594 | fw_idx, 595 | away_idx, 596 | step_size) 597 | 598 | old_fval = fval 599 | fval, grad = fval_next, grad_next 600 | if callback is not None: 601 | callback(locals()) 602 | 603 | return optimize.OptimizeResult(x=x.data, nit=it, certificate=cert, active_set=active_set, 604 | fval=fval, grad=grad) 605 | 606 | 607 | def minimize_alternating_fw_prox(closure, x0, y0, prox=None, lmo=None, lipschitz=1e-3, 608 | step='sublinear', line_search=None, max_iter=200, callback=None, 609 | *args, **kwargs): 610 | """ 611 | Implements algorithm from [Garber et al. 2018] 612 | https://arxiv.org/abs/1802.05581 613 | 614 | to solve the following problem 615 | 616 | 617 | ..math:: 618 | \min_{x, y} f(x + y) + R_x(x) + R_y(y). 619 | 620 | We suppose that $f$ is $L$-smooth and that 621 | we have access to the following operators: 622 | 623 | - a generalized LMO for $R_y$: 624 | ..math:: 625 | gLMO(w) = \text{argmin}_w R_y(w) + \langle w, \nabla f(x_t + y_t) \rangle 626 | 627 | - a prox operator for $R_x$: 628 | ..math:: 629 | prox(v) = \text{argmin}_v R_x(v) + \langle v, \nabla f(x_t+ y_t) \rangle + \frac{\gamma_t L}{2} \|v + w_t - (x_t + y_t)\|^2 630 | 631 | Args: 632 | x0: torch.Tensor of shape (batch_size, *) 633 | starting point for x 634 | 635 | y0: torch.Tensor of shape (batch_size, *) 636 | starting point for y 637 | 638 | prox: function 639 | proximal operator for R_x 640 | 641 | lmo: function 642 | generalized LMO operator for R_y. If R_y is an indicator function, 643 | it reduces to the usual LMO operator. 644 | 645 | lipschitz: float 646 | initial guess of the lipschitz constant of f 647 | 648 | step: float or 'sublinear' 649 | step-size scheme to be used. 650 | 651 | max_iter: int 652 | max number of iterations. 653 | 654 | callback: callable 655 | (optional) Any callable called on locals() at the end of each iteration. 656 | Often used for logging. 657 | 658 | Returns: 659 | 660 | result: optimize.OptimizeResult object 661 | Holds the result of the optimization, and certificates of convergence. 662 | """ 663 | 664 | x = x0.detach().clone() 665 | y = y0.detach().clone() 666 | batch_size = x.size(0) 667 | 668 | if x.shape != y.shape: 669 | raise ValueError(f"x, y should have the same shape. Got {x.shape}, {y.shape}.") 670 | 671 | if not (isinstance(step, Number) or step == 'sublinear'): 672 | raise ValueError(f"step must be a float or 'sublinear', got {step} instead.") 673 | 674 | if isinstance(step, Number): 675 | step_size = step * torch.ones(batch_size, device=x.device, dtype=x.dtype) 676 | 677 | # TODO: add error catching for L0 678 | Lt = lipschitz 679 | 680 | for it in range(max_iter): 681 | 682 | if step == 'sublinear': 683 | step_size = 2. / (it + 2) * torch.ones(batch_size, device=x.device) 684 | 685 | x.requires_grad_(True) 686 | y.requires_grad_(True) 687 | z = x + y 688 | 689 | f_val, grad = closure(z) 690 | 691 | # estimate Lipschitz constant with backtracking line search 692 | Lt = utils.init_lipschitz(closure, z, L0=Lt) 693 | 694 | y_update, max_step_size = lmo(-grad, y) 695 | with torch.no_grad(): 696 | w = y_update + y 697 | prox_step_size = utils.bmul(step_size, Lt) 698 | v = prox(z - w - utils.bdiv(grad, prox_step_size), prox_step_size) 699 | 700 | with torch.no_grad(): 701 | if line_search is None: 702 | step_size = torch.min(step_size, max_step_size) 703 | else: 704 | step_size = line_search(locals()) 705 | 706 | y += utils.bmul(step_size, y_update) 707 | x_update = v - x 708 | x += utils.bmul(step_size, x_update) 709 | 710 | if callback is not None: 711 | if callback(locals()) is False: 712 | break 713 | 714 | fval, grad = closure(x + y) 715 | # TODO: add a certificate of optimality 716 | result = optimize.OptimizeResult(x=x, y=y, nit=it, fval=fval, grad=grad, certificate=None) 717 | return result 718 | -------------------------------------------------------------------------------- /chop/penalties.py: -------------------------------------------------------------------------------- 1 | """ 2 | Penalties. 3 | ========= 4 | This module contains classes representing penalties / regularizers. 5 | They function batch-wise, similar to objects in `chop.constraints`. 6 | Reshaping will be of order if the penalties are used on the parameters of a model. 7 | 8 | Code inspired from https://github.com/openopt/copt/. 9 | 10 | The proximal operators are derived e.g. in https://www.di.ens.fr/~fbach/opt_book.pdf. 11 | """ 12 | 13 | from numbers import Number 14 | from numpy.core.fromnumeric import nonzero 15 | import torch 16 | import torch.nn.functional as F 17 | 18 | from chop import utils 19 | 20 | 21 | class L1: 22 | """L1 penalty. Batch-wise function. For each element in the batch, 23 | the L1 penalty is given by 24 | ..math:: 25 | \Omega(x) = \alpha \|x\|_1 26 | """ 27 | 28 | def __init__(self, alpha: float): 29 | """ 30 | Args: 31 | alpha: float 32 | Size of the penalty. Must be non-negative. 33 | """ 34 | if alpha < 0: 35 | raise ValueError("alpha must be non negative.") 36 | self.alpha = alpha 37 | 38 | def __call__(self, x): 39 | """ 40 | Returns the value of the penalty on x, batch_size. 41 | 42 | Args: 43 | x: torch.Tensor 44 | x has shape (batch_size, *) 45 | """ 46 | batch_size = x.size(0) 47 | return self.alpha * abs(x.view(batch_size, -1)).sum(dim=-1) 48 | 49 | def prox(self, x, step_size=None): 50 | """Proximal operator for the L1 norm penalty. This is given by soft-thresholding. 51 | 52 | Args: 53 | x: torch.Tensor 54 | x has shape (batch_size, *) 55 | step_size: float or torch.Tensor of shape (batch_size,) 56 | 57 | """ 58 | if isinstance(step_size, Number): 59 | step_size = step_size * torch.ones(x.size(0), device=x.device, dtype=x.dtype) 60 | return utils.bmul(torch.sign(x), F.relu(abs(x) - self.alpha * step_size.view((-1,) + (1,) * (x.dim() - 1)))) 61 | 62 | 63 | class GroupL1: 64 | """ 65 | Group LASSO penalty. Batch-wise function. 66 | """ 67 | 68 | def __init__(self, alpha, groups): 69 | """ 70 | Args: 71 | alpha: float 72 | Size of the penalty. Must be non-negative. 73 | 74 | groups: iterable of iterables 75 | Each element of groups will be used to index the given tensor to compute 76 | the penalty on. See example. 77 | 78 | Examples: 79 | Our input is of shape (batch_size, 4), and we want to split the features in two groups. 80 | The first contains the first two features, and the second the latter two. This is done by: 81 | $ groups = [(0, 1), (2, 3)] 82 | 83 | In this case, since the groups are of equal size, we could have used 84 | $ groups = torch.tensor([[0, 1], 85 | $ [2, 3]]) 86 | 87 | If the input is of shape (batch_size, 4, 2), and we want to split 88 | our features in 2 groups (left half and right half of the image), then each group 89 | is an iterable over the coordinates contained in it. 90 | 91 | $ groups = [((0, 0), (0, 1), (1, 0), (1, 1)), 92 | $ ((2, 0), (2, 1), (3, 0), (3, 1))] 93 | 94 | The same convention is used for higher dimension inputs. 95 | If the provided coordinates are of smaller dimension, they will be prepended 96 | by an Ellipsis. 97 | 98 | Todo: 99 | * Test the Ellipsis behavior. 100 | """ 101 | self.alpha = alpha 102 | 103 | # TODO: implement ValueErrors 104 | # groups must be indices and non overlapping 105 | if not isinstance(groups[0], torch.Tensor): 106 | groups = [torch.tensor(group) for group in groups] 107 | while groups[0].dim() < 2: 108 | groups = [group.unsqueeze(-1) for group in groups] 109 | 110 | self.groups = groups 111 | 112 | def __call__(self, x): 113 | group_norms = torch.stack([torch.linalg.norm(x[(...,) + tuple(g.T)], 114 | dim=-1) 115 | for g in self.groups]) 116 | 117 | return self.alpha * group_norms.sum(dim=0) 118 | 119 | @torch.no_grad() 120 | def prox(self, x, step_size=None): 121 | """ 122 | Returns the proximal operator for the (non overlapping) Group L1 norm. 123 | Args: 124 | x: torch.Tensor of shape (batch_size, *) 125 | 126 | step_size: float or torch.Tensor of shape (batch_size,) 127 | 128 | """ 129 | out = x.detach().clone() 130 | if isinstance(step_size, Number): 131 | step_size *= torch.ones(x.size(0), dtype=x.dtype, device=x.device) 132 | 133 | for g in self.groups: 134 | norm = torch.linalg.norm(x[(...,) + tuple(g.T)].view(x.size(0), -1), dim=-1) 135 | nonzero_norm = torch.nonzero(norm) 136 | out[(nonzero_norm, ...) + tuple(g.T)] = utils.bmul(out[(nonzero_norm, ...) + tuple(g.T)], 137 | F.relu(1 - utils.bmul(self.alpha * step_size[nonzero_norm], 138 | 1. / norm[nonzero_norm]))) 139 | return out 140 | -------------------------------------------------------------------------------- /chop/stochastic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Stochastic optimizers. 3 | ========================= 4 | 5 | This module contains stochastic first order optimizers. 6 | These are meant to be used in replacement of optimizers such as SGD, Adam etc, 7 | for training a model over batches of a dataset. 8 | The API in this module is inspired by torch.optim. 9 | 10 | """ 11 | 12 | import warnings 13 | 14 | import collections 15 | 16 | import torch 17 | from torch.optim import Optimizer 18 | import numpy as np 19 | 20 | 21 | EPS = np.finfo(np.float32).eps 22 | 23 | 24 | def backtracking_step_size( 25 | x, 26 | f_t, 27 | old_f_t, 28 | f_grad, 29 | certificate, 30 | lipschitz_t, 31 | max_step_size, 32 | update_direction, 33 | norm_update_direction, 34 | ): 35 | """Backtracking step-size finding routine for FW-like algorithms 36 | 37 | Args: 38 | x: array-like, shape (n_features,) 39 | Current iterate 40 | f_t: float 41 | Value of objective function at the current iterate. 42 | old_f_t: float 43 | Value of objective function at previous iterate. 44 | f_grad: callable 45 | Callable returning objective function and gradient at 46 | argument. 47 | certificate: float 48 | FW gap 49 | lipschitz_t: float 50 | Current value of the Lipschitz estimate. 51 | max_step_size: float 52 | Maximum admissible step-size. 53 | update_direction: array-like, shape (n_features,) 54 | Update direction given by the FW variant. 55 | norm_update_direction: float 56 | Squared L2 norm of update_direction 57 | Returns: 58 | step_size_t: float 59 | Step-size to be used to compute the next iterate. 60 | lipschitz_t: float 61 | Updated value for the Lipschitz estimate. 62 | f_next: float 63 | Objective function evaluated at x + step_size_t d_t. 64 | grad_next: array-like 65 | Gradient evaluated at x + step_size_t d_t. 66 | """ 67 | ratio_decrease = 0.9 68 | ratio_increase = 2.0 69 | max_ls_iter = 100 70 | if old_f_t is not None: 71 | tmp = (certificate ** 2) / ( 72 | 2 * (old_f_t - f_t) * norm_update_direction 73 | ) 74 | lipschitz_t = max(min(tmp, lipschitz_t), lipschitz_t * ratio_decrease) 75 | for _ in range(max_ls_iter): 76 | step_size_t = certificate / (norm_update_direction * lipschitz_t) 77 | if step_size_t < max_step_size: 78 | rhs = -0.5 * step_size_t * certificate 79 | else: 80 | step_size_t = max_step_size 81 | rhs = ( 82 | -step_size_t * certificate 83 | + 0.5 84 | * (step_size_t ** 2) 85 | * lipschitz_t 86 | * norm_update_direction 87 | ) 88 | f_next, grad_next = f_grad(x + step_size_t * update_direction) 89 | if f_next - f_t <= rhs + EPS: 90 | # .. sufficient decrease condition verified .. 91 | break 92 | else: 93 | lipschitz_t *= ratio_increase 94 | else: 95 | warnings.warn( 96 | "Exhausted line search iterations in minimize_frank_wolfe", 97 | RuntimeWarning, 98 | ) 99 | return step_size_t, lipschitz_t, f_next, grad_next 100 | 101 | 102 | def normalize_gradient(grad, normalization): 103 | if normalization == "none": 104 | return grad 105 | elif normalization == "Linf": 106 | grad = grad / abs(grad).max() 107 | 108 | elif normalization == "sign": 109 | grad = torch.sign(grad) 110 | 111 | elif normalization == "L2": 112 | grad = grad / torch.norm(grad) 113 | 114 | return grad 115 | 116 | 117 | class PGD(Optimizer): 118 | """Proximal Gradient Descent 119 | 120 | Args: 121 | params: [torch.Parameter] 122 | List of parameters to optimize over 123 | prox: [callable or None] 124 | List of prox operators, one per parameter. 125 | lr: float 126 | Learning rate 127 | momentum: float in [0, 1] 128 | 129 | normalization: str 130 | Type of gradient normalization to be used. 131 | Possible values are 'none', 'L2', 'Linf', 'sign'. 132 | 133 | """ 134 | 135 | name = "PGD" 136 | POSSIBLE_NORMALIZATIONS = {"none", "L2", "Linf", "sign"} 137 | 138 | def __init__( 139 | self, params, prox=None, lr=0.1, momentum=0.9, normalization="none" 140 | ): 141 | if prox is None: 142 | prox = [None] * len(list(params)) 143 | 144 | self.prox = [] 145 | for prox_el in prox: 146 | if prox_el is not None: 147 | self.prox.append( 148 | lambda x, s=None: prox_el(x.unsqueeze(0)).squeeze() 149 | ) 150 | else: 151 | self.prox.append(lambda x, s=None: x) 152 | 153 | if not (type(lr) == float or lr == "sublinear"): 154 | raise ValueError("lr must be float or 'sublinear'.") 155 | self.lr = lr 156 | 157 | if type(momentum) == float: 158 | if not (0.0 <= momentum <= 1.0): 159 | raise ValueError("Momentum must be in [0., 1.].") 160 | self.momentum = momentum 161 | 162 | if normalization in self.POSSIBLE_NORMALIZATIONS: 163 | self.normalization = normalization 164 | else: 165 | raise ValueError( 166 | f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}" 167 | ) 168 | defaults = dict( 169 | prox=self.prox, 170 | name=self.name, 171 | momentum=self.momentum, 172 | lr=self.lr, 173 | normalization=self.normalization, 174 | ) 175 | super(PGD, self).__init__(params, defaults) 176 | 177 | @property 178 | @torch.no_grad() 179 | def certificate(self): 180 | """A generator over the current convergence certificate estimate 181 | for each optimized parameter.""" 182 | for groups in self.param_groups: 183 | for p in groups["params"]: 184 | state = self.state[p] 185 | yield state["certificate"] 186 | 187 | @torch.no_grad() 188 | def step(self, closure=None): 189 | loss = None 190 | if closure is not None: 191 | with torch.enable_grad(): 192 | loss = closure() 193 | idx = 0 194 | for groups in self.param_groups: 195 | for p in groups["params"]: 196 | if p.grad is None: 197 | continue 198 | 199 | grad = p.grad 200 | 201 | if grad.is_sparse: 202 | raise RuntimeError( 203 | "We do not yet support sparse gradients." 204 | ) 205 | 206 | state = self.state[p] 207 | # Initialization 208 | if len(state) == 0: 209 | state["step"] = 0.0 210 | state["grad_estimate"] = torch.zeros_like( 211 | p, memory_format=torch.preserve_format 212 | ) 213 | 214 | state["step"] += 1.0 215 | state["grad_estimate"].add_( 216 | grad - state["grad_estimate"], alpha=1.0 - self.momentum 217 | ) 218 | 219 | grad_est = normalize_gradient( 220 | state["grad_estimate"], self.normalization 221 | ) 222 | 223 | if self.lr == "sublinear": 224 | step_size = 1.0 / (state["step"] + 1.0) 225 | else: 226 | step_size = self.lr 227 | 228 | new_p = self.prox[idx](p - step_size * grad_est, 1.0) 229 | state["certificate"] = torch.norm((p - new_p) / step_size) 230 | p.copy_(new_p) 231 | idx += 1 232 | return loss 233 | 234 | 235 | class PGDMadry(Optimizer): 236 | """PGD from [1]. 237 | 238 | Args: 239 | params: [torch.Tensor] 240 | list of parameters to optimize 241 | 242 | lmo: [callable] 243 | list of lmo operators for each parameter 244 | 245 | prox: [callable or None] or None 246 | list of prox operators for each parameter 247 | 248 | lr: float > 0 249 | learning rate 250 | 251 | References: 252 | Madry, Aleksander, and Makelov, Aleksandar, and Schmidt, Ludwig, 253 | and Tsipras, Dimitris, and Vladu, Adrian. Towards Deep Learning Models 254 | Resistant to Adversarial Attacks. ICLR 2018. 255 | """ 256 | 257 | name = "PGD-Madry" 258 | 259 | def __init__(self, params, lmo, prox=None, lr=1e-2): 260 | self.prox = [] 261 | for prox_el in prox: 262 | if prox_el is None: 263 | 264 | def prox_el(x, s=None): 265 | return x 266 | 267 | def _prox(x, s=None): 268 | return prox_el(x.unsqueeze(0), s).squeeze() 269 | 270 | self.prox.append(_prox) 271 | 272 | self.lmo = [] 273 | for lmo_el in lmo: 274 | 275 | def _lmo(u, x): 276 | update_direction, max_step_size = lmo_el( 277 | u.unsqueeze(0), x.unsqueeze(0) 278 | ) 279 | return update_direction.squeeze(dim=0), max_step_size 280 | 281 | self.lmo.append(_lmo) 282 | 283 | if not (type(lr) == float or lr == "sublinear"): 284 | raise ValueError("lr must be float or 'sublinear'.") 285 | 286 | self.lr = lr 287 | defaults = dict(prox=self.prox, lmo=self.lmo, name=self.name) 288 | super(PGDMadry, self).__init__(params, defaults) 289 | 290 | @property 291 | @torch.no_grad() 292 | def certificate(self): 293 | """A generator over the current convergence certificate estimate 294 | for each optimized parameter.""" 295 | for groups in self.param_groups: 296 | for p in groups["params"]: 297 | state = self.state[p] 298 | yield state["certificate"] 299 | 300 | @torch.no_grad() 301 | def step(self, step_size=None, closure=None): 302 | loss = None 303 | if closure is not None: 304 | with torch.enable_grad(): 305 | loss = closure() 306 | idx = 0 307 | for groups in self.param_groups: 308 | for p in groups["params"]: 309 | if p.grad is None: 310 | continue 311 | grad = p.grad 312 | if grad.is_sparse: 313 | raise RuntimeError( 314 | "We do not yet support sparse gradients." 315 | ) 316 | # Keep track of the step 317 | state = self.state[p] 318 | if len(state) == 0: 319 | state["step"] = 0.0 320 | state["step"] += 1.0 321 | 322 | if self.lr == "sublinear": 323 | step_size = 1.0 / (state["step"] + 1.0) 324 | else: 325 | step_size = self.lr 326 | lmo_res, _ = self.lmo[idx](-p.grad, p) 327 | normalized_grad = lmo_res + p 328 | new_p = self.prox[idx](p + step_size * normalized_grad) 329 | state["certificate"] = torch.norm((p - new_p) / step_size) 330 | p.copy_(new_p) 331 | idx += 1 332 | return loss 333 | 334 | 335 | class S3CM(Optimizer): 336 | """ 337 | Stochastic Three Composite Minimization (S3CM) 338 | 339 | Args: 340 | params: [torch.Tensor] 341 | list of parameters to optimize 342 | 343 | prox1: [callable or None] or None 344 | Proximal operator for first constraint set. 345 | 346 | prox2: [callable or None] or None 347 | Proximal operator for second constraint set. 348 | 349 | lr: float > 0 350 | Learning rate 351 | 352 | normalization: str in {'none', 'L2', 'Linf', 'sign'} 353 | Normalizes the gradient. 'L2', 'Linf' divide the gradient by the corresponding norm. 354 | 'sign' uses the sign of the gradient. 355 | 356 | References: 357 | Yurtsever, Alp, and Vu, Bang Cong, and Cevher, Volkan. 358 | "Stochastic Three-Composite Convex Minimization" NeurIPS 2016 359 | """ 360 | 361 | name = "S3CM" 362 | POSSIBLE_NORMALIZATIONS = {"none", "L2", "Linf", "sign"} 363 | 364 | def __init__( 365 | self, params, prox1=None, prox2=None, lr=0.1, normalization="none" 366 | ): 367 | if not type(lr) == float: 368 | raise ValueError("lr must be a float.") 369 | 370 | self.lr = lr 371 | if normalization in self.POSSIBLE_NORMALIZATIONS: 372 | self.normalization = normalization 373 | else: 374 | raise ValueError( 375 | f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}" 376 | ) 377 | 378 | if prox1 is None: 379 | prox1 = [None] * len(params) 380 | if prox2 is None: 381 | prox2 = [None] * len(params) 382 | 383 | self.prox1 = [] 384 | self.prox2 = [] 385 | 386 | for prox1_, prox2_ in zip(prox1, prox2): 387 | if prox1_ is None: 388 | 389 | def prox1_(x, s=None): 390 | return x 391 | 392 | if prox2_ is None: 393 | 394 | def prox2_(x, s=None): 395 | return x 396 | 397 | self.prox1.append( 398 | lambda x, s=None: prox1_(x.unsqueeze(0), s).squeeze(dim=0) 399 | ) 400 | self.prox2.append( 401 | lambda x, s=None: prox2_(x.unsqueeze(0), s).squeeze(dim=0) 402 | ) 403 | 404 | defaults = dict( 405 | lr=self.lr, 406 | prox1=self.prox1, 407 | prox2=self.prox2, 408 | normalization=self.normalization, 409 | ) 410 | super(S3CM, self).__init__(params, defaults) 411 | 412 | @torch.no_grad() 413 | def step(self, closure=None): 414 | loss = None 415 | if closure is not None: 416 | with torch.enable_grad(): 417 | loss = closure() 418 | idx = 0 419 | for group in self.param_groups: 420 | for p in group["params"]: 421 | if p.grad is None: 422 | continue 423 | grad = p.grad 424 | 425 | grad = normalize_gradient(grad, self.normalization) 426 | 427 | if grad.is_sparse: 428 | raise RuntimeError( 429 | "S3CM does not yet support sparse gradients." 430 | ) 431 | state = self.state[p] 432 | # initialization 433 | if len(state) == 0: 434 | state["step"] = 0 435 | state["iterate_1"] = p.clone().detach() 436 | state["iterate_2"] = self.prox2[idx](p, self.lr) 437 | state["dual"] = ( 438 | state["iterate_1"] - state["iterate_2"] 439 | ) / self.lr 440 | 441 | state["iterate_2"] = self.prox2[idx]( 442 | state["iterate_1"] + self.lr * state["dual"], self.lr 443 | ) 444 | state["dual"].add_( 445 | (state["iterate_1"] - state["iterate_2"]) / self.lr 446 | ) 447 | state["iterate_1"] = self.prox1[idx]( 448 | state["iterate_2"] - self.lr * (grad + state["dual"]), 449 | self.lr, 450 | ) 451 | 452 | p.copy_(state["iterate_2"]) 453 | idx += 1 454 | 455 | 456 | class PairwiseFrankWolfe(Optimizer): 457 | """Pairwise Frank-Wolfe algorithm""" 458 | 459 | name = "Pairwise-FW" 460 | 461 | def __init__(self, params, lmo_pairwise, lr=0.1, momentum=0.9): 462 | if not (type(lr) == float or lr == "sublinear"): 463 | raise ValueError("lr must be float or 'sublinear'.") 464 | 465 | def _lmo(u, x): 466 | update_direction, max_step_size = lmo_pairwise( 467 | u.unsqueeze(0), x.unsqueeze(0) 468 | ) 469 | return update_direction.squeeze(dim=0), max_step_size 470 | 471 | self.lmo = _lmo 472 | self.lr = lr 473 | self.momentum = momentum 474 | defaults = dict( 475 | lmo=self.lmo, name=self.name, lr=self.lr, momentum=self.momentum 476 | ) 477 | super(PairwiseFrankWolfe, self).__init__(params, defaults) 478 | 479 | raise NotImplementedError 480 | 481 | 482 | class FrankWolfe(Optimizer): 483 | """Class for the Stochastic Frank-Wolfe algorithm given in Mokhtari et al. 484 | This is essentially Frank-Wolfe with Momentum. 485 | We use the tricks from [1] for gradient normalization. 486 | 487 | Args: 488 | params: [torch.Tensor] 489 | Parameters to optimize over. 490 | 491 | lmo: [callable] 492 | List of LMO operators. 493 | 494 | lr: float 495 | Learning rate 496 | 497 | momentum: float in [0, 1] 498 | Amount of momentum to be used in gradient estimator 499 | 500 | weight_decay: float > 0 501 | Amount of L2 regularization to be added 502 | 503 | normalization: str in {'gradient', 'none'} 504 | Gradient normalization to be used. 'gradient' option is described in [1]. 505 | 506 | References: 507 | Pokutta, Sebastian, and Spiegel, Christoph and Zimmer, Max, 508 | Deep Neural Network Training with Frank Wolfe. 2020. 509 | """ 510 | 511 | name = "Frank-Wolfe" 512 | POSSIBLE_NORMALIZATIONS = {"gradient", "none"} 513 | 514 | def __init__( 515 | self, 516 | params, 517 | lmo, 518 | lr=0.1, 519 | momentum=0.9, 520 | weight_decay=0.0, 521 | normalization="none", 522 | track_active_set=False, 523 | ): 524 | 525 | self.track_active_set = track_active_set 526 | 527 | lmo_candidates = [] 528 | for oracle in lmo: 529 | if oracle is None: 530 | # Then FW will not be used on this parameter 531 | _lmo = None 532 | else: 533 | 534 | def _lmo(u, x): 535 | update_direction, max_step_size = oracle( 536 | u.unsqueeze(0), x.unsqueeze(0) 537 | ) 538 | return update_direction.squeeze(dim=0), max_step_size 539 | 540 | lmo_candidates.append(_lmo) 541 | 542 | self.lmo = [] 543 | useable_params = [] 544 | for param, oracle in zip(params, lmo): 545 | if oracle: 546 | useable_params.append(param) 547 | self.lmo.append(oracle) 548 | else: 549 | msg = ( 550 | f"No LMO was provided for parameter {param}. " 551 | f"Frank-Wolfe will not optimize this parameter. " 552 | f"Please use another optimizer." 553 | ) 554 | warnings.warn(msg) 555 | 556 | if type(lr) == float: 557 | if not (0.0 < lr <= 1.0): 558 | raise ValueError("lr must be in (0., 1.].") 559 | self.lr = lr 560 | if type(momentum) == float: 561 | if not (0.0 <= momentum <= 1.0): 562 | raise ValueError("Momentum must be in [0., 1.].") 563 | self.momentum = momentum 564 | if not (weight_decay >= 0): 565 | raise ValueError("weight_decay should be nonnegative.") 566 | self.weight_decay = weight_decay 567 | if normalization not in self.POSSIBLE_NORMALIZATIONS: 568 | raise ValueError( 569 | f"Normalization must be in {self.POSSIBLE_NORMALIZATIONS}." 570 | ) 571 | self.normalization = normalization 572 | defaults = dict( 573 | lmo=self.lmo, 574 | name=self.name, 575 | lr=self.lr, 576 | momentum=self.momentum, 577 | weight_decay=weight_decay, 578 | normalization=self.normalization, 579 | ) 580 | super(FrankWolfe, self).__init__(useable_params, defaults) 581 | 582 | @property 583 | @torch.no_grad() 584 | def certificate(self): 585 | """A generator over the current convergence certificate estimate 586 | for each optimized parameter.""" 587 | for group in self.param_groups: 588 | for p in group["params"]: 589 | state = self.state[p] 590 | yield state["certificate"] 591 | 592 | @torch.no_grad() 593 | def step(self, closure=None): 594 | """Performs a single optimization step. 595 | Arguments: 596 | closure (callable, optional): A closure that reevaluates the model 597 | and returns the loss 598 | """ 599 | loss = None 600 | if closure is not None: 601 | with torch.enable_grad(): 602 | loss = closure() 603 | idx = 0 604 | for group in self.param_groups: 605 | for p in group["params"]: 606 | if p.grad is None: 607 | continue 608 | grad = p.grad + self.weight_decay * p 609 | if grad.is_sparse: 610 | raise RuntimeError( 611 | "SFW does not yet support sparse gradients." 612 | ) 613 | state = self.state[p] 614 | if len(state) == 0: 615 | state["step"] = 0 616 | state["grad_estimate"] = torch.zeros_like( 617 | p, memory_format=torch.preserve_format 618 | ) 619 | if self.track_active_set: 620 | state["active_set"] = collections.defaultdict(float) 621 | state["active_set"][p.detach().clone()] = 1.0 622 | 623 | if self.lr == "sublinear": 624 | step_size = 1.0 / (state["step"] + 1.0) 625 | elif type(self.lr) == float: 626 | step_size = self.lr 627 | else: 628 | raise ValueError("lr must be float or 'sublinear'.") 629 | 630 | if self.momentum is None: 631 | rho = (1.0 / (state["step"] + 1)) ** (1 / 3) 632 | momentum = 1.0 - rho 633 | else: 634 | momentum = self.momentum 635 | 636 | state["step"] += 1.0 637 | 638 | state["grad_estimate"].add_( 639 | grad - state["grad_estimate"], alpha=1.0 - momentum 640 | ) 641 | update_direction, _ = self.lmo[idx](-state["grad_estimate"], p) 642 | state["certificate"] = ( 643 | -state["grad_estimate"] * update_direction 644 | ).sum() 645 | if self.normalization == "gradient": 646 | grad_norm = torch.norm(state["grad_estimate"]) 647 | step_size = min( 648 | 1.0, 649 | step_size 650 | * grad_norm 651 | / torch.linalg.norm(update_direction), 652 | ) 653 | elif self.normalization == "none": 654 | pass 655 | p.add_(step_size * update_direction) 656 | 657 | # Update active set 658 | if self.track_active_set: 659 | active_set = state["active_set"] 660 | curr_atom = update_direction + p 661 | for vertex in list(active_set): 662 | active_set[vertex] *= 1 - step_size 663 | if active_set[vertex] < 1e-7: 664 | del active_set[vertex] 665 | active_set[curr_atom] += step_size 666 | state["active_set"] = active_set 667 | idx += 1 668 | return loss 669 | -------------------------------------------------------------------------------- /chop/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from . import logging 3 | from . import data 4 | from . import image 5 | -------------------------------------------------------------------------------- /chop/utils/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image utilities. 3 | ================ 4 | Utility functions for image manipulation and visualization. 5 | """ 6 | 7 | from itertools import product 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | 12 | def matplotlib_imshow(img, one_channel=False, ax=None, **kwargs): 13 | if ax is None: 14 | ax = plt 15 | if one_channel: 16 | img = img.mean(dim=0) 17 | npimg = img.detach().cpu().numpy() 18 | if one_channel: 19 | ax.imshow(npimg, cmap="gray", **kwargs) 20 | else: 21 | ax.imshow(np.transpose(npimg, (1, 2, 0)), **kwargs) 22 | 23 | 24 | def matplotlib_imshow_batch(batch, labels=None, one_channel=False, axes=None, normalize=False, range=(0., 1.), 25 | title="", negative=False, **kwargs): 26 | npimgs = [img.detach().cpu().numpy() for img in batch] 27 | if labels is None: 28 | labels = [""] * batch.size(0) 29 | axes[0].set_title(title) 30 | for ax, img, label in zip(axes, npimgs, labels): 31 | if one_channel: 32 | if normalize: 33 | img = normalize_image(img, range, True, negative) 34 | ax.imshow(img, cmap='gray', **kwargs) 35 | else: 36 | img = np.transpose(img, (1, 2, 0)) 37 | if normalize: 38 | img = normalize_image(img, range, False, negative) 39 | ax.imshow(img, **kwargs) 40 | 41 | ax.set_ylabel(label) 42 | ax.set_xticks([]) 43 | ax.set_yticks([]) 44 | # ax.axis('off') 45 | 46 | 47 | 48 | def normalize_image(img, range=(0., 1.), one_channel=False, negative=False): 49 | """ 50 | Linearly normalizes an image to be in range. 51 | 52 | Supposes that img is in numpy image shape: (m, n, channels) or (m, n) 53 | """ 54 | new_min, new_max = range 55 | old_min, old_max = img.min(), img.max() 56 | new = new_min + (img - old_min) * (new_max - new_min) / (old_max - old_min) 57 | if negative: 58 | new = 1. - new 59 | return new 60 | 61 | 62 | def group_patches(x_patch_size=8, y_patch_size=8, x_image_size=32, y_image_size=32, n_channels=3): 63 | groups = [] 64 | for m in range(int(x_image_size / x_patch_size)): 65 | for p in range(int(y_image_size / y_patch_size)): 66 | groups.append([(c, m * x_patch_size + i, p * y_patch_size + j) 67 | for c, i, j in product(range(n_channels), 68 | range(x_patch_size), 69 | range(y_patch_size))]) 70 | return groups 71 | 72 | -------------------------------------------------------------------------------- /chop/utils/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logging utilities. 3 | ================== 4 | Utility functions and objects for logging. 5 | """ 6 | 7 | from datetime import datetime 8 | 9 | 10 | class Trace: 11 | """Trace callback""" 12 | 13 | def __init__(self, closure=None, log_x=True, log_grad=False, freq=1, callable=None): 14 | self.freq = int(freq) 15 | self.log_iterates = log_x 16 | self.closure = closure 17 | 18 | self.trace_x = [] 19 | self.trace_time = [] 20 | self.trace_step_size = [] 21 | if log_grad: 22 | self.trace_grad = [] 23 | if callable is not None: 24 | self.callable = callable 25 | self.trace_callable = [] 26 | self.trace_f = [] 27 | self.start = datetime.now() 28 | self._counter = 0 29 | 30 | def __call__(self, kwargs): 31 | if self.closure is None: 32 | self.closure = kwargs['closure'] 33 | 34 | if self._counter % self.freq == 0: 35 | self.trace_x.append(kwargs['x'].detach().clone().data) 36 | self.trace_f.append(self.closure(kwargs['x'], return_jac=False).clone().data) 37 | try: 38 | self.trace_callable.append(self.callable(kwargs)) 39 | except AttributeError: 40 | pass 41 | 42 | try: 43 | self.trace_grad.append() 44 | except AttributeError: 45 | pass 46 | 47 | delta = (datetime.now() - self.start).total_seconds() 48 | self.trace_time.append(delta) 49 | self.trace_step_size.append(kwargs['step_size']) 50 | 51 | self._counter += 1 52 | -------------------------------------------------------------------------------- /chop/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utility functions. 3 | ========================= 4 | 5 | """ 6 | 7 | from functools import wraps 8 | import torch 9 | 10 | 11 | def get_func_and_jac(func, x, *args, **kwargs): 12 | """Computes the jacobian of a batch-wise separable function func of x. 13 | func returns a torch.Tensor of shape (batch_size,) when 14 | x is a torch.Tensor of shape (batch_size, *). 15 | Adapted from 16 | https://gist.github.com/sbarratt/37356c46ad1350d4c30aefbd488a4faa 17 | by Shane Baratt""" 18 | 19 | batch_size = x.size(0) 20 | if x.is_leaf: 21 | x.requires_grad_(True) 22 | else: 23 | x.retain_grad() 24 | output = func(x, *args, **kwargs) 25 | if output.dim() == 0: 26 | output = output.unsqueeze(0) 27 | output.backward(torch.ones(batch_size, device=x.device)) 28 | return output.data, x.grad.data 29 | 30 | 31 | def closure(f): 32 | @wraps(f) 33 | def wrapper(x, return_jac=True, *args, **kwargs): 34 | """Adds jacobian computation when calling function. 35 | When return_jac is True, returns (value, jacobian) 36 | instead of just value.""" 37 | x.requires_grad_(True) 38 | if not return_jac: 39 | val = f(x, *args, **kwargs) 40 | if val.ndim == 0: 41 | val = torch.tensor([val], device=val.device) 42 | return val 43 | 44 | # Reset gradients 45 | x.grad = None 46 | return get_func_and_jac(f, x, *args, **kwargs) 47 | 48 | return wrapper 49 | 50 | 51 | def init_lipschitz(closure, x0, L0=1e-3, n_it=100): 52 | """Estimates the Lipschitz constant of closure 53 | for each datapoint in the batch using backtracking line-search. 54 | 55 | Args: 56 | closure: callable 57 | returns func_val, jacobian 58 | 59 | x0: torch.tensor of shape (batch_size, *) 60 | 61 | L0: float 62 | initial guess 63 | 64 | n_it: int 65 | number of iterations 66 | 67 | Returns: 68 | Lt: torch.tensor of shape (batch_size,) 69 | """ 70 | 71 | Lt = L0 * torch.ones(x0.size(0), device=x0.device, dtype=x0.dtype) 72 | 73 | f0, grad = closure(x0) 74 | xt = x0 - bmul((1. / Lt), grad) 75 | 76 | ft = closure(xt, return_jac=False) 77 | 78 | for _ in range(n_it): 79 | mask = (ft > f0) 80 | Lt[mask] *= 10. 81 | xt = x0 - bmul(1. / Lt, grad) 82 | ft = closure(xt, return_jac=False) 83 | 84 | if not mask.any(): 85 | break 86 | return Lt 87 | 88 | 89 | def bdot(tensor, other): 90 | """Returns the batch-wise dot product between tensor and other. 91 | Supposes that the shapes are (batch_size, *). 92 | This includes matrix inner products.""" 93 | 94 | t1 = tensor.view(tensor.size(0), -1) 95 | t2 = other.view(other.size(0), -1) 96 | return (t1 * t2).sum(dim=-1) 97 | 98 | 99 | def bmul(tensor, other): 100 | """Batch multiplies tensor and other""" 101 | return torch.mul(tensor.T, other.T).T 102 | 103 | 104 | def bdiv(tensor, other): 105 | """Batch divides tensor by other""" 106 | return bmul(tensor, 1. / other) 107 | 108 | 109 | def bnorm(tensor, *args, **kwargs): 110 | """Batch vector norms for tensor""" 111 | batch_size = tensor.size(0) 112 | return torch.linalg.norm(tensor.reshape(batch_size, -1), dim=-1, *args, **kwargs) 113 | 114 | 115 | def bmm(tensor, other): 116 | *batch_dims, m, n = tensor.shape 117 | *_, n2, p = other.shape 118 | if n2 != n: 119 | raise ValueError(f"Make sure shapes are compatible. Got " 120 | f"{tensor.shape}, {other.shape}.") 121 | t1 = tensor.view(-1, m, n) 122 | t2 = other.view(-1, n, p) 123 | return torch.bmm(t1, t2).view(*batch_dims, m, p) 124 | 125 | 126 | def bmv(tensor, vector): 127 | return bmm(tensor, vector.unsqueeze(-1)).squeeze(-1) 128 | 129 | 130 | # TODO: tolerance parameter 131 | def power_iteration(mat, n_iter: int=10, tol: float=1e-6): 132 | """ 133 | Obtains the largest singular value of a matrix, batch wise, 134 | and the associated left and right singular vectors. 135 | 136 | Args: 137 | mat: torch.Tensor of shape (*, M, N) 138 | n_iter: int 139 | number of iterations to perform 140 | tol: float 141 | Tolerance. Not used for now. 142 | """ 143 | if n_iter < 1 or type(n_iter) != int: 144 | raise ValueError("n_iter must be a positive integer.") 145 | *batch_shapes, m, n = mat.shape 146 | matT = torch.transpose(mat, -1, -2) 147 | vec_shape = (*batch_shapes, n) 148 | # Choose a random vector 149 | # to decrease the chance that our 150 | # initial right vector 151 | # is orthogonal to the first singular vector 152 | v_k = torch.normal(torch.zeros(vec_shape, device=mat.device), 153 | torch.ones(vec_shape, device=mat.device)) 154 | 155 | for _ in range(n_iter): 156 | u_k = bmv(mat, v_k) 157 | # get singular value 158 | sigma_k = torch.norm(u_k.view(-1, m), dim=-1).view(*batch_shapes) 159 | # normalize u 160 | u_k = bdiv(u_k, sigma_k) 161 | 162 | v_k = bmv(matT, u_k) 163 | norm_vk = torch.norm(v_k.view(-1, n), dim=-1).view(*batch_shapes) 164 | 165 | # normalize v 166 | v_k = bmul(v_k, 1. / norm_vk) 167 | return u_k, sigma_k, v_k 168 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | upload: 23 | gsutil -m rsync -r _build/html gs://openo.pt/chop 24 | 25 | clean: 26 | rm -rf $(BUILDDIR)/* 27 | rm -rf auto_examples/ -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | import os 17 | import sys 18 | sys.path.insert(0, os.path.abspath("..")) 19 | sys.path.insert(0, os.path.abspath("sphinx_ext")) 20 | 21 | # TODO: perhaps replace with https://github.com/westurner/sphinxcontrib-srclinks 22 | from github_link import make_linkcode_resolve 23 | 24 | 25 | # -- Project information ----------------------------------------------------- 26 | 27 | project = 'chop' 28 | copyright = '2020, chop developers' 29 | author = 'chop developers' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # Add any Sphinx extension module names here, as strings. They can be 35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 36 | # ones. 37 | extensions = [ 38 | "sphinx.ext.autosummary", 39 | "sphinx.ext.intersphinx", 40 | "sphinx.ext.todo", 41 | "sphinx.ext.napoleon", 42 | "sphinx.ext.ifconfig", 43 | "sphinx.ext.mathjax", 44 | "sphinx.ext.linkcode", 45 | "sphinx_gallery.gen_gallery", 46 | ] 47 | 48 | 49 | sphinx_gallery_conf = { 50 | # path to your examples scripts 51 | "examples_dirs": "../examples", 52 | "doc_module": "chop", 53 | # path where to save gallery generated examples 54 | "gallery_dirs": "auto_examples", 55 | "backreferences_dir": os.path.join("modules", "generated"), 56 | "show_memory": True, 57 | "reference_url": {"chop": None}, 58 | } 59 | 60 | 61 | mathjax_config = { 62 | "TeX": { 63 | "Macros": { 64 | "argmin": "\\DeclareMathOperator*{\\argmin}{\\mathbf{arg\\,min}}", 65 | "argmax": "\\DeclareMathOperator*{\\argmin}{\\mathbf{arg\\,max}}", 66 | "bs": "\\newcommand{\\bs}[1]{\\boldsymbol{#1}}", 67 | }, 68 | }, 69 | "tex2jax" : { 70 | "inlineMath": [['$', '$'], ['\(', '\)']], 71 | } 72 | } 73 | 74 | # The following is used by sphinx.ext.linkcode to provide links to github 75 | linkcode_resolve = make_linkcode_resolve( 76 | "copt", 77 | u"https://github.com/openopt/" "copt/blob/{revision}/" "{package}/{path}#L{lineno}", 78 | ) 79 | 80 | 81 | autosummary_generate = True 82 | autodoc_default_options = {"members": True, "inherited-members": True} 83 | 84 | 85 | # Add any paths that contain templates here, relative to this directory. 86 | templates_path = ['_templates'] 87 | 88 | # List of patterns, relative to source directory, that match files and 89 | # directories to ignore when looking for source files. 90 | # This pattern also affects html_static_path and html_extra_path. 91 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 92 | 93 | 94 | # -- Options for HTML output ------------------------------------------------- 95 | 96 | # The theme to use for HTML and HTML Help pages. See the documentation for 97 | # a list of builtin themes. 98 | # 99 | html_theme = 'alabaster' 100 | 101 | # Add any paths that contain custom static files (such as style sheets) here, 102 | # relative to this directory. They are copied after the builtin static files, 103 | # so a file named "default.css" will overwrite the builtin "default.css". 104 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | 2 | Welcome to chop's documentation! 3 | ================================ 4 | 5 | .. currentmodule:: chop 6 | 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | optim 12 | stochastic 13 | constraints 14 | adversary 15 | penalties 16 | utils 17 | 18 | 19 | .. toctree:: 20 | :maxdepth: 2 21 | :caption: Contents: 22 | 23 | 24 | Where to go from here? 25 | ---------------------- 26 | 27 | To know more about chop, check out our :ref:`example gallery ` or browse through the module reference using the left navigation bar. 28 | 29 | 30 | .. toctree:: 31 | :maxdepth: 2 32 | :hidden: 33 | 34 | auto_examples/index 35 | 36 | Last change: |today| 37 | 38 | Indices and tables 39 | ================== 40 | 41 | * :ref:`genindex` 42 | * :ref:`modindex` 43 | * :ref:`search` 44 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /doc/sphinx_ext/github_link.py: -------------------------------------------------------------------------------- 1 | from operator import attrgetter 2 | import inspect 3 | import subprocess 4 | import os 5 | import sys 6 | from functools import partial 7 | 8 | REVISION_CMD = 'git rev-parse --short HEAD' 9 | 10 | 11 | def _get_git_revision(): 12 | try: 13 | revision = subprocess.check_output(REVISION_CMD.split()).strip() 14 | except (subprocess.CalledProcessError, OSError): 15 | print('Failed to execute git to get revision') 16 | return None 17 | return revision.decode('utf-8') 18 | 19 | 20 | def _linkcode_resolve(domain, info, package, url_fmt, revision): 21 | """Determine a link to online source for a class/method/function 22 | 23 | This is called by sphinx.ext.linkcode 24 | 25 | An example with a long-untouched module that everyone has 26 | >>> _linkcode_resolve('py', {'module': 'tty', 27 | ... 'fullname': 'setraw'}, 28 | ... package='tty', 29 | ... url_fmt='http://hg.python.org/cpython/file/' 30 | ... '{revision}/Lib/{package}/{path}#L{lineno}', 31 | ... revision='xxxx') 32 | 'http://hg.python.org/cpython/file/xxxx/Lib/tty/tty.py#L18' 33 | """ 34 | 35 | if revision is None: 36 | return 37 | if domain not in ('py', 'pyx'): 38 | return 39 | if not info.get('module') or not info.get('fullname'): 40 | return 41 | 42 | class_name = info['fullname'].split('.')[0] 43 | if type(class_name) != str: 44 | # Python 2 only 45 | class_name = class_name.encode('utf-8') 46 | module = __import__(info['module'], fromlist=[class_name]) 47 | obj = attrgetter(info['fullname'])(module) 48 | 49 | try: 50 | fn = inspect.getsourcefile(obj) 51 | except Exception: 52 | fn = None 53 | if not fn: 54 | try: 55 | fn = inspect.getsourcefile(sys.modules[obj.__module__]) 56 | except Exception: 57 | fn = None 58 | if not fn: 59 | return 60 | 61 | fn = os.path.relpath(fn, 62 | start=os.path.dirname(__import__(package).__file__)) 63 | try: 64 | lineno = inspect.getsourcelines(obj)[1] 65 | except Exception: 66 | lineno = '' 67 | return url_fmt.format(revision=revision, package=package, 68 | path=fn, lineno=lineno) 69 | 70 | 71 | def make_linkcode_resolve(package, url_fmt): 72 | """Returns a linkcode_resolve function for the given URL format 73 | 74 | revision is a git commit reference (hash or name) 75 | 76 | package is the name of the root module of the package 77 | 78 | url_fmt is along the lines of ('https://github.com/USER/PROJECT/' 79 | 'blob/{revision}/{package}/' 80 | '{path}#L{lineno}') 81 | """ 82 | revision = _get_git_revision() 83 | return partial(_linkcode_resolve, revision=revision, package=package, 84 | url_fmt=url_fmt) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: chop 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - python=3 6 | - numpy 7 | - pandas 8 | - pytorch 9 | - cudatoolkit 10 | - torchvision 11 | - pip 12 | - pip: 13 | - tables 14 | - dill 15 | - easydict -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | .. _general_examples: 2 | 3 | Example Gallery 4 | =============== 5 | 6 | Miscellaneous examples 7 | ---------------------- 8 | 9 | Miscellaneous and introductory examples for chop. -------------------------------------------------------------------------------- /examples/adversarial_robustness/README.txt: -------------------------------------------------------------------------------- 1 | .. _adversarial_examples: 2 | 3 | 4 | Adversarial Robustness 5 | ---------------------- 6 | 7 | Examples on adversarial robustness. -------------------------------------------------------------------------------- /examples/adversarial_robustness/attack_benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Benchmark of attacks. 3 | ======================== 4 | """ 5 | import torch 6 | from tqdm import tqdm 7 | 8 | import chop 9 | from chop.optim import minimize_frank_wolfe, minimize_pgd, minimize_pgd_madry, minimize_three_split 10 | from chop.data import load_cifar10 11 | from chop.adversary import Adversary 12 | 13 | from robustbench.utils import load_model 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | batch_size = 50 19 | n_examples = 10000 20 | loaders = load_cifar10(test_batch_size=batch_size, data_dir='~/datasets') 21 | loader = loaders.test 22 | 23 | model_name = 'Engstrom2019Robustness' 24 | model = load_model(model_name, norm='Linf').to(device) 25 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 26 | 27 | # Define the perturbation constraint set 28 | max_iter = 20 29 | alpha = 8 / 255. 30 | constraint = chop.constraints.LinfBall(alpha) 31 | 32 | 33 | print(f"Evaluating model {model_name} on L{constraint.p} ball({alpha}).") 34 | 35 | n_correct = 0 36 | n_correct_adv_pgd_madry = 0 37 | n_correct_adv_pgd = 0 38 | n_correct_adv_split = 0 39 | n_correct_adv_fw = 0 40 | 41 | adversary_pgd = Adversary(minimize_pgd) 42 | adversary_pgd_madry = Adversary(minimize_pgd_madry) 43 | adversary_split = Adversary(minimize_three_split) 44 | adversary_fw = Adversary(minimize_frank_wolfe) 45 | 46 | for k, (data, target) in tqdm(enumerate(loader), total=len(loader)): 47 | data = data.to(device) 48 | target = target.to(device) 49 | 50 | def image_constraint_prox(delta, step_size=None): 51 | """Projects perturbation delta 52 | so that 0. <= data + delta <= 1.""" 53 | 54 | adv_img = torch.clamp(data + delta, 0, 1) 55 | delta = adv_img - data 56 | return delta 57 | 58 | def prox(delta, step_size=None): 59 | delta = constraint.prox(delta, step_size) 60 | delta = image_constraint_prox(delta, step_size) 61 | return delta 62 | 63 | 64 | _, delta_pgd = adversary_pgd.perturb(data, target, model, criterion, 65 | use_best=True, 66 | step='backtracking', 67 | prox=prox, 68 | max_iter=max_iter) 69 | 70 | delta_pgd_madry = torch.zeros_like(data) 71 | # _, delta_pgd_madry = adversary_pgd_madry.perturb(data, target, model, 72 | # criterion, 73 | # use_best=False, 74 | # prox=prox, 75 | # lmo=constraint.lmo, 76 | # step=2. / max_iter, 77 | # max_iter=max_iter) 78 | 79 | delta_split = torch.zeros_like(data) 80 | # _, delta_split = adversary_split.perturb(data, target, model, 81 | # criterion, 82 | # use_best=False, 83 | # prox1=constraint.prox, 84 | # prox2=image_constraint_prox, 85 | # max_iter=max_iter) 86 | 87 | delta_fw = torch.zeros_like(data) 88 | # _, delta_fw = adversary_fw.perturb(data, target, model, criterion, 89 | # lmo=constraint.lmo, 90 | # step='sublinear', 91 | # max_iter=max_iter 92 | # ) 93 | 94 | label = torch.argmax(model(data), dim=-1) 95 | n_correct += (label == target).sum().item() 96 | 97 | adv_label_pgd_madry = torch.argmax(model(data + delta_pgd_madry), dim=-1) 98 | n_correct_adv_pgd_madry += (adv_label_pgd_madry == target).sum().item() 99 | 100 | adv_label_pgd = torch.argmax(model(data + delta_pgd), dim=-1) 101 | n_correct_adv_pgd += (adv_label_pgd == target).sum().item() 102 | 103 | adv_label_split = torch.argmax(model(data + delta_split), dim=-1) 104 | n_correct_adv_split += (adv_label_split == target).sum().item() 105 | 106 | adv_label_fw = torch.argmax(model(data + delta_fw), dim=-1) 107 | n_correct_adv_fw += (adv_label_fw == target).sum().item() 108 | 109 | 110 | accuracy = n_correct / n_examples 111 | accuracy_adv_pgd_madry = n_correct_adv_pgd_madry / n_examples 112 | accuracy_adv_pgd = n_correct_adv_pgd / n_examples 113 | accuracy_adv_split = n_correct_adv_split / n_examples 114 | accuracy_adv_fw = n_correct_adv_fw / n_examples 115 | 116 | print(f"Accuracy: {accuracy:.4f}") 117 | print(f"RobustAccuracy PGD Madry: {accuracy_adv_pgd_madry:.4f}") 118 | print(f"RobustAccuracy PGD: {accuracy_adv_pgd:.4f}") 119 | print(f"RobustAccuracy Splitting: {accuracy_adv_split:.4f}") 120 | print(f"RobustAccuracy FW: {accuracy_adv_fw:.4f}") 121 | -------------------------------------------------------------------------------- /examples/adversarial_robustness/plot_train_robust_cifar10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of robust training on CIFAR10. 3 | ========================================= 4 | """ 5 | import matplotlib.pyplot as plt 6 | from chop.adversary import Adversary 7 | import torch 8 | from tqdm import tqdm 9 | from easydict import EasyDict 10 | 11 | import chop 12 | 13 | from torch.optim import SGD 14 | 15 | from torchvision import models 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | n_epochs = 100 20 | batch_size = 128 21 | batch_size_test = 100 22 | 23 | dataset = chop.utils.data.CIFAR10('~/datasets') 24 | loaders = dataset.loaders(batch_size, batch_size_test) 25 | 26 | n_train = len(loaders.train.dataset) 27 | n_test = len(loaders.test.dataset) 28 | 29 | model = models.resnet18(pretrained=False) 30 | model.to(device) 31 | 32 | criterion = torch.nn.CrossEntropyLoss() 33 | 34 | optimizer = SGD(model.parameters(), lr=.1, momentum=.9, weight_decay=5e-4) 35 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs) 36 | 37 | # Define the perturbation constraint set 38 | max_iter_train = 7 39 | max_iter_test = 20 40 | alpha = 8. / 255 41 | constraint = chop.constraints.LinfBall(alpha) 42 | criterion_adv = torch.nn.CrossEntropyLoss(reduction='none') 43 | 44 | print(f"Training on L{constraint.p} ball({alpha}).") 45 | 46 | 47 | adversary = Adversary(chop.optim.minimize_pgd_madry) 48 | 49 | results = EasyDict(train_acc=[], test_acc=[], 50 | train_acc_adv=[], test_acc_adv=[], 51 | train_adv_loss=[], 52 | test_adv_loss=[]) 53 | 54 | for _ in range(n_epochs): 55 | 56 | # Train 57 | n_correct = 0 58 | n_correct_adv = 0 59 | 60 | model.train() 61 | 62 | for k, (data, target) in tqdm(enumerate(loaders.train)): 63 | data = data.to(device) 64 | target = target.to(device) 65 | 66 | @torch.no_grad() 67 | def image_constraint_prox(delta, step_size=None): 68 | """Projects perturbation delta 69 | so that 0. <= data + delta <= 1.""" 70 | 71 | adv_img = torch.clamp(data + delta, 0, 1) 72 | delta = adv_img - data 73 | return delta 74 | 75 | @torch.no_grad() 76 | def prox(delta, step_size=None): 77 | delta = constraint.prox(delta, step_size) 78 | delta = image_constraint_prox(delta, step_size) 79 | return delta 80 | 81 | _, delta = adversary.perturb(data, target, model, 82 | criterion_adv, 83 | prox=prox, 84 | lmo=constraint.lmo, 85 | step=2. / max_iter_train, 86 | max_iter=max_iter_train) 87 | 88 | optimizer.zero_grad() 89 | 90 | output = model(data) 91 | output_adv = model(data + delta) 92 | loss = criterion(output, target) 93 | loss.backward() 94 | 95 | optimizer.step() 96 | 97 | pred = torch.argmax(output, dim=-1) 98 | pred_adv = torch.argmax(output_adv, dim=-1) 99 | 100 | n_correct += (pred == target).sum().item() 101 | n_correct_adv += (pred_adv == target).sum().item() 102 | 103 | results.train_acc.append(100. * n_correct / n_train) 104 | results.train_acc_adv.append(100. * n_correct_adv / n_train) 105 | print(f"Train Accuracy: {results.train_acc[-1] :.1f}%") 106 | print(f"Train Adv Accuracy: {results.train_acc_adv[-1]:.1f}%") 107 | 108 | # Test 109 | n_correct = 0 110 | n_correct_adv = 0 111 | 112 | model.eval() 113 | 114 | for k, (data, target) in tqdm(enumerate(loaders.test)): 115 | data = data.to(device) 116 | target = target.to(device) 117 | 118 | @torch.no_grad() 119 | def image_constraint_prox(delta, step_size=None): 120 | """Projects perturbation delta 121 | so that 0. <= data + delta <= 1.""" 122 | 123 | adv_img = torch.clamp(data + delta, 0, 1) 124 | delta = adv_img - data 125 | return delta 126 | 127 | @torch.no_grad() 128 | def prox(delta, step_size=None): 129 | delta = constraint.prox(delta, step_size) 130 | delta = image_constraint_prox(delta, step_size) 131 | return delta 132 | 133 | _, delta = adversary.perturb(data, target, model, 134 | criterion_adv, 135 | prox=prox, 136 | lmo=constraint.lmo, 137 | step=2. / max_iter_test, 138 | max_iter=max_iter_test) 139 | 140 | with torch.no_grad(): 141 | output = model(data) 142 | output_adv = model(data + delta) 143 | 144 | pred = torch.argmax(output, dim=-1) 145 | pred_adv = torch.argmax(output_adv, dim=-1) 146 | 147 | n_correct += (pred == target).sum().item() 148 | n_correct_adv += (pred_adv == target).sum().item() 149 | 150 | results.test_acc.append(100. * n_correct / n_test) 151 | results.test_acc_adv.append(100. * n_correct_adv / n_test) 152 | 153 | print(f"Test Accuracy: {results.test_acc[-1]:.1f}%") 154 | print(f"Test Adv Accuracy: {results.test_acc_adv[-1]:.1f}%") 155 | 156 | 157 | fig, ax = plt.subplots(nrows=2, sharex=True) 158 | 159 | ax[0].set_title("Clean data accuracies") 160 | ax[0].plot(results.train_acc, label='Train Acc') 161 | ax[0].plot(results.test_acc, label='Test Acc') 162 | ax[1].set_title("Adversarial data accuracies") 163 | ax[1].plot(results.train_acc_adv, label='Train Acc Adv') 164 | ax[1].plot(results.test_acc_adv, label='Test Acc Adv') 165 | plt.legend() 166 | plt.show() 167 | -------------------------------------------------------------------------------- /examples/adversarial_robustness/plot_universal_adversarial_examples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Universal Adversarial Examples 3 | ================================ 4 | 5 | This example shows how to generate and plot universal adversarial examples for 6 | CIFAR-10. 7 | 8 | We solve the following problem: 9 | 10 | ..math: 11 | \max_{\delta \in \mathcal B} \frac{1}{n} \sum_{i=1}^n \ell(h_\theta(x_i + rho(\delta - x_i)), y_i) 12 | """ 13 | 14 | import numpy as np 15 | from tqdm import tqdm 16 | import matplotlib.pyplot as plt 17 | 18 | import torch 19 | from robustbench.utils import load_model 20 | 21 | import chop 22 | from chop.utils.image import matplotlib_imshow 23 | from chop.utils.data import CIFAR10, NormalizingModel 24 | 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | 29 | data_dir = "~/datasets/" 30 | dataset = CIFAR10(data_dir, normalize=False) 31 | 32 | classes = dataset.classes 33 | 34 | # CIFAR10 model 35 | model = load_model('Standard') # Can be changed to any model from the robustbench model zoo 36 | 37 | # Add an initial layer to normalize data. 38 | # This allows us to use the [0, 1] image constraint set 39 | model = NormalizingModel(model, dataset) 40 | 41 | model = model.to(device) 42 | 43 | # Attack criterion 44 | criterion = torch.nn.CrossEntropyLoss() 45 | 46 | n_epochs = 1 47 | restarts = 5 48 | batch_size = 250 49 | 50 | loaders = dataset.loaders(batch_size, batch_size) 51 | 52 | # Optimize freely over a patch in the image 53 | length = 8 54 | x_start = 12 55 | y_start = 12 56 | rho = torch.zeros(3, 32, 32).to(device) 57 | rho[:, x_start:x_start+length, y_start:y_start+length] += 1. 58 | 59 | model.eval() 60 | 61 | losses = [] 62 | test_acc = [] 63 | test_acc_adv = [] 64 | 65 | 66 | def apply_perturbation(data, delta): 67 | return data + rho * (delta - data) 68 | 69 | 70 | best_loss = -np.inf 71 | 72 | for _ in range(restarts): 73 | # Random initialization 74 | delta = torch.rand(3, 32, 32).to(device) 75 | delta.requires_grad_(True) 76 | delta_opt = chop.stochastic.PGD([delta], 77 | prox=[chop.constraints.Box(0, 1).prox], 78 | lr=.2, normalization='Linf') 79 | 80 | for it in range(n_epochs): 81 | for data, target in tqdm(loaders.train): 82 | data = data.to(device) 83 | target = target.to(device) 84 | 85 | def loss_fun(delta): 86 | adv_data = apply_perturbation(data, delta) 87 | return -criterion(model(adv_data), target) 88 | 89 | delta_opt.zero_grad() 90 | loss_val = loss_fun(delta) 91 | loss_val.backward() 92 | delta_opt.step() 93 | losses.append(-loss_val.item()) 94 | 95 | if -loss_val.item() > best_loss: 96 | best_loss = -loss_val.item() 97 | best_delta = delta.detach().clone() 98 | 99 | correct = 0 100 | correct_adv = 0 101 | n_datapoints = 0 102 | for data, target in tqdm(loaders.test): 103 | n_datapoints += len(data) 104 | data = data.to(device) 105 | target = target.to(device) 106 | 107 | preds = model(data).argmax(1) 108 | adv_image = apply_perturbation(data, best_delta) 109 | preds_adv = model(adv_image).argmax(1) 110 | 111 | correct += (preds == target).sum().item() 112 | correct_adv += (preds_adv == target).sum().item() 113 | 114 | correct /= n_datapoints 115 | correct_adv /= n_datapoints 116 | 117 | print(f"Clean accuracy: {100 * correct:.2f}") 118 | print(f"Best attack accuracy {100 * correct_adv:.2f}") 119 | 120 | plt.plot(losses, label='Training loss') 121 | plt.legend() 122 | plt.show() 123 | 124 | 125 | fig, ax = plt.subplots() 126 | matplotlib_imshow(best_delta) 127 | plt.show() 128 | 129 | 130 | fig, ax = plt.subplots() 131 | data = data.to(device) 132 | pert_image = apply_perturbation(data[0], best_delta) 133 | matplotlib_imshow(pert_image) 134 | plt.show() 135 | -------------------------------------------------------------------------------- /examples/adversarial_robustness/plot_visualizing_adversarial_attacks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualizing Adversarial Examples 3 | ================================ 4 | 5 | This example shows how to generate and plot adversarial examples for a batch of datapoints from CIFAR-10, 6 | and compares the examples from different constraint sets, penalizations and solvers. 7 | 8 | """ 9 | 10 | # TODO: REFACTOR DATASETS FROM OUR DATALOADING UTILITIES 11 | 12 | import torch 13 | import torchvision 14 | # from torchvision import transforms 15 | # from robustbench.data import load_cifar10 16 | # from robustbench.utils import load_model 17 | 18 | import matplotlib.pyplot as plt 19 | 20 | import chop 21 | from chop.utils.image import group_patches, matplotlib_imshow_batch 22 | from chop.utils.data import ImageNet, CIFAR10, NormalizingModel 23 | from chop.utils.logging import Trace 24 | 25 | from sklearn.metrics import f1_score 26 | 27 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | device = torch.device('cpu') 29 | 30 | batch_size = 4 31 | 32 | # Note that this example uses load_cifar10 from the robustbench library 33 | # data, target = load_cifar10(n_examples=batch_size, data_dir='~/datasets') 34 | 35 | data_dir = "/scratch/data/imagenet12/" 36 | dataset = ImageNet(data_dir, normalize=False) 37 | 38 | # data_dir = "~/datasets/" 39 | # dataset = CIFAR10() 40 | 41 | normalize = dataset.normalize 42 | unnormalize = dataset.unnormalize 43 | 44 | data, target = dataset.load_k(batch_size, train=True, device=device, 45 | shuffle=True) 46 | classes = dataset.classes 47 | 48 | # ImageNet model 49 | model = torchvision.models.resnet18(pretrained=True) 50 | 51 | # CIFAR10 model 52 | # model = load_model('Standard') # Can be changed to any model from the robustbench model zoo 53 | 54 | model = model.to(device) 55 | # Attack criterion 56 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 57 | 58 | 59 | # Add first layer normalization 60 | # since the data is not normalized on loading. 61 | model = NormalizingModel(model, dataset) 62 | 63 | # Define the constraint set + initial point 64 | print("L2 norm constraint.") 65 | alpha = 3 66 | constraint = chop.constraints.L2Ball(alpha) 67 | 68 | max_str_length = 15 # for plot: max length of class name 69 | 70 | 71 | def image_constraint_prox(delta, step_size=None): 72 | adv_img = torch.clamp(data + delta, 0, 1) 73 | delta = adv_img - data 74 | return delta 75 | 76 | 77 | # TODO: think about using Dykstra instead of 2 alternate projections 78 | def prox(delta, step_size=None): 79 | """This needs to clip the data renormalized to [0, 1]. 80 | The epsilon scale is w/ regard to this unit box constraint.""" 81 | 82 | delta = constraint.prox(delta, step_size) 83 | delta = image_constraint_prox(delta, step_size) 84 | return delta 85 | 86 | 87 | adversary = chop.Adversary(chop.optim.minimize_pgd_madry) 88 | callback_L2 = Trace() 89 | 90 | # perturb's iterates are delta in [-mean / std, (1 - mean)/ std] 91 | _, delta = adversary.perturb(data, target, model, criterion, 92 | prox=prox, 93 | lmo=constraint.lmo, 94 | max_iter=20, 95 | step=2. / 20, 96 | callback=callback_L2) 97 | 98 | # Plot adversarial images 99 | fig, ax = plt.subplots(ncols=8, nrows=batch_size, figsize=(12, 6)) 100 | 101 | # Plot clean data 102 | matplotlib_imshow_batch(data, labels=[classes[int(k)][:max_str_length] for k in target], axes=ax[:, 0], 103 | title="Ground Truth") 104 | 105 | preds = model(data).argmax(dim=-1) 106 | matplotlib_imshow_batch(data, labels=[classes[int(k)][:max_str_length] for k in preds], axes=ax[:, 1], 107 | title="Prediction") 108 | 109 | # Adversarial Lp images 110 | adv_output = model(data + delta) 111 | adv_labels = torch.argmax(adv_output, dim=-1) 112 | matplotlib_imshow_batch(data + delta, labels=[classes[int(k)][:max_str_length] for k in adv_labels], axes=ax[:, 2], 113 | title=f'L{constraint.p}') 114 | 115 | # Perturbation 116 | matplotlib_imshow_batch(abs(delta), axes=ax[:, 5], normalize=True, 117 | title=f'L{constraint.p}', negative=False) 118 | 119 | print(f"F1 score: {f1_score(target.detach().cpu(), adv_labels.detach().cpu(), average='macro'):.3f}" 120 | f" for alpha={alpha:.4f}") 121 | 122 | 123 | print("GroupL1 constraint.") 124 | 125 | # CIFAR-10 126 | # groups = group_patches(x_patch_size=8, y_patch_size=8, x_image_size=32, y_image_size=32) 127 | 128 | # Imagenet 129 | groups = group_patches(x_patch_size=28, y_patch_size=28, x_image_size=224, y_image_size=224) 130 | 131 | for eps in [5e-2]: 132 | alpha = eps * len(groups) 133 | constraint_group = chop.constraints.GroupL1Ball(alpha, groups) 134 | adversary_group = chop.Adversary(chop.optim.minimize_frank_wolfe) 135 | 136 | # callback_group = Trace(callable=lambda kw: criterion(model(data + kw['x']), target)) 137 | callback_group = Trace() 138 | 139 | _, delta_group = adversary_group.perturb(data, target, model, criterion, 140 | lmo=constraint_group.lmo, 141 | max_iter=20, 142 | callback=callback_group) 143 | 144 | 145 | delta_group = image_constraint_prox(delta_group) 146 | 147 | # Show adversarial examples and perturbations 148 | adv_output_group = model(data + delta_group) 149 | adv_labels_group = torch.argmax(adv_output_group, dim=-1) 150 | 151 | matplotlib_imshow_batch(data + delta_group, labels=(classes[int(k)][:max_str_length] for k in adv_labels_group), 152 | axes=ax[:, 3], 153 | title='Group Lasso') 154 | 155 | matplotlib_imshow_batch(abs(delta_group), axes=ax[:, 6], normalize=True, 156 | # title='Group Lasso', negative=True) 157 | title='Group Lasso', negative=False) 158 | 159 | print(f"F1 score: {f1_score(target.detach().cpu(), adv_labels_group.detach().cpu(), average='macro'):.3f}" 160 | f" for alpha={alpha:.4f}") 161 | 162 | print("Nuclear norm ball adv examples") 163 | 164 | for alpha in [5.]: 165 | constraint_nuc = chop.constraints.NuclearNormBall(alpha) 166 | 167 | def prox_nuc(delta, step_size=None): 168 | delta = constraint_nuc.prox(delta, step_size) 169 | delta = image_constraint_prox(delta, step_size) 170 | return delta 171 | 172 | adversary = chop.Adversary(chop.optim.minimize_frank_wolfe) 173 | callback_nuc = Trace() 174 | 175 | _, delta_nuc = adversary.perturb(data, target, model, criterion, 176 | # prox=prox, 177 | lmo=constraint_nuc.lmo, 178 | max_iter=20, 179 | # step=2. / 20, 180 | callback=callback_nuc) 181 | 182 | # Clamp last iterate to image space 183 | delta_nuc = image_constraint_prox(delta_nuc) 184 | 185 | # Add nuclear examples to plot 186 | adv_output_nuc = model(data + delta_nuc) 187 | adv_labels_nuc = torch.argmax(adv_output_nuc, dim=-1) 188 | 189 | matplotlib_imshow_batch(data + delta_nuc, labels=(classes[int(k)][:max_str_length] for k in adv_labels_nuc), 190 | axes=ax[:, 4], 191 | title='Nuclear Norm') 192 | 193 | matplotlib_imshow_batch(abs(delta_nuc), axes=ax[:, 7], normalize=True, 194 | # title='Nuclear Norm', negative=True) 195 | title='Nuclear Norm', negative=False) 196 | 197 | print(f"F1 score: {f1_score(target.detach().cpu(), adv_labels_nuc.detach().cpu(), average='macro'):.3f}" 198 | f" for alpha={alpha:.4f}") 199 | 200 | plt.subplots_adjust(bottom=.06, top=0.5) 201 | 202 | plt.tight_layout() 203 | plt.show() 204 | plt.savefig('adv_attacks.png') 205 | 206 | 207 | # TODO refactor this in functions 208 | 209 | # Plot group lasso loss values 210 | # fig, ax = plt.subplots(figsize=(6, 10), nrows=batch_size, sharex=True) 211 | # for k in range(batch_size): 212 | # ax[k].plot([-trace[k] for trace in callback_group.trace_f]) 213 | # plt.tight_layout() 214 | # plt.show() 215 | 216 | # # Plot loss functions per datapoint 217 | # fig, ax = plt.subplots(figsize=(6, 10), nrows=batch_size, sharex=True) 218 | # for k in range(batch_size): 219 | # ax[k].plot([-trace[k] for trace in callback_nuc.trace_f]) 220 | 221 | # plt.tight_layout() 222 | # plt.show() 223 | 224 | # # Plot loss functions per datapoint 225 | # fig, ax = plt.subplots(figsize=(6, 10), nrows=batch_size, sharex=True) 226 | # for k in range(batch_size): 227 | # ax[k].plot([-trace[k] for trace in callback_L2.trace_f]) 228 | 229 | # plt.tight_layout() 230 | # plt.show() 231 | -------------------------------------------------------------------------------- /examples/plot_bounded_cone.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bounded Cone optimization 3 | ========================== 4 | In this example, we optimize a simple function over the intersection of a second order cone and a norm ball. 5 | """ 6 | 7 | import matplotlib.pyplot as plt 8 | from mpl_toolkits.mplot3d import Axes3D 9 | import torch 10 | import chop 11 | from chop.utils.logging import Trace 12 | 13 | u = torch.tensor([[0, 0, 1.]]) 14 | cos_alpha = .5 15 | cone = chop.constraints.Cone(u, cos_alpha) 16 | 17 | norm_bound = chop.constraints.L2Ball(1.) 18 | 19 | 20 | @chop.utils.closure 21 | def obj_fun(x): 22 | return ((x-torch.tensor([[0, 0, 2.]])) ** 2).sum(dim=-1) 23 | 24 | 25 | trace = Trace() 26 | 27 | x0 = torch.rand(*u.shape) 28 | 29 | 30 | res = chop.optim.minimize_three_split(obj_fun, x0, cone.prox, norm_bound.prox, 31 | max_iter=100, callback=trace) 32 | 33 | fig = plt.figure() 34 | plt.plot([(fval - 1.) for fval in trace.trace_f]) 35 | plt.title("Function values") 36 | 37 | # TODO: Plot the norm ball constraint and the second order cone constraint 38 | fig = plt.figure() 39 | ax = fig.add_subplot(111, projection='3d') 40 | ax.set_xlim(-1, 1) 41 | ax.set_ylim(-1, 1) 42 | ax.set_zlim(-1, 1) 43 | points = [p.squeeze() for p in trace.trace_x] 44 | xs, ys, zs = zip(*points) 45 | ax.plot(xs, ys, zs) 46 | plt.title("Iterates") 47 | 48 | print(f"Final iterate: {res.x}\nFinal value: {res.fval}") 49 | -------------------------------------------------------------------------------- /examples/plot_logistic_regression_L2_penalized.py: -------------------------------------------------------------------------------- 1 | """ 2 | L2-penalized Logistic Regression (full-batch) 3 | ================================== 4 | 5 | L2 penalized (unconstrained) logistic regression on the Covtype dataset. 6 | Uses full-batch gradient descent with line-search. 7 | """ 8 | 9 | import numpy as np 10 | 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | 15 | import chop 16 | 17 | from sklearn.datasets import fetch_covtype 18 | from sklearn.preprocessing import StandardScaler 19 | 20 | import matplotlib.pyplot as plt 21 | 22 | 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | 25 | # Regularization strength 26 | lmbd = 1. 27 | 28 | max_iter = 200 29 | 30 | # Load and prepare dataset 31 | X, y = fetch_covtype(return_X_y=True) 32 | y[y != 2] = -1 33 | y[y == 2] = 1 34 | 35 | scaler = StandardScaler() 36 | X = scaler.fit_transform(X) 37 | 38 | X = torch.tensor(X, dtype=torch.float32, device=device) 39 | y = torch.tensor(y, dtype=torch.float32, device=device) 40 | 41 | n_datapoints, n_features = X.shape 42 | 43 | # Initialize weights 44 | x0 = torch.zeros(1, n_features, dtype=X.dtype, device=X.device) 45 | 46 | # Binary cross entropy 47 | @chop.utils.closure 48 | def logloss_reg(x, pen=lmbd): 49 | y_X_x = y * (X @ x.flatten()) 50 | l2 = 0.5 * x.pow(2).sum() 51 | logloss = torch.log1p(torch.exp(-y_X_x)).sum() 52 | return (logloss + pen * l2) / X.size(0) 53 | 54 | 55 | @torch.no_grad() 56 | def log_accuracy(kwargs): 57 | out = X @ kwargs['x'].flatten() 58 | acc = (torch.sign(out).detach().cpu().numpy().round() == y.cpu().numpy()).mean() 59 | return acc 60 | 61 | 62 | callback = chop.utils.logging.Trace(callable=lambda kwargs: (log_accuracy(kwargs), 63 | logloss_reg(kwargs['x'], pen=0., 64 | return_jac=False).item())) 65 | 66 | result = chop.optim.minimize_pgd(logloss_reg, x0, callback=callback, max_iter=max_iter, step='backtracking') 67 | 68 | fig = plt.figure() 69 | plt.plot(np.array([val.item() for val in callback.trace_f]).clip(0, 1)) 70 | plt.title("Regularized Loss") 71 | plt.show() 72 | 73 | accuracies, losses = zip(*callback.trace_callable) 74 | 75 | for name, vals in (('Accuracy', accuracies), 76 | ('Loss', losses)): 77 | fig = plt.figure() 78 | plt.title(name) 79 | plt.plot(vals) 80 | plt.show() 81 | -------------------------------------------------------------------------------- /examples/plot_optim_dynamics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full-gradient constrained optimization dynamics. 3 | ================================================ 4 | Sets up simple 2-d problems on Linf balls to visualize dynamics of various 5 | constrained optimization algorithms. 6 | """ 7 | from functools import partial 8 | import matplotlib.pyplot as plt 9 | import torch 10 | 11 | from chop.constraints import LinfBall, Polytope 12 | from chop.optim import minimize_frank_wolfe, minimize_pgd, minimize_pgd_madry, minimize_three_split 13 | from chop.optim import minimize_pairwise_frank_wolfe 14 | from chop import utils 15 | 16 | torch.random.manual_seed(0) 17 | 18 | 19 | def setup_problem(make_nonconvex=False): 20 | alpha = 1. 21 | x_star = torch.tensor([.9, .3]).unsqueeze(0) 22 | x_0 = torch.zeros_like(x_star) 23 | 24 | @utils.closure 25 | def loss_func(x): 26 | val = .5 * ((x - x_star) ** 2).sum() 27 | if make_nonconvex: 28 | val += .01 * torch.sin(50 * torch.norm(x, p=1) + .1) 29 | return val 30 | 31 | constraint = LinfBall(alpha) 32 | 33 | return x_0, x_star, loss_func, constraint 34 | 35 | 36 | def log(kwargs, iterates, losses): 37 | x= kwargs['x'].squeeze().data 38 | iterates.append(x) 39 | val = kwargs['closure'](x, return_jac=False).data 40 | losses.append(val) 41 | 42 | 43 | if __name__ == "__main__": 44 | 45 | x_0, x_star, loss_func, constraint = setup_problem(make_nonconvex=False) 46 | iterations = 300 47 | 48 | iterates_pgd = [x_0.squeeze().data] 49 | iterates_pgd_madry = [x_0.squeeze().data] 50 | iterates_splitting = [x_0.squeeze().data] 51 | iterates_fw = [x_0.squeeze().data] 52 | iterates_pfw = [x_0.squeeze().data] 53 | 54 | losses_pgd = [loss_func(x_0, return_jac=False).data] 55 | losses_pgd_madry = [loss_func(x_0, return_jac=False).data] 56 | losses_splitting = [loss_func(x_0, return_jac=False).data] 57 | losses_fw = [loss_func(x_0, return_jac=False).data] 58 | losses_pfw = [loss_func(x_0, return_jac=False).data] 59 | 60 | log_pgd = partial(log, iterates=iterates_pgd, losses=losses_pgd) 61 | log_pgd_madry = partial(log, iterates=iterates_pgd_madry, losses=losses_pgd_madry) 62 | log_splitting = partial(log, iterates=iterates_splitting, losses=losses_splitting) 63 | log_fw = partial(log, iterates=iterates_fw, losses=losses_fw) 64 | log_pfw = partial(log, iterates=iterates_pfw, losses=losses_pfw) 65 | 66 | sol_pgd = minimize_pgd(loss_func, x_0, constraint.prox, 67 | max_iter=iterations, 68 | callback=log_pgd) 69 | 70 | sol_pgd_madry = minimize_pgd_madry(loss_func, x_0, constraint.prox, 71 | constraint.lmo, 72 | step=2. / iterations, 73 | max_iter=iterations, 74 | callback=log_pgd_madry) 75 | 76 | sol_splitting = minimize_three_split(loss_func, x_0, prox1=constraint.prox, 77 | max_iter=iterations, callback=log_splitting) 78 | 79 | sol_fw = minimize_frank_wolfe(loss_func, x_0, constraint.lmo, callback=log_fw, 80 | max_iter=iterations) 81 | polytope = Polytope(vertices=torch.tensor([ 82 | [0, 0], 83 | [1, 1], 84 | [1, -1], 85 | [-1, 1], 86 | [-1, -1] 87 | ], 88 | dtype=torch.float).unsqueeze(0)) 89 | sol_pfw = minimize_pairwise_frank_wolfe(loss_func, 0, polytope, 90 | callback=log_pfw, 91 | max_iter=iterations, 92 | lipschitz=2.) 93 | 94 | fig, ax = plt.subplots() 95 | ax.plot(losses_pgd, label="PGD") 96 | ax.plot(losses_pgd_madry, label="PGD Madry") 97 | ax.plot(losses_splitting, label="Operator Splitting") 98 | ax.plot(losses_fw, label="Frank-Wolfe") 99 | ax.plot(losses_pfw, label="Pairwise Frank-Wolfe") 100 | fig.legend() 101 | plt.savefig('losses.png') 102 | plt.show() 103 | 104 | fig, ax = plt.subplots(ncols=3, nrows=2, sharex=True, sharey=True) 105 | ax = ax.flatten() 106 | ax[0].plot(*zip(*iterates_pgd), '-o', label="PGD", alpha=.6) 107 | ax[0].set_xlim(-1, 1) 108 | ax[0].set_ylim(-1, 1) 109 | ax[0].legend() 110 | 111 | ax[1].plot(*zip(*iterates_pgd_madry), '-o', label="PGD Madry", alpha=.6) 112 | ax[1].set_xlim(-1, 1) 113 | ax[1].set_ylim(-1, 1) 114 | ax[1].legend() 115 | 116 | ax[2].plot(*zip(*iterates_splitting), '-o', label="Operator Splitting", alpha=.6) 117 | ax[2].set_xlim(-1, 1) 118 | ax[2].set_ylim(-1, 1) 119 | ax[2].legend() 120 | 121 | ax[3].plot(*zip(*iterates_fw), '-o', label="Frank-Wolfe", alpha=.6) 122 | ax[3].set_xlim(-1, 1) 123 | ax[3].set_ylim(-1, 1) 124 | ax[3].legend() 125 | 126 | ax[4].plot(*zip(*iterates_pfw), '-o', label="Pairwise Frank-Wolfe", alpha=.6) 127 | ax[4].set_xlim(-1, 1) 128 | ax[4].set_ylim(-1, 1) 129 | ax[4].legend() 130 | 131 | plt.show() 132 | plt.savefig("Pairwise.png") 133 | -------------------------------------------------------------------------------- /examples/plot_robust_PCA.py: -------------------------------------------------------------------------------- 1 | """ 2 | Robust PCA 3 | =========== 4 | 5 | This example fits a Robust PCA model to data. 6 | It uses a hybrid Frank-Wolfe and proximal method. 7 | See description in :func:`chop.optim.minimize_alternating_fw_prox`. 8 | 9 | 10 | We reproduce the synthetic experimental setting from `[Garber et al. 2018] `_. 11 | We aim to recover :math:`M = L + S + N`, where :math:`L` is rank :math:`p`, 12 | :math:`S` is :math:`p` sparse, and :math:`N` is standard Gaussian elementwise. 13 | """ 14 | 15 | 16 | import matplotlib.pyplot as plt 17 | import torch 18 | import chop 19 | from chop import utils 20 | from chop.utils.logging import Trace 21 | 22 | 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | 25 | m = 1000 26 | n = 1000 27 | 28 | r_p = [(5, 1e-3), (5, 3e-3), (25, 1e-3), (25, 3e-3), 29 | (25, 3e-2), (130, 1e-2)] 30 | 31 | for r, p in r_p: 32 | print(f'r={r} and p={p}') 33 | U = torch.normal(torch.zeros(1, m, r)) 34 | V = torch.normal(torch.zeros(1, r, n)) 35 | 36 | # Low rank component 37 | L = 10 * utils.bmm(U, V) 38 | 39 | # Sparse component 40 | S = 100 * torch.normal(torch.zeros(1, m, n)) 41 | 42 | S *= (torch.rand_like(S) <= p) 43 | 44 | # Add noise 45 | N = torch.normal(torch.zeros(1, m, n)) 46 | 47 | M = L + S + N 48 | M = M.to(device) 49 | 50 | @utils.closure 51 | def sqloss(Z): 52 | return .5 * torch.linalg.norm((Z - M).squeeze(), ord='fro') ** 2 53 | 54 | rnuc = torch.linalg.norm(L.squeeze(), ord='nuc') 55 | sL1 = abs(S).sum() 56 | 57 | print(f"Initial L1 norm: {sL1}") 58 | print(f"Initial Nuclear norm: {rnuc}") 59 | 60 | rank_constraint = chop.constraints.NuclearNormBall(rnuc) 61 | sparsity_constraint = chop.constraints.L1Ball(sL1) 62 | 63 | lmo = rank_constraint.lmo 64 | prox = sparsity_constraint.prox 65 | 66 | def things_to_log(kwargs): 67 | result = ( 68 | torch.linalg.norm(kwargs['y'].squeeze(), ord='nuc').item(), 69 | abs(kwargs['x']).sum().item(), 70 | sqloss(kwargs['x'] + kwargs['y'])[0].item() 71 | ) 72 | return result 73 | 74 | callback = Trace(log_x=False, callable=things_to_log) 75 | 76 | def line_search(kwargs): 77 | x = kwargs['x'] 78 | y = kwargs['y'] 79 | w = kwargs['w'] 80 | v = kwargs['v'] 81 | q = w + v 82 | z = x + y 83 | B = M - z 84 | A = q - z 85 | 86 | step_size = torch.clamp(utils.bdiv(utils.bdot(A, B), utils.bdot(A, A)), max=1.) 87 | assert (step_size >= 0).all() 88 | return step_size 89 | 90 | result = chop.optim.minimize_alternating_fw_prox(sqloss, 91 | torch.zeros_like(M, device=device), 92 | torch.zeros_like(M, device=device), 93 | prox=prox, lmo=lmo, 94 | L0=1., 95 | line_search=line_search, 96 | max_iter=200, 97 | callback=callback) 98 | 99 | low_rank_nuc, sparse_comp, f_vals = zip(*callback.trace_callable) 100 | 101 | fig, axes = plt.subplots(3, sharex=True, figsize=(6, 12)) 102 | fig.suptitle(f'r={r} and p={p}') 103 | 104 | axes[0].plot(f_vals) 105 | axes[0].set_title("Function values") 106 | 107 | axes[1].plot(sparse_comp) 108 | axes[1].set_title("L1 norm of sparse component") 109 | 110 | axes[2].plot(low_rank_nuc) 111 | axes[2].set_title("Nuclear Norm of low rank component") 112 | 113 | plt.tight_layout() 114 | plt.show() 115 | print(f"L1 norm: {abs(result.x).sum()}") 116 | print(f"Nuc Norm: {torch.linalg.norm(result.y.squeeze(), ord='nuc')}") 117 | -------------------------------------------------------------------------------- /examples/plot_stochastic_dynamics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Stochastic constrained optimization dynamics. 3 | ================================================ 4 | Sets up simple 2-d problems on Linf balls to visualize dynamics of various 5 | stochastic constrained optimization algorithms. 6 | """ 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | 12 | from chop.constraints import LinfBall 13 | from chop.stochastic import PGD, PGDMadry, FrankWolfe, S3CM 14 | 15 | torch.random.manual_seed(0) 16 | 17 | OPTIMIZER_CLASSES = [PGD, PGDMadry, FrankWolfe, S3CM] 18 | 19 | 20 | def setup_problem(make_nonconvex=False): 21 | radius = 1. 22 | x_star = torch.tensor([radius, radius/2]) 23 | x_0 = torch.zeros_like(x_star) 24 | 25 | def loss_func(x): 26 | val = .5 * ((x - x_star) ** 2).sum() 27 | if make_nonconvex: 28 | val += .1 * torch.sin(50 * torch.norm(x, p=1) + .1) 29 | return val 30 | 31 | constraint = LinfBall(radius) 32 | 33 | return x_0, x_star, loss_func, constraint 34 | 35 | 36 | def optimize(x_0, loss_func, constraint, optimizer_class, iterations=10): 37 | x = x_0.detach().clone() 38 | x.requires_grad = True 39 | # Use Madry's heuristic for step size 40 | lr = { 41 | FrankWolfe: 2.5 / iterations, 42 | PGD: 2.5 * constraint.alpha / iterations * 2., 43 | PGDMadry: 2.5 / iterations, 44 | S3CM: 2.5 / iterations 45 | } 46 | 47 | prox, lmo = constraint.prox, constraint.lmo 48 | 49 | constraint_oracles = { 50 | PGD: { 51 | 'prox': [prox] 52 | }, 53 | PGDMadry: { 54 | 'prox': [prox], 55 | 'lmo': [lmo] 56 | }, 57 | FrankWolfe: { 58 | 'lmo': [lmo] 59 | }, 60 | S3CM: { 61 | 'prox2': [prox] 62 | } 63 | } 64 | 65 | optimizer = optimizer_class([x], **constraint_oracles[optimizer_class], lr=lr[optimizer_class]) 66 | iterates = [x.data.numpy().copy()] 67 | losses = [] 68 | 69 | for _ in range(iterations): 70 | optimizer.zero_grad() 71 | loss = loss_func(x) 72 | loss.backward() 73 | optimizer.step() 74 | losses.append(loss.item()) 75 | iterates.append(x.data.numpy().copy()) 76 | 77 | loss = loss_func(x) 78 | losses.append(loss.item()) 79 | return losses, iterates 80 | 81 | 82 | if __name__ == "__main__": 83 | 84 | x_0, x_star, loss_func, constraint = setup_problem(make_nonconvex=False) 85 | iterations = 10 86 | losses_all = {} 87 | iterates_all = {} 88 | for opt_class in OPTIMIZER_CLASSES: 89 | losses_, iterates_ = optimize(x_0, 90 | loss_func, 91 | constraint, 92 | opt_class, 93 | iterations) 94 | losses_all[opt_class.name] = losses_ 95 | iterates_all[opt_class.name] = iterates_ 96 | # print(losses) 97 | fig, ax = plt.subplots() 98 | for opt_class in OPTIMIZER_CLASSES: 99 | ax.plot(np.arange(iterations + 1), losses_all[opt_class.name], 100 | label=opt_class.name) 101 | fig.legend() 102 | plt.show() 103 | 104 | fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True) 105 | for ax, opt_class in zip(axes.reshape(-1), OPTIMIZER_CLASSES): 106 | ax.plot(*zip(*iterates_all[opt_class.name]), '-o', label=opt_class.name, alpha=.6) 107 | ax.set_xlim(-1, 1) 108 | ax.set_ylim(-1, 1) 109 | ax.legend(loc='lower left') 110 | plt.show() 111 | -------------------------------------------------------------------------------- /examples/training_L1_constrained_net_on_CIFAR10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constrained Neural Network Training. 3 | ====================================== 4 | Trains a ResNet model on CIFAR10 using constraints on the weights. 5 | This example is inspired by the official PyTorch MNIST example, which 6 | can be found [here](https://github.com/pytorch/examples/blob/master/mnist/main.py). 7 | """ 8 | from tqdm import tqdm 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | from torchvision import models 15 | from torch.nn import functional as F 16 | from easydict import EasyDict 17 | 18 | import chop 19 | 20 | # Setup 21 | torch.manual_seed(0) 22 | 23 | use_cuda = torch.cuda.is_available() 24 | device = torch.device("cuda" if use_cuda else "cpu") 25 | 26 | # Data Loaders 27 | print("Loading data...") 28 | dataset = chop.utils.data.CIFAR10("~/datasets/") 29 | loaders = dataset.loaders() 30 | # Model setup 31 | 32 | 33 | print("Initializing model.") 34 | model = models.resnet18() 35 | model.to(device) 36 | 37 | criterion = nn.CrossEntropyLoss() 38 | 39 | # Outer optimization parameters 40 | nb_epochs = 200 41 | momentum = .9 42 | lr = 0.1 43 | 44 | # Make constraints 45 | print("Preparing constraints.") 46 | constraints = chop.constraints.make_Lp_model_constraints(model, p=1, value=10000) 47 | proxes = [constraint.prox if constraint else None for constraint in constraints] 48 | lmos = [constraint.lmo if constraint else None for constraint in constraints] 49 | 50 | print("Projecting model parameters in their associated constraint sets.") 51 | chop.constraints.make_feasible(model, proxes) 52 | 53 | optimizer = chop.stochastic.FrankWolfe(model.parameters(), lmos, 54 | lr=lr, momentum=momentum, 55 | weight_decay=5e-4, 56 | normalization='gradient') 57 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) 58 | 59 | bias_params = [param for name, param in model.named_parameters() if 'bias' in name] 60 | bias_opt = chop.stochastic.PGD(bias_params, lr=1e-1) 61 | bias_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(bias_opt) 62 | 63 | print("Training...") 64 | # Training loop 65 | for epoch in range(nb_epochs): 66 | model.train() 67 | train_loss = 0. 68 | for data, target in tqdm(loaders.train, desc=f'Training epoch {epoch}/{nb_epochs - 1}'): 69 | data, target = data.to(device), target.to(device) 70 | 71 | optimizer.zero_grad() 72 | bias_opt.zero_grad() 73 | loss = criterion(model(data), target) 74 | loss.backward() 75 | optimizer.step() 76 | bias_opt.step() 77 | 78 | train_loss += loss.item() 79 | train_loss /= len(loaders.train) 80 | print(f'Training loss: {train_loss:.3f}') 81 | 82 | # Evaluate on clean and adversarial test data 83 | 84 | model.eval() 85 | report = EasyDict(nb_test=0, correct=0, correct_adv_pgd=0, 86 | correct_adv_pgd_madry=0, 87 | correct_adv_fw=0, correct_adv_mfw=0) 88 | val_loss = 0 89 | with torch.no_grad(): 90 | for data, target in tqdm(loaders.test, desc=f'Val epoch {epoch}/{nb_epochs - 1}'): 91 | data, target = data.to(device), target.to(device) 92 | 93 | # Compute corresponding predictions 94 | logits = model(data) 95 | _, pred = logits.max(1) 96 | val_loss += criterion(logits, target) 97 | # Get clean accuracies 98 | report.nb_test += data.size(0) 99 | report.correct += pred.eq(target).sum().item() 100 | 101 | val_loss /= report.nb_test 102 | print(f'Val acc on clean examples (%): {report.correct / report.nb_test * 100.:.3f}') 103 | 104 | scheduler.step(val_loss) 105 | bias_scheduler.step(val_loss) -------------------------------------------------------------------------------- /examples/training_constrained_net_on_mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constrained neural network training. 3 | ====================================== 4 | Trains a LeNet5 model on MNIST using constraints on the weights. 5 | This example is inspired by the official PyTorch MNIST example, which 6 | can be found [here](https://github.com/pytorch/examples/blob/master/mnist/main.py). 7 | """ 8 | from tqdm import tqdm 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | from torch.nn import functional as F 15 | from easydict import EasyDict 16 | 17 | import chop 18 | 19 | # Setup 20 | torch.manual_seed(0) 21 | 22 | use_cuda = torch.cuda.is_available() 23 | device = torch.device("cuda" if use_cuda else "cpu") 24 | 25 | # Data Loaders 26 | print("Loading data...") 27 | dataset = chop.utils.data.MNIST("~/datasets/") 28 | loaders = dataset.loaders() 29 | # Model setup 30 | 31 | 32 | class Net(nn.Module): 33 | def __init__(self): 34 | super(Net, self).__init__() 35 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 36 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 37 | self.dropout1 = nn.Dropout(0.25) 38 | self.dropout2 = nn.Dropout(0.5) 39 | self.fc1 = nn.Linear(9216, 128) 40 | self.fc2 = nn.Linear(128, 10) 41 | 42 | def forward(self, x): 43 | x = self.conv1(x) 44 | x = F.relu(x) 45 | x = self.conv2(x) 46 | x = F.relu(x) 47 | x = F.max_pool2d(x, 2) 48 | x = self.dropout1(x) 49 | x = torch.flatten(x, 1) 50 | x = self.fc1(x) 51 | x = F.relu(x) 52 | x = self.dropout2(x) 53 | x = self.fc2(x) 54 | output = F.log_softmax(x, dim=1) 55 | return output 56 | 57 | 58 | print("Initializing model.") 59 | model = Net() 60 | model.to(device) 61 | 62 | criterion = nn.CrossEntropyLoss() 63 | 64 | # Outer optimization parameters 65 | nb_epochs = 20 66 | momentum = .9 67 | lr = 0.3 68 | 69 | # Make constraints 70 | print("Preparing constraints.") 71 | constraints = chop.constraints.make_Lp_model_constraints(model, p=1, value=10000) 72 | proxes = [constraint.prox if constraint else None for constraint in constraints] 73 | lmos = [constraint.lmo if constraint else None for constraint in constraints] 74 | 75 | print("Projecting model parameters in their associated constraint sets.") 76 | chop.constraints.make_feasible(model, proxes) 77 | 78 | optimizer = chop.stochastic.FrankWolfe(model.parameters(), lmos, 79 | lr=lr, momentum=momentum, 80 | weight_decay=3e-4, 81 | normalization='gradient') 82 | 83 | bias_params = [param for name, param in model.named_parameters() if 'bias' in name] 84 | bias_opt = chop.stochastic.PGD(bias_params, lr=1e-2) 85 | 86 | print("Training...") 87 | # Training loop 88 | for epoch in range(nb_epochs): 89 | model.train() 90 | train_loss = 0. 91 | for data, target in tqdm(loaders.train, desc=f'Training epoch {epoch}/{nb_epochs - 1}'): 92 | data, target = data.to(device), target.to(device) 93 | 94 | optimizer.zero_grad() 95 | bias_opt.zero_grad() 96 | loss = criterion(model(data), target) 97 | loss.backward() 98 | optimizer.step() 99 | bias_opt.step() 100 | 101 | train_loss += loss.item() 102 | train_loss /= len(loaders.train) 103 | print(f'Training loss: {train_loss:.3f}') 104 | 105 | # Evaluate on clean and adversarial test data 106 | 107 | model.eval() 108 | report = EasyDict(nb_test=0, correct=0, correct_adv_pgd=0, 109 | correct_adv_pgd_madry=0, 110 | correct_adv_fw=0, correct_adv_mfw=0) 111 | 112 | for data, target in tqdm(loaders.test, desc=f'Val epoch {epoch}/{nb_epochs - 1}'): 113 | data, target = data.to(device), target.to(device) 114 | 115 | # Compute corresponding predictions 116 | _, pred = model(data).max(1) 117 | 118 | # Get clean accuracies 119 | report.nb_test += data.size(0) 120 | report.correct += pred.eq(target).sum().item() 121 | 122 | print(f'Val acc on clean examples (%): {report.correct / report.nb_test * 100.:.3f}') 123 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "numpy", "scipy", "tqdm", "torch"] # PEP 518 specifications. 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages 2 | from setuptools import setup 3 | 4 | CLASSIFIERS = """\ 5 | Development Status :: 3 - Alpha 6 | Intended Audience :: Science/Research 7 | Intended Audience :: Developers 8 | License :: OSI Approved :: BSD License 9 | Programming Language :: Python 10 | Programming Language :: Python :: 3 11 | Topic :: Software Development 12 | Operating System :: POSIX 13 | Operating System :: Unix 14 | """ 15 | 16 | with open("README.md", 'r', encoding='utf-8') as f: 17 | README = f.read() 18 | 19 | setup( 20 | name="chop-pytorch", 21 | description="Continuous and constrained optimization with PyTorch", 22 | long_description=README, 23 | long_description_content_type='text/markdown', 24 | version="0.0.3", 25 | author="Geoffrey Negiar", 26 | author_email="geoffrey_negiar@berkeley.edu", 27 | url="http://pypi.python.org/pypi/chop-pytorch", 28 | packages=find_packages(), 29 | install_requires=["numpy", "scipy", "torch", "torchvision", 30 | "easydict", "matplotlib", "tqdm"], 31 | setup_requires=['wheel'], 32 | classifiers=[_f for _f in CLASSIFIERS.split("\n") if _f], 33 | ) 34 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openopt/chop/27b1e4ff091b842ea1c0c539ac7c4f929e1c7a0d/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_adversary.py: -------------------------------------------------------------------------------- 1 | """Testing our adversarial attacks""" 2 | import pytest 3 | import torch 4 | from torch import nn 5 | 6 | 7 | import numpy as np 8 | 9 | import chop 10 | from chop import optim 11 | from chop.adversary import Adversary 12 | 13 | 14 | class LinearModel(nn.Module): 15 | def __init__(self): 16 | super(LinearModel, self).__init__() 17 | self.linear = nn.Linear(25 * 25, 2) 18 | 19 | def forward(self, x): 20 | batch_size = x.size(0) 21 | x = x.view(-1) 22 | return self.linear(x).view(batch_size, -1) 23 | 24 | 25 | @pytest.mark.parametrize('algorithm', [optim.minimize_pgd, optim.minimize_pgd_madry, 26 | optim.minimize_frank_wolfe]) 27 | @pytest.mark.parametrize('step_size', [1, .5, .1, .05, .001, 0.]) 28 | @pytest.mark.parametrize('p', [1, 2, np.inf]) 29 | def test_adversary_synthetic_data(algorithm, step_size, p): 30 | # Setup 31 | torch.manual_seed(0) 32 | use_cuda = torch.cuda.is_available() 33 | device = torch.device("cuda" if use_cuda else "cpu") 34 | 35 | data = torch.rand((1, 25, 25)) 36 | target = torch.zeros(1).long() 37 | 38 | data = data.to(device) 39 | target = target.to(device) 40 | 41 | model = LinearModel() 42 | model.to(device) 43 | criterion = nn.CrossEntropyLoss() 44 | constraint = chop.constraints.make_LpBall(alpha=1., p=p) 45 | 46 | adv = Adversary(algorithm) 47 | 48 | # Get nominal loss 49 | output = model(data) 50 | loss = criterion(output, target) 51 | 52 | # Algorithm arguments: 53 | if algorithm == optim.minimize_pgd: 54 | alg_kwargs = { 55 | 'prox': constraint.prox, 56 | 'max_iter': 50 57 | } 58 | elif algorithm == optim.minimize_pgd_madry: 59 | alg_kwargs = { 60 | 'prox': constraint.prox, 61 | 'lmo': constraint.lmo, 62 | 'max_iter': 50, 63 | 'step': 2. * constraint.alpha / 50 64 | } 65 | 66 | elif algorithm == optim.minimize_frank_wolfe: 67 | alg_kwargs = { 68 | 'lmo': constraint.lmo, 69 | 'step': 'sublinear', 70 | 'max_iter': 50 71 | } 72 | 73 | # Run perturbation 74 | adv_loss, delta = adv.perturb(data, target, model, criterion, **alg_kwargs) 75 | 76 | -------------------------------------------------------------------------------- /tests/test_constraints.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | 5 | import torch 6 | from torch import nn 7 | from torchvision import models 8 | import torch.nn.functional as F 9 | 10 | import chop 11 | from chop import utils 12 | from chop.utils.image import group_patches 13 | import chop.constraints 14 | 15 | 16 | def test_nuclear_norm(): 17 | 18 | batch_size = 8 19 | channels = 3 20 | m = 32 21 | n = 35 22 | alpha = 1. 23 | constraint = chop.constraints.NuclearNormBall(alpha) 24 | 25 | grad = torch.rand(batch_size, channels, m, n) 26 | iterate = torch.rand(batch_size, channels, m, n) 27 | constraint.lmo(-grad, iterate) 28 | constraint.prox(iterate - .1 * grad) 29 | 30 | 31 | @pytest.mark.parametrize('constraint', [chop.constraints.L1Ball, 32 | chop.constraints.L2Ball, 33 | chop.constraints.LinfBall, 34 | chop.constraints.Simplex, 35 | chop.constraints.NuclearNormBall, 36 | chop.constraints.GroupL1Ball, 37 | chop.constraints.Cone]) 38 | @pytest.mark.parametrize('alpha', [1., 10., .5]) 39 | def test_projections(constraint, alpha): 40 | """Tests that projections are true projections: 41 | ..math:: 42 | p\circ p = p 43 | """ 44 | batch_size = 8 45 | if constraint == chop.constraints.GroupL1Ball: 46 | groups = group_patches() 47 | prox = constraint(alpha, groups).prox 48 | elif constraint == chop.constraints.Cone: 49 | directions = torch.rand(batch_size, 3, 32, 32) 50 | prox = constraint(directions, cos_angle=.2).prox 51 | else: 52 | prox = constraint(alpha).prox 53 | 54 | for _ in range(10): 55 | data = torch.rand(batch_size, 3, 32, 32) 56 | 57 | proj_data = prox(data) 58 | # SVD reconstruction doesn't do better than 1e-5 59 | double_proj = prox(proj_data) 60 | assert double_proj.allclose(proj_data, atol=1e-5), (double_proj, proj_data) 61 | 62 | 63 | def test_GroupL1LMO(): 64 | batch_size = 2 65 | alpha = 1. 66 | groups = group_patches(x_patch_size=2, y_patch_size=2, x_image_size=6, y_image_size=6) 67 | constraint = chop.constraints.GroupL1Ball(alpha, groups) 68 | data = torch.rand(batch_size, 3, 6, 6) 69 | grad = torch.rand(batch_size, 3, 6, 6) 70 | 71 | constraint.lmo(-grad, data) 72 | 73 | 74 | def test_groupL1Prox(): 75 | batch_size = 2 76 | alpha = 10 77 | groups = group_patches(x_patch_size=2, y_patch_size=2, x_image_size=6, y_image_size=6) 78 | constraint = chop.constraints.GroupL1Ball(alpha, groups) 79 | data = torch.rand(batch_size, 3, 6, 6) 80 | 81 | constraint.prox(-data, step_size=.3) 82 | 83 | 84 | def test_cone_constraint(): 85 | # Standard second order cone 86 | u = torch.tensor([[0., 0., 1.]]) 87 | cos_alpha = .5 88 | 89 | cone = chop.constraints.Cone(u, cos_alpha) 90 | 91 | for inp, correct_prox in [(torch.tensor([[1., 0, 0]]), torch.tensor([[.5, 0, .5]])), 92 | (torch.tensor([[0, 1., 0]]), torch.tensor([[0, .5, .5]])), 93 | (u, u), 94 | (-u, torch.zeros_like(u)) 95 | ]: 96 | assert cone.prox(inp).eq(correct_prox).all() 97 | 98 | # Moreau decomposition: x = proj_x + (x - proj_x) where 99 | # the two vectors are orthogonal 100 | for _ in range(10): 101 | x = torch.rand(*u.shape) 102 | proj_x = cone.prox(x) 103 | assert utils.bdot(x - proj_x, proj_x).allclose(torch.zeros_like(x), atol=4e-7) 104 | 105 | 106 | @pytest.mark.parametrize('Constraint', [chop.constraints.L1Ball, 107 | chop.constraints.L2Ball, 108 | chop.constraints.LinfBall, 109 | chop.constraints.Simplex, 110 | chop.constraints.NuclearNormBall, 111 | chop.constraints.GroupL1Ball, 112 | chop.constraints.Box, 113 | chop.constraints.Cone]) 114 | @pytest.mark.parametrize('alpha', [.1, 1., 20.]) 115 | def test_feasible(Constraint, alpha): 116 | """Tests if prox and LMO yield feasible points""" 117 | # TODO: implement feasibility check method in each constraint. 118 | 119 | if Constraint == chop.constraints.GroupL1Ball: 120 | groups = group_patches(x_patch_size=2, y_patch_size=2, x_image_size=6, y_image_size=6) 121 | constraint = Constraint(alpha, groups) 122 | elif Constraint == chop.constraints.Cone: 123 | directions = torch.rand(2, 3, 6, 6) 124 | cos_alpha = .2 125 | constraint = Constraint(directions, cos_alpha) 126 | elif Constraint == chop.constraints.Box: 127 | constraint = Constraint(-1., 10.) 128 | else: 129 | constraint = Constraint(alpha) 130 | for _ in range(10): 131 | try: 132 | data = (alpha + 1) * torch.rand(2, 3, 6, 6) 133 | assert constraint.is_feasible(constraint.prox(data)).all() 134 | except AttributeError: # Constraint doesn't have a prox operator 135 | pass 136 | try: 137 | grad = (alpha + 1) * torch.rand(2, 3, 6, 6) 138 | update_dir, _ = constraint.lmo(-grad, data) 139 | s = update_dir + data 140 | assert constraint.is_feasible(s).all() 141 | except AttributeError: # Constraint doesn't have an LMO 142 | pass 143 | 144 | 145 | class Net(nn.Module): 146 | def __init__(self): 147 | super(Net, self).__init__() 148 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 149 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 150 | self.dropout1 = nn.Dropout(0.25) 151 | self.dropout2 = nn.Dropout(0.5) 152 | self.fc1 = nn.Linear(9216, 128) 153 | self.fc2 = nn.Linear(128, 10) 154 | 155 | def forward(self, x): 156 | x = self.conv1(x) 157 | x = F.relu(x) 158 | x = self.conv2(x) 159 | x = F.relu(x) 160 | x = F.max_pool2d(x, 2) 161 | x = self.dropout1(x) 162 | x = torch.flatten(x, 1) 163 | x = self.fc1(x) 164 | x = F.relu(x) 165 | x = self.dropout2(x) 166 | x = self.fc2(x) 167 | output = F.log_softmax(x, dim=1) 168 | return output 169 | 170 | 171 | @pytest.mark.parametrize('ord', [1, 2, np.inf, 'nuc']) 172 | @pytest.mark.parametrize('constrain_bias', [True, False]) 173 | def test_model_constraint_maker(ord, constrain_bias): 174 | 175 | model = Net() 176 | constraints = chop.constraints.make_model_constraints(model, ord, constrain_bias=constrain_bias) 177 | 178 | assert len(constraints) == len(list(model.parameters())) 179 | 180 | proxes = [constraint.prox if constraint else None for constraint in constraints] 181 | 182 | chop.constraints.make_feasible(model, proxes) 183 | 184 | for (name, param), constraint in zip(model.named_parameters(), constraints): 185 | if chop.constraints.is_bias(name, param) and ord == 'nuc': 186 | continue 187 | if constraint: 188 | assert torch.allclose(param, constraint.prox(param.unsqueeze(0)).squeeze(0), atol=1e-5) 189 | 190 | 191 | @pytest.mark.parametrize('d', [2, 10]) 192 | def test_polytope_lmo(d): 193 | vertices = torch.zeros(1, 2 * d, d) 194 | 195 | for i in range(d): 196 | vertices[0, i, i] = 1. 197 | vertices[0, d + i, i] = -1. 198 | 199 | constraint = chop.constraints.Polytope(vertices) 200 | 201 | iterate = torch.zeros(d) 202 | grad = torch.zeros(d) 203 | grad[0] = 1. 204 | 205 | assert torch.allclose(constraint.lmo(grad.unsqueeze(0), iterate.unsqueeze(0))[0], grad) 206 | 207 | @pytest.mark.parametrize('d', [2, 10]) 208 | def test_polytope_lmo_pairwise(d): 209 | 210 | vertices = torch.zeros(1, 2 * d, d) 211 | 212 | for i in range(d): 213 | vertices[0, i, i] = 1. 214 | vertices[0, d + i, i] = -1. 215 | 216 | constraint = chop.constraints.Polytope(vertices) 217 | 218 | iterate = torch.zeros(d) 219 | grad = torch.zeros(d) 220 | grad[0] = 1. 221 | 222 | constraint.lmo_pairwise(grad.unsqueeze(0), iterate.unsqueeze(0), active_set={0: 1.}) 223 | -------------------------------------------------------------------------------- /tests/test_optim.py: -------------------------------------------------------------------------------- 1 | """Tests full gradient batch-wise optimization algorithms from chop.optim""" 2 | 3 | 4 | import torch 5 | from chop import optim 6 | from chop import utils 7 | from chop import constraints 8 | from chop.utils import logging 9 | 10 | import pytest 11 | 12 | # Set up a batch of toy constrained optimization problems 13 | batch_size = 20 14 | d = 2 15 | xstar = torch.rand(batch_size, d) 16 | alpha = .5 17 | constraint = constraints.LinfBall(alpha) 18 | xstar = constraint.prox(xstar) 19 | 20 | # Minimize quadratics for each datapoint in the batch 21 | @utils.closure 22 | def loss_fun(x): 23 | return .5 * ((x - xstar) ** 2).view(batch_size, -1).sum(-1) 24 | 25 | 26 | @pytest.mark.parametrize('step', [1., 'backtracking']) 27 | def test_minimize_pgd(step): 28 | max_iter = 2000 29 | x0 = torch.zeros_like(xstar) 30 | trace_cb = logging.Trace(closure=loss_fun) 31 | 32 | sol = optim.minimize_pgd(loss_fun, x0, constraint.prox, 33 | step=step, 34 | max_iter=max_iter, callback=trace_cb) 35 | 36 | assert sol.certificate.allclose(torch.zeros(batch_size, dtype=torch.float)), sol.certificate 37 | 38 | 39 | def test_minimize_frank_wolfe(): 40 | max_iter = 2000 41 | x0 = torch.zeros_like(xstar) 42 | sol = optim.minimize_frank_wolfe(loss_fun, x0, constraint.lmo, 43 | max_iter=max_iter) 44 | assert sol.certificate.allclose(torch.zeros(batch_size, dtype=torch.float), atol=1e-3), sol.certificate 45 | 46 | 47 | def test_minimize_three_split(): 48 | max_iter = 200 49 | x0 = torch.zeros_like(xstar) 50 | batch_size = x0.size(0) 51 | sol = optim.minimize_three_split(loss_fun, x0, constraint.prox, 52 | max_iter=max_iter) 53 | 54 | cert = sol.certificate 55 | assert cert.allclose(torch.zeros(batch_size, dtype=torch.float), atol=1e-5) -------------------------------------------------------------------------------- /tests/test_penalties.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | 5 | import chop 6 | 7 | 8 | @pytest.mark.parametrize('alpha', [0., 1e-3, 1.]) 9 | def test_groupL1_unit_groups_against_lasso_1d(alpha): 10 | n_features = 4 11 | groups = torch.arange(n_features) 12 | lasso = chop.penalties.L1(alpha) 13 | groupLasso = chop.penalties.GroupL1(alpha, groups) 14 | 15 | batch_size = 3 16 | data = torch.rand(batch_size, n_features) 17 | assert torch.allclose(lasso(data), groupLasso(data)), '__call__' 18 | 19 | assert torch.allclose(lasso.prox(data, 1.), groupLasso.prox(data, 1.)), 'prox' 20 | 21 | if alpha == 0.: 22 | assert torch.allclose(data, groupLasso.prox(data, 1.)) 23 | 24 | 25 | def test_zero_data_L1_groupL1_prox(): 26 | batch_size, n_features = 3, 4 27 | data = torch.zeros(batch_size, n_features) 28 | alpha = 1 29 | groups = torch.arange(n_features) 30 | lasso = chop.penalties.L1(alpha) 31 | groupLasso = chop.penalties.GroupL1(alpha, groups) 32 | 33 | assert (lasso.prox(data, 1.) == data).all() 34 | assert (groupLasso.prox(data, 1.) == data).all() 35 | 36 | 37 | def test_groupL1_1d(): 38 | groups = [(0, 1), 39 | (2, 3)] 40 | alpha = 1. 41 | penalty = chop.penalties.GroupL1(alpha, groups) 42 | 43 | batch_size = 3 44 | data = torch.rand(batch_size, 4) 45 | 46 | correct_result = alpha * (torch.sqrt(data[:, 0] ** 2 + data[:, 1] ** 2) 47 | + torch.sqrt(data[:, 2] ** 2 + data[:, 3] ** 2)) 48 | assert torch.allclose(penalty(data), correct_result) 49 | 50 | 51 | def test_groupL1_2d(): 52 | groups = [((0, 0), (0, 1), (1, 0), (1, 1)), 53 | ((2, 0), (2, 1), (3, 0), (3, 1))] 54 | alpha = 1. 55 | penalty = chop.penalties.GroupL1(alpha, groups) 56 | batch_size = 3 57 | 58 | data = torch.rand(batch_size, 4, 2) 59 | 60 | correct_result = alpha * (torch.sqrt(data[:, 0, 0] ** 2 + data[:, 1, 0] ** 2 61 | + data[:, 0, 1] ** 2 + data[:, 1, 1] ** 2) 62 | + torch.sqrt(data[:, 2, 0] ** 2 + data[:, 3, 0] ** 2 63 | + data[:, 2, 1] ** 2 + data[:, 3, 1] ** 2)) 64 | assert torch.allclose(penalty(data), correct_result), '__call__' 65 | 66 | 67 | 68 | @pytest.mark.parametrize("penalty", [chop.penalties.GroupL1(1., np.array_split(np.arange(16), 5)), 69 | chop.penalties.L1(1.)]) 70 | def test_three_inequality(penalty): 71 | """Test the L1 prox using the three point inequality 72 | The three-point inequality is described e.g., in Lemma 1.4 73 | in "Gradient-Based Algorithms with Applications to Signal 74 | Recovery Problems", Amir Beck and Marc Teboulle 75 | """ 76 | n_features = 16 77 | batch_size = 3 78 | 79 | for _ in range(10): 80 | z = torch.rand(batch_size, n_features) 81 | u = torch.rand(batch_size, n_features) 82 | xi = penalty.prox(z, 1.) 83 | 84 | lhs = 2 * (penalty(xi) - penalty(u)) 85 | rhs = ( 86 | torch.linalg.norm(u - z, dim=-1) ** 2 87 | - torch.linalg.norm(u - xi, dim=-1) ** 2 88 | - torch.linalg.norm(xi - z, dim=-1) ** 2 89 | ) 90 | assert (lhs <= rhs).all(), penalty 91 | -------------------------------------------------------------------------------- /tests/test_stochastic.py: -------------------------------------------------------------------------------- 1 | """Tests for stochastic optimizers""" 2 | 3 | import numpy as np 4 | import torch 5 | import pytest 6 | import shutil 7 | 8 | # TODO: remove cox dependency 9 | from cox.store import Store 10 | 11 | import chop 12 | from chop import stochastic 13 | 14 | 15 | OUT_DIR = "logging/tests/test_optim" 16 | shutil.rmtree(OUT_DIR, ignore_errors=True) 17 | MAX_ITER = 300 18 | 19 | torch.manual_seed(0) 20 | 21 | # TODO: put this in a setup function 22 | # Set up random regression problem 23 | alpha = 1.0 24 | n_samples, n_features = 20, 15 25 | X = torch.rand((n_samples, n_features)) 26 | w = torch.rand(n_features) 27 | w = alpha * w / sum(abs(w)) 28 | y = X.mv(w) 29 | # Logistic regression: \|y\|_\infty <= 1 30 | y = abs(y / y.max()) 31 | 32 | tol = 4e-3 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "algorithm", 37 | [ 38 | stochastic.PGD, 39 | stochastic.PGDMadry, 40 | stochastic.FrankWolfe, 41 | stochastic.S3CM, 42 | ], 43 | ) 44 | @pytest.mark.parametrize("lr", [1.0, 0.5, 0.1, 0.05, 0.001]) 45 | def test_L1Ball(algorithm, lr): 46 | # Setup 47 | constraint = chop.constraints.L1Ball(alpha) 48 | prox = constraint.prox 49 | lmo = constraint.lmo 50 | assert (constraint.prox(w) == w).all() 51 | w_t = torch.zeros_like(w) 52 | w_t.requires_grad = True 53 | 54 | constraint_oracles = { 55 | stochastic.PGD.name: {"prox": [prox]}, 56 | stochastic.PGDMadry.name: {"prox": [prox], "lmo": [lmo]}, 57 | stochastic.FrankWolfe.name: {"lmo": [lmo]}, 58 | stochastic.S3CM.name: {"prox1": [prox], "prox2": [prox]}, 59 | } 60 | 61 | optimizer = algorithm([w_t], **(constraint_oracles[algorithm.name]), lr=lr) 62 | criterion = torch.nn.MSELoss(reduction="mean") 63 | 64 | # Logging 65 | store = Store(OUT_DIR) 66 | store.add_table("metadata", {"algorithm": str, "lr": float}) 67 | 68 | store["metadata"].append_row({"algorithm": optimizer.name, "lr": lr}) 69 | store.add_table( 70 | optimizer.name, 71 | {"func_val": float, "certificate": float, "norm(w_t)": float}, 72 | ) 73 | cert = torch.tensor(np.inf) 74 | for ii in range(MAX_ITER): 75 | optimizer.zero_grad() 76 | # TODO: make this stochastic, use a dataloader 77 | loss = criterion(X.mv(w_t), y) 78 | loss.backward() 79 | 80 | optimizer.step() 81 | 82 | try: 83 | cert = next(optimizer.certificate) # only one parameter here 84 | except AttributeError: 85 | cert = torch.tensor(np.nan) 86 | 87 | store.log_table_and_tb( 88 | optimizer.name, 89 | { 90 | "func_val": loss.item(), 91 | "certificate": cert.item(), 92 | "norm(w_t)": sum(abs(w_t)).item(), 93 | }, 94 | ) 95 | store[optimizer.name].flush_row() 96 | 97 | store.close() 98 | 99 | 100 | def test_FW_active_set(): 101 | """Tests active set capabilities. 102 | Weight vector should be in the simplex.""" 103 | 104 | constraint = chop.constraints.L1Ball(alpha) 105 | lmo = constraint.lmo 106 | assert (constraint.prox(w) == w).all() 107 | w_t = torch.zeros_like(w) 108 | w_t.requires_grad = True 109 | 110 | optimizer = stochastic.FrankWolfe( 111 | [w_t], [lmo], momentum=0.0, track_active_set=True 112 | ) 113 | criterion = torch.nn.MSELoss(reduction="mean") 114 | for ii in range(MAX_ITER): 115 | optimizer.zero_grad() 116 | # TODO: make this stochastic, use a dataloader 117 | loss = criterion(X.mv(w_t), y) 118 | loss.backward() 119 | 120 | optimizer.step() 121 | 122 | active_set = optimizer.state[w_t]["active_set"] 123 | 124 | for weight in active_set.values(): 125 | assert 0.0 <= weight <= 1.0 126 | 127 | assert np.allclose(sum(active_set.values()), 1.0) 128 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for utility functions""" 2 | 3 | from chop.utils import closure 4 | import torch 5 | from torch import nn 6 | from chop import utils 7 | 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') 10 | 11 | # Set up random regression problem 12 | alpha = 1. 13 | n_samples, n_features = 20, 15 14 | X = torch.rand((n_samples, n_features)) 15 | w = torch.rand(n_features) 16 | w = alpha * w / sum(abs(w)) 17 | y = X.mv(w) 18 | # Logistic regression: \|y\|_\infty <= 1 19 | y = abs(y / y.max()) 20 | 21 | tol = 4e-3 22 | 23 | batch_size = 20 24 | d1 = 10 25 | d2 = 5 26 | x0 = torch.ones(batch_size, d1, d2) 27 | 28 | 29 | def test_jacobian_batch(): 30 | def loss(x): 31 | return (x.view(x.size(0), -1) ** 2).sum(-1) 32 | 33 | val, jac = utils.get_func_and_jac(loss, x0) 34 | 35 | assert jac.eq(2 * x0).all() 36 | 37 | 38 | def test_jacobian_single_sample(): 39 | def loss(x): 40 | return (x ** 2).sum() 41 | 42 | x0 = torch.rand(1, d1, d2) 43 | val, jac = utils.get_func_and_jac(loss, x0) 44 | 45 | def test_closure(): 46 | 47 | @utils.closure 48 | def loss(x): 49 | return (x.view(x.size(0), -1) ** 2).sum(-1) 50 | 51 | val, grad = loss(x0) 52 | assert val.eq(torch.ones(batch_size) * (d1 * d2)).all() 53 | assert grad.eq(2 * x0).all() 54 | 55 | 56 | def test_init_lipschitz(): 57 | criterion = nn.MSELoss(reduction='none') 58 | 59 | @closure 60 | def loss_fun(X): 61 | return criterion(X.mv(w), y) 62 | 63 | L = utils.init_lipschitz(loss_fun, X.detach().clone().requires_grad_(True)) 64 | print(L) 65 | 66 | 67 | def test_bmm(): 68 | """ 69 | Check shape returned by batch matmul 70 | """ 71 | for _ in range(10): 72 | t1 = torch.rand(4, 3, 32, 35) 73 | t2 = torch.rand(4, 3, 35, 32) 74 | 75 | res = utils.bmm(t1, t2) 76 | assert res.shape == (4, 3, 32, 32) 77 | 78 | 79 | def test_bmv(): 80 | """ 81 | Check shape returns of batch mat-vec multiply 82 | """ 83 | for _ in range(10): 84 | mat = torch.rand(4, 3, 32, 35) 85 | vec = torch.rand(4, 3, 35) 86 | 87 | res = utils.bmv(mat, vec) 88 | assert res.shape == (4, 3, 32) 89 | 90 | mat = torch.rand(1, 200, 500) 91 | vec = torch.rand(1, 500) 92 | 93 | res = utils.bmv(mat, vec) 94 | assert res.shape == (1, 200) 95 | 96 | 97 | def test_power_iteration(): 98 | """ 99 | Checks our power iteration method against torch.svd 100 | """ 101 | mat = torch.rand(4, 3, 32, 35) 102 | mat.to(device) 103 | # Ground truth 104 | U, S, V = torch.svd(mat) 105 | u, s, v = utils.power_iteration(mat, n_iter=10) 106 | 107 | # First singular value should be the same 108 | assert torch.allclose(S[..., 0], s, atol=1e-5), (S[..., 0] - s) 109 | 110 | outer = U[..., 0].unsqueeze(-1) * V[..., 0].unsqueeze(-2) 111 | outer_pi = u.unsqueeze(-1) * v.unsqueeze(-2) 112 | 113 | # Rank 1 approx should be the same 114 | assert torch.allclose(outer, outer_pi), (outer - outer_pi) 115 | --------------------------------------------------------------------------------