├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── api.rst ├── conf.py ├── index.rst ├── requirements.txt └── sdp_verify.rst ├── examples ├── run_boundprop.py ├── run_examples.sh ├── run_lp_solver.py └── run_sdp_verify.py ├── jax_verify ├── __init__.py ├── extensions │ ├── functional_lagrangian │ │ ├── README.md │ │ ├── attacks.py │ │ ├── bounding.py │ │ ├── data.py │ │ ├── dual_build.py │ │ ├── dual_solve.py │ │ ├── inner_solvers │ │ │ ├── __init__.py │ │ │ ├── exact_opt_softmax.py │ │ │ ├── get_strategy.py │ │ │ ├── input_uncertainty_spec.py │ │ │ ├── lp.py │ │ │ ├── mixed.py │ │ │ ├── pga │ │ │ │ ├── __init__.py │ │ │ │ ├── optimizer.py │ │ │ │ ├── pga_strategy.py │ │ │ │ ├── square.py │ │ │ │ └── utils.py │ │ │ └── uncertainty_spec.py │ │ ├── lagrangian_form.py │ │ ├── model.py │ │ ├── run │ │ │ ├── configs │ │ │ │ ├── config_adv_stochastic_model.py │ │ │ │ ├── config_ood_stochastic_input.py │ │ │ │ └── config_ood_stochastic_model.py │ │ │ └── run_functional_lagrangian.py │ │ ├── specification.py │ │ └── verify_utils.py │ └── sdp_verify │ │ ├── README.md │ │ ├── boundprop_utils.py │ │ ├── cvxpy_verify.py │ │ ├── eigenvector_utils.py │ │ ├── problem.py │ │ ├── problem_from_graph.py │ │ ├── sdp_verify.py │ │ └── utils.py ├── sdp_verify.py ├── src │ ├── activation_relaxation.py │ ├── bound_propagation.py │ ├── bound_utils.py │ ├── branching │ │ ├── backpropagation.py │ │ ├── branch_algorithm.py │ │ ├── branch_selection.py │ │ └── branch_utils.py │ ├── concretization.py │ ├── graph_traversal.py │ ├── ibp.py │ ├── intersection.py │ ├── linear │ │ ├── backward_crown.py │ │ ├── backward_linearbounds_with_branching.py │ │ ├── forward_linear_bounds.py │ │ ├── linear_bound_utils.py │ │ └── linear_relaxations.py │ ├── mccormick.py │ ├── mip_solver │ │ ├── cvxpy_relaxation_solver.py │ │ ├── relaxation.py │ │ └── solve_relaxation.py │ ├── nonconvex │ │ ├── duals.py │ │ ├── methods.py │ │ ├── nonconvex.py │ │ └── optimizers.py │ ├── opt_utils.py │ ├── optimizers.py │ ├── simplex_bound.py │ ├── synthetic_primitives.py │ ├── types.py │ └── utils.py └── tests │ ├── activation_relaxation_test.py │ ├── backpropagation_test.py │ ├── backward_crown_test.py │ ├── backward_linearbounds_with_branching_test.py │ ├── bound_propagation_test.py │ ├── branch_algorithm_test.py │ ├── branch_selection_test.py │ ├── branch_utils_test.py │ ├── crownibp_test.py │ ├── cvxpy_relaxation_test.py │ ├── forward_linear_bounds_test.py │ ├── functional_lagrangian │ ├── attacks_test.py │ ├── lagrangian_form_test.py │ ├── lp_test.py │ ├── pga_test.py │ └── uncertainty_spec_test.py │ ├── ibp_test.py │ ├── linear_relaxations_test.py │ ├── model_zoo.py │ ├── model_zoo_test.py │ ├── nonconvex_test.py │ ├── opt_utils_test.py │ ├── sdp_verify │ ├── boundprop_utils_test.py │ ├── cvxpy_verify_test.py │ ├── problem_from_graph_test.py │ ├── sdp_verify_test.py │ ├── test_utilfuns.py │ └── test_utils.py │ ├── simplex_bound_test.py │ ├── synthetic_primitives_test.py │ └── test_utils.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.pyc 2 | docs/_build 3 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # For context, jax_verify is developed day-to-day using internal continuous 2 | # integration software. 3 | # 4 | # The current Travis CI setup is unpolished and verifies that open-source 5 | # jax_verify builds correctly. This is done on a best-effort basis; we are not 6 | # attached to Travis CI. 7 | # 8 | # If you use jax_verify, continuous integration improvements are welcome. 9 | 10 | language: python 11 | python: 12 | - "3.6" 13 | - "3.7" 14 | - "3.8" 15 | addons: 16 | apt: 17 | packages: 18 | - libblas-dev 19 | - liblapack-dev 20 | env: 21 | - TRAVIS=true 22 | install: 23 | - pip install . 24 | script: 25 | - pwd 26 | - python3 --version 27 | - cd jax_verify/tests 28 | - pytest 29 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | ## How to become a contributor and submit your own code 4 | 5 | ### Contributor License Agreements 6 | 7 | We'd love to accept your patches! Before we can take them, we have to jump a 8 | couple of legal hurdles. 9 | 10 | Please fill out either the individual or corporate Contributor License Agreement 11 | (CLA). 12 | 13 | * If you are an individual writing original source code and you're sure you 14 | own the intellectual property, then you'll need to sign an [individual 15 | CLA](http://code.google.com/legal/individual-cla-v1.0.html). 16 | * If you work for a company that wants to allow you to contribute your work, 17 | then you'll need to sign a [corporate 18 | CLA](http://code.google.com/legal/corporate-cla-v1.0.html). 19 | 20 | Follow either of the two links above to access the appropriate CLA and 21 | instructions for how to sign and return it. Once we receive it, we'll be able to 22 | accept your pull requests. 23 | 24 | ***NOTE***: Only original source code from you and other people that have signed 25 | the CLA can be accepted into the main repository. 26 | 27 | ### Contributing code 28 | 29 | If you have improvements to Haiku, send us your pull requests! For those just 30 | getting started, Github has a 31 | [howto](https://help.github.com/articles/using-pull-requests/). 32 | 33 | If you want to contribute but you're not sure where to start, take a look at the 34 | [issues with the "contributions welcome" 35 | label](https://github.com/deepmind/jax_verify/labels/stat%3Acontributions%20welcome). 36 | These are issues that we believe are particularly well suited for outside 37 | contributions, often because we probably won't get to them right now. If you 38 | decide to start on an issue, leave a comment so that other people know that 39 | you're working on it. If you want to help out, but not alone, use the issue 40 | comment thread to coordinate. 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jax_verify: Neural Network Verification in JAX 2 | 3 | [![tests status](https://travis-ci.com/deepmind/jax_verify.svg?branch=master)](https://travis-ci.com/deepmind/jax_verify) 4 | [![docs: latest](https://img.shields.io/badge/docs-stable-blue.svg)](https://jax-verify.readthedocs.io) 5 | 6 | Jax_verify is a library containing JAX implementations of many widely-used neural network verification techniques. 7 | 8 | ## Overview 9 | 10 | If you just want to get started with using jax_verify to verify your neural 11 | networks, the main thing to know is we provide a simple, consistent interface 12 | for a variety of verification algorithms: 13 | 14 | ```python 15 | output_bounds = jax_verify.verification_technique(network_fn, input_bounds) 16 | ``` 17 | 18 | Here, `network_fn` is any JAX function, `input_bounds` define bounds over 19 | possible inputs to `network_fn`, and `output_bounds` will be the computed bounds 20 | over possible outputs of `network_fn`. `verification_technique` can be one of 21 | many algorithms implemented in `jax_verify`, such as `interval_bound_propagation` 22 | or `crown_bound_propagation`. 23 | 24 | The overall approach is to use JAX’s powerful [program transformation system](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html), 25 | which allows us to analyze general network structures defined by `network_fn` 26 | and then to define corresponding functions for calculating 27 | verified bounds for these networks. 28 | 29 | ## Verification Techniques 30 | 31 | The methods currently provided by `jax_verify` include: 32 | 33 | * Functional Lagrangian [Berrada et al 2021](https://arxiv.org/abs/2102.09479) 34 | * SDP-FO (first-order SDP verification, [Dathathri et al 2020](https://arxiv.org/abs/2010.11645)) 35 | * Non-convex ([Bunel et al 2020](https://arxiv.org/abs/2010.14322)) 36 | * Interval Bound Propagation ([Gowal et al 2018](https://arxiv.org/pdf/1810.12715.pdf), [Mirman et al 2018](http://proceedings.mlr.press/v80/mirman18b/mirman18b.pdf)) 37 | * Backward Lirpa bounds such as CAP ([Wong and Kolter 2017](https://arxiv.org/pdf/1711.00851.pdf)), FastLin([Weng et al 2018](https://arxiv.org/pdf/1804.09699.pdf)) or CROWN ([Zhang et al 2018](https://arxiv.org/pdf/1811.00866.pdf)) 38 | * Forward Lirpa bounds ([Xu et al 2020](https://arxiv.org/pdf/2002.12920.pdf)) 39 | * CROWN-IBP ([Zhang et al 2019](https://arxiv.org/abs/1906.06316)) 40 | * Planet (also known as the "LP" or "triangle" relaxation, [Ehlers 2017](https://arxiv.org/abs/1705.01320)), currently using [CVXPY](https://github.com/cvxgrp/cvxpy) as the LP solver 41 | * MIP encoding ([Cheng et al 2017](https://arxiv.org/pdf/1705.01040.pdf), [Tjeng et al 2019](https://arxiv.org/pdf/1711.07356.pdf)) 42 | 43 | ## Installation 44 | 45 | **Stable**: Just run `pip install jax_verify` and you can `import jax_verify` from any of your Python code. 46 | 47 | **Latest**: Clone this directory and run `pip install .` from the directory root. 48 | 49 | ## Getting Started 50 | 51 | We suggest starting by looking at the minimal examples in the `examples/` directory. 52 | For example, all the bound propagation techniques can be run with the `run_boundprop.py` script: 53 | 54 | ```bash 55 | cd examples/ 56 | python3 run_boundprop.py --boundprop_method=interval_bound_propagation 57 | ``` 58 | 59 | For documentation, please refer to the [API reference page](https://jax-verify.readthedocs.io/en/latest/api.html). 60 | 61 | ## Notes 62 | 63 | Contributions of additional verification techniques are very welcome. Please open 64 | an issue first to let us know. 65 | 66 | ## License 67 | 68 | All code is made available under the Apache 2.0 License. 69 | Model parameters are made available under the Creative Commons Attribution 4.0 70 | International (CC BY 4.0) License. 71 | See https://creativecommons.org/licenses/by/4.0/legalcode for more details. 72 | 73 | ## Disclaimer 74 | 75 | This is not an official Google product. 76 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | ############# 2 | API Reference 3 | ############# 4 | 5 | 6 | .. currentmodule:: jax_verify 7 | 8 | Verification methods 9 | ==================== 10 | 11 | .. autofunction:: crown_bound_propagation 12 | 13 | .. autofunction:: crownibp_bound_propagation 14 | 15 | .. autofunction:: fastlin_bound_propagation 16 | 17 | .. autofunction:: ibpfastlin_bound_propagation 18 | 19 | .. autofunction:: interval_bound_propagation 20 | 21 | .. autofunction:: solve_planet_relaxation 22 | 23 | Bound objects 24 | ============= 25 | 26 | .. autoclass:: LinearBound 27 | 28 | .. autoclass:: IntervalBound 29 | 30 | 31 | Utility methods 32 | =============== 33 | 34 | .. autofunction:: open_file 35 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 17 | # 18 | # Licensed under the Apache License, Version 2.0 (the "License"); 19 | # you may not use this file except in compliance with the License. 20 | # You may obtain a copy of the License at 21 | # 22 | # http://www.apache.org/licenses/LICENSE-2.0 23 | # 24 | # Unless required by applicable law or agreed to in writing, software 25 | # distributed under the License is distributed on an "AS IS" BASIS, 26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 27 | # See the License for the specific language governing permissions and 28 | # limitations under the License. 29 | # ============================================================================ 30 | """Configuration file for the Sphinx documentation builder.""" 31 | 32 | # This file only contains a selection of the most common options. For a full 33 | # list see the documentation: 34 | # http://www.sphinx-doc.org/en/master/config 35 | 36 | # -- Path setup -------------------------------------------------------------- 37 | 38 | # If extensions (or modules to document with autodoc) are in another directory, 39 | # add these directories to sys.path here. If the directory is relative to the 40 | # documentation root, use os.path.abspath to make it absolute, like shown here. 41 | 42 | # pylint: disable=g-bad-import-order 43 | # pylint: disable=g-import-not-at-top 44 | import inspect 45 | import os 46 | import sys 47 | 48 | sys.path.insert(0, os.path.abspath('../')) 49 | 50 | import jax_verify 51 | 52 | # -- Project information ----------------------------------------------------- 53 | 54 | project = 'jax_verify' 55 | copyright = '2020, DeepMind' # pylint: disable=redefined-builtin 56 | author = 'DeepMind' 57 | 58 | # -- General configuration --------------------------------------------------- 59 | 60 | master_doc = 'index' 61 | 62 | # Add any Sphinx extension module names here, as strings. They can be 63 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 64 | # ones. 65 | extensions = [ 66 | 'sphinx.ext.autodoc', 67 | 'sphinx.ext.autosummary', 68 | 'sphinx.ext.linkcode', 69 | 'sphinx.ext.napoleon', 70 | ] 71 | 72 | # Add any paths that contain templates here, relative to this directory. 73 | templates_path = ['_templates'] 74 | 75 | # List of patterns, relative to source directory, that match files and 76 | # directories to ignore when looking for source files. 77 | # This pattern also affects html_static_path and html_extra_path. 78 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 79 | 80 | # -- Options for autodoc ----------------------------------------------------- 81 | 82 | autodoc_default_options = { 83 | 'member-order': 'bysource', 84 | 'special-members': True, 85 | 'exclude-members': '__repr__, __str__, __weakref__', 86 | } 87 | 88 | # -- Options for HTML output ------------------------------------------------- 89 | 90 | # The theme to use for HTML and HTML Help pages. See the documentation for 91 | # a list of builtin themes. 92 | # 93 | html_theme = 'sphinx_rtd_theme' 94 | 95 | html_theme_options = { 96 | # 'collapse_navigation': False, 97 | # 'sticky_navigation': False, 98 | } 99 | 100 | # -- Source code links ------------------------------------------------------- 101 | 102 | 103 | def linkcode_resolve(domain, info): 104 | """Resolve a GitHub URL corresponding to Python object.""" 105 | if domain != 'py': 106 | return None 107 | 108 | try: 109 | mod = sys.modules[info['module']] 110 | except ImportError: 111 | return None 112 | 113 | obj = mod 114 | try: 115 | for attr in info['fullname'].split('.'): 116 | obj = getattr(obj, attr) 117 | except AttributeError: 118 | return None 119 | else: 120 | obj = inspect.unwrap(obj) 121 | 122 | try: 123 | filename = inspect.getsourcefile(obj) 124 | except TypeError: 125 | return None 126 | 127 | try: 128 | source, lineno = inspect.getsourcelines(obj) 129 | except OSError: 130 | return None 131 | 132 | # TODO: support tags after we release an initial version. 133 | return 'https://github.com/deepmind/jax_verify/blob/master/jax_verify/%s#L%d#L%d' % ( 134 | os.path.relpath(filename, start=os.path.dirname( 135 | jax_verify.__file__)), lineno, lineno + len(source) - 1) 136 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ################## 2 | jax_verify Documentation 3 | ################## 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | :hidden: 8 | 9 | api 10 | sdp_verify 11 | 12 | ``jax_verify`` is a library for verification of neural network specifications. 13 | 14 | Installation 15 | ============ 16 | 17 | Install ``jax_verify`` by running:: 18 | 19 | $ pip install jax_verify 20 | 21 | Support 22 | ======= 23 | 24 | If you are having issues, please let us know by filing an issue on our 25 | `issue tracker `_. 26 | 27 | License 28 | ======= 29 | 30 | jax_verify is licensed under the Apache 2.0 License. 31 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx>=2.0.1 2 | sphinx_rtd_theme>=0.4.3 3 | absl-py 4 | cvxpy 5 | dm-tree 6 | jax>=0.1.71 7 | jaxlib>=0.1.49 8 | numpy 9 | optax 10 | dm-haiku 11 | -------------------------------------------------------------------------------- /docs/sdp_verify.rst: -------------------------------------------------------------------------------- 1 | ################ 2 | SDP Verification 3 | ################ 4 | 5 | The ``sdp_verify`` directory contains a largely self-contained implementation of 6 | the SDP-FO (first-order SDP verification) algorithm described in Dathathri et al 7 | 2020. We *encourage* projects building off this code to fork this directory, 8 | though contributions are also welcome! 9 | 10 | The core solver is contained in ``sdp_verify.py``. The main function is 11 | ``dual_fun(verif_instance, dual_vars)``, which defines the dual upper bound from 12 | Equation (5). For any feasible ``dual_vars`` this provides a valid bound. It is 13 | written amenable to autodiff, such that ``jax.grad`` with respect to 14 | ``dual_vars`` yields a valid subgradient. 15 | 16 | We also provide ``solve_sdp_dual_simple(verif_instance)``, which implements the 17 | optimization loop (SDP-FO). This initializes the dual variables using our 18 | proposed scheme, and performs projected subgradient steps. 19 | 20 | Both methods accept a ``SdpDualVerifInstance`` which specifies (1) the 21 | Lagrangian, (2) interval bounds on the primal variables, and (3) dual variable 22 | shapes. 23 | 24 | As described in the paper, the solver can easily be applied to other 25 | input/output specifications or network architectures for any QCQP. This involves 26 | defining the corresponding QCQP Lagrangian and creating a 27 | ``SdpDualVerifInstance``. In ``examples/run_sdp_verify.py`` we include an 28 | example for certifying adversarial L_inf robustness of a ReLU convolutional 29 | network image classifier. 30 | 31 | API Reference 32 | ============= 33 | 34 | .. currentmodule:: jax_verify.sdp_verify 35 | 36 | .. autofunction:: dual_fun 37 | 38 | .. autofunction:: solve_sdp_dual 39 | 40 | .. autofunction:: solve_sdp_dual_simple 41 | 42 | .. autoclass:: SdpDualVerifInstance 43 | -------------------------------------------------------------------------------- /examples/run_boundprop.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Bound propagation example usage: IBP, Fastlin, CROWN, CROWN-IBP. 17 | 18 | Examples: 19 | python3 run_boundprop.py 20 | python3 run_boundprop.py --model=cnn 21 | python3 run_boundprop.py --boundprop_method=fastlin_bound_propagation 22 | """ 23 | import functools 24 | import pickle 25 | from absl import app 26 | from absl import flags 27 | from absl import logging 28 | import jax.numpy as jnp 29 | import jax_verify 30 | from jax_verify.extensions.sdp_verify import utils 31 | import numpy as np 32 | 33 | MLP_PATH = 'models/raghunathan18_pgdnn.pkl' 34 | CNN_PATH = 'models/mnist_wongsmall_eps_10_adv.pkl' 35 | ALL_BOUNDPROP_METHODS = ( 36 | jax_verify.interval_bound_propagation, 37 | jax_verify.forward_fastlin_bound_propagation, 38 | jax_verify.backward_fastlin_bound_propagation, 39 | jax_verify.ibpforwardfastlin_bound_propagation, 40 | jax_verify.forward_crown_bound_propagation, 41 | jax_verify.backward_crown_bound_propagation, 42 | jax_verify.crownibp_bound_propagation, 43 | ) 44 | 45 | flags.DEFINE_string('model', 'mlp', 'mlp or cnn') 46 | flags.DEFINE_string('boundprop_method', '', 47 | 'Any boundprop method, such as `interval_bound_propagation`' 48 | ' `forward_fastlin_bound_propagation` or ' 49 | ' `crown_bound_propagation`.' 50 | 'Empty string defaults to IBP.') 51 | FLAGS = flags.FLAGS 52 | 53 | 54 | def load_model(model_name): 55 | """Load model parameters and prediction function.""" 56 | # Choose appropriate prediction function 57 | if model_name == 'mlp': 58 | model_path = MLP_PATH 59 | def model_fn(params, inputs): 60 | inputs = np.reshape(inputs, (inputs.shape[0], -1)) 61 | return utils.predict_mlp(params, inputs) 62 | elif model_name == 'cnn': 63 | model_path = CNN_PATH 64 | model_fn = utils.predict_cnn 65 | else: 66 | raise ValueError('') 67 | 68 | # Load parameters from file 69 | with jax_verify.open_file(model_path, 'rb') as f: 70 | params = pickle.load(f) 71 | return model_fn, params 72 | 73 | 74 | def main(unused_args): 75 | # Load some test samples 76 | with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f: 77 | inputs = np.load(f) 78 | 79 | # Load the parameters of an existing model. 80 | model_pred, params = load_model(FLAGS.model) 81 | 82 | # Evaluation of the model on unperturbed images. 83 | clean_preds = model_pred(params, inputs) 84 | 85 | # Define initial bound 86 | eps = 0.1 87 | initial_bound = jax_verify.IntervalBound( 88 | jnp.minimum(jnp.maximum(inputs - eps, 0.0), 1.0), 89 | jnp.minimum(jnp.maximum(inputs + eps, 0.0), 1.0)) 90 | 91 | # Because our function `model_pred` takes as inputs both the parameters 92 | # `params` and the `inputs`, we need to wrap it such that it only takes 93 | # `inputs` as parameters. 94 | logits_fn = functools.partial(model_pred, params) 95 | 96 | # Apply bound propagation. All boundprop methods take as an input the model 97 | # `function`, and the inital bounds, and return final bounds with the same 98 | # structure as the output of `function`. Internally, these methods work by 99 | # replacing each operation with its boundprop equivalent - see 100 | # bound_propagation.py for details. 101 | boundprop_method = ( 102 | jax_verify.interval_bound_propagation if not FLAGS.boundprop_method else 103 | getattr(jax_verify, FLAGS.boundprop_method)) 104 | assert boundprop_method in ALL_BOUNDPROP_METHODS, 'unsupported method' 105 | final_bound = boundprop_method(logits_fn, initial_bound) 106 | 107 | logging.info('Lower bound: %s', final_bound.lower) 108 | logging.info('Upper bound: %s', final_bound.upper) 109 | logging.info('Clean predictions: %s', clean_preds) 110 | 111 | assert jnp.all(final_bound.lower <= clean_preds), 'Invalid lower bounds' 112 | assert jnp.all(final_bound.upper >= clean_preds), 'Invalid upper bounds' 113 | 114 | 115 | if __name__ == '__main__': 116 | app.run(main) 117 | -------------------------------------------------------------------------------- /examples/run_examples.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | # Run examples with various flags and make sure they don't crash. 17 | # This script is used for continuous integration testing. 18 | 19 | set -e # Exit on any error 20 | 21 | echo "Running run_boundprop.py" 22 | python3 run_boundprop.py 23 | python3 run_boundprop.py --model=cnn 24 | python3 run_boundprop.py --boundprop_method=interval_bound_propagation 25 | python3 run_boundprop.py --boundprop_method=ibpforwardfastlin_bound_propagation 26 | python3 run_boundprop.py --boundprop_method=backward_fastlin_bound_propagation 27 | python3 run_boundprop.py --boundprop_method=backward_crown_bound_propagation 28 | python3 run_boundprop.py --boundprop_method=crownibp_bound_propagation 29 | 30 | echo "Running run_sdp_verify.py" 31 | python3 run_sdp_verify.py --model_name=models/cifar10_wongsmall_eps2_mix.pkl \ 32 | --anneal_lengths="3,3" 33 | python3 run_sdp_verify.py --epsilon=0.1 --dataset=mnist \ 34 | --model_name=models/raghunathan18_pgdnn.pkl --use_exact_eig_train=True \ 35 | --use_exact_eig_eval=True --opt_name=adam --lam_coeff=0.1 --nu_coeff=0.03 \ 36 | --anneal_lengths="3,3" --custom_kappa_coeff=10000 --kappa_zero_after=2 37 | 38 | echo "Running run_lp_solver.py" 39 | python3 run_lp_solver.py --model=toy 40 | -------------------------------------------------------------------------------- /examples/run_lp_solver.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Run verification with out-of-the-box LP solver. 17 | 18 | This example uses jax_verify to generate Linear Program (LP) constraints 19 | expressed in CVXPY, which is then solved with a generic LP solver. 20 | 21 | Note that this CVXPY example is purely illustrative - it incurs a large overhead 22 | for defining the problem, since CVXPY struggles with the large number of 23 | constraints, particularly with convolutional layers. We will release more 24 | performant implementations with other LP solvers in the future. We also welcome 25 | contributions. 26 | """ 27 | import functools 28 | import pickle 29 | from absl import app 30 | from absl import flags 31 | from absl import logging 32 | import jax.numpy as jnp 33 | import jax_verify 34 | from jax_verify.extensions.sdp_verify import utils 35 | from jax_verify.src.linear import forward_linear_bounds 36 | from jax_verify.src.mip_solver.solve_relaxation import solve_planet_relaxation 37 | import numpy as np 38 | 39 | MLP_PATH = 'models/raghunathan18_pgdnn.pkl' 40 | CNN_PATH = 'models/mnist_wongsmall_eps_10_adv.pkl' 41 | 42 | flags.DEFINE_string('model', 'mlp', 'mlp or cnn or toy') 43 | flags.DEFINE_string('boundprop_method', 'ibp', 'ibp or fastlin') 44 | FLAGS = flags.FLAGS 45 | 46 | 47 | def load_model(model_name): 48 | """Load model parameters and prediction function.""" 49 | # Choose appropriate prediction function 50 | if model_name in ('mlp', 'toy'): 51 | model_path = MLP_PATH 52 | def model_fn(params, inputs): 53 | inputs = np.reshape(inputs, (inputs.shape[0], -1)) 54 | return utils.predict_mlp(params, inputs) 55 | elif model_name == 'cnn': 56 | model_path = CNN_PATH 57 | model_fn = utils.predict_cnn 58 | else: 59 | raise ValueError('') 60 | 61 | # Get parameters 62 | if model_name == 'toy': 63 | params = [ 64 | (np.random.normal(size=(784, 2)), np.random.normal(size=(2,))), 65 | (np.random.normal(size=(2, 10)), np.random.normal(size=(10,))), 66 | ] 67 | else: 68 | with jax_verify.open_file(model_path, 'rb') as f: 69 | params = pickle.load(f) 70 | return model_fn, params 71 | 72 | 73 | def main(unused_args): 74 | 75 | # Load the parameters of an existing model. 76 | model_pred, params = load_model(FLAGS.model) 77 | logits_fn = functools.partial(model_pred, params) 78 | 79 | # Load some test samples 80 | with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f: 81 | inputs = np.load(f) 82 | 83 | # Compute boundprop bounds 84 | eps = 0.1 85 | lower_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] - eps, 0.0), 1.0) 86 | upper_bound = jnp.minimum(jnp.maximum(inputs[:2, ...] + eps, 0.0), 1.0) 87 | init_bound = jax_verify.IntervalBound(lower_bound, upper_bound) 88 | 89 | if FLAGS.boundprop_method == 'forwardfastlin': 90 | final_bound = jax_verify.forward_fastlin_bound_propagation(logits_fn, 91 | init_bound) 92 | boundprop_transform = forward_linear_bounds.forward_fastlin_transform 93 | elif FLAGS.boundprop_method == 'ibp': 94 | final_bound = jax_verify.interval_bound_propagation(logits_fn, init_bound) 95 | boundprop_transform = jax_verify.ibp_transform 96 | else: 97 | raise NotImplementedError('Only ibp/fastlin boundprop are' 98 | 'currently supported') 99 | 100 | dummy_output = model_pred(params, inputs) 101 | 102 | # Run LP solver 103 | objective = jnp.where(jnp.arange(dummy_output[0, ...].size) == 0, 104 | jnp.ones_like(dummy_output[0, ...]), 105 | jnp.zeros_like(dummy_output[0, ...])) 106 | objective_bias = 0. 107 | value, _, status = solve_planet_relaxation( 108 | logits_fn, init_bound, boundprop_transform, objective, 109 | objective_bias, index=0) 110 | logging.info('Relaxation LB is : %f, Status is %s', value, status) 111 | value, _, status = solve_planet_relaxation( 112 | logits_fn, init_bound, boundprop_transform, -objective, 113 | objective_bias, index=0) 114 | logging.info('Relaxation UB is : %f, Status is %s', -value, status) 115 | 116 | logging.info('Boundprop LB is : %f', final_bound.lower[0, 0]) 117 | logging.info('Boundprop UB is : %f', final_bound.upper[0, 0]) 118 | 119 | 120 | if __name__ == '__main__': 121 | app.run(main) 122 | -------------------------------------------------------------------------------- /jax_verify/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Library to perform verification on Neural Networks.""" 17 | 18 | from jax_verify.src.bound_propagation import IntervalBound 19 | from jax_verify.src.ibp import bound_transform as ibp_transform 20 | from jax_verify.src.ibp import interval_bound_propagation 21 | from jax_verify.src.intersection import IntersectionBoundTransform 22 | from jax_verify.src.linear.backward_crown import backward_crown_bound_propagation 23 | from jax_verify.src.linear.backward_crown import backward_fastlin_bound_propagation 24 | from jax_verify.src.linear.backward_crown import crownibp_bound_propagation 25 | from jax_verify.src.linear.forward_linear_bounds import forward_crown_bound_propagation 26 | from jax_verify.src.linear.forward_linear_bounds import forward_fastlin_bound_propagation 27 | from jax_verify.src.linear.forward_linear_bounds import ibpforwardfastlin_bound_propagation 28 | from jax_verify.src.nonconvex.methods import nonconvex_constopt_bound_propagation 29 | from jax_verify.src.nonconvex.methods import nonconvex_ibp_bound_propagation 30 | from jax_verify.src.utils import open_file 31 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/README.md: -------------------------------------------------------------------------------- 1 | # Functional Lagrangian Neural Network Verification 2 | 3 | This directory provides an implementation of the Functional Lagrangian framework from [Berrada et al 2021](https://arxiv.org/abs/2102.09479). 4 | 5 | The `run` sub-directory contains the necessary code to reproduce the results of our paper, namely configuration files (in `run/configs/`) to specify the verification problem at hand and its various parameters, and a script (`run_functional_lagragian.py`) to solve that problem (importing code from the rest of the codebase). 6 | 7 | ## Running the Code 8 | 9 | First make sure that: 10 | 11 | 1. you have installed the `jax_verify` package. 12 | 2. your current directory is `extensions/functional_lagrangian/run`. 13 | 14 | Then the results of our paper can be reproduced using the commands provided below. Note that each command verifies a single sample for a single label; the full paper results can be obtained by iterating over the samples and labels. 15 | 16 | **Note:** for each experiment, the required data and model parameters are downloaded to `/tmp/jax_verify` by default. This can be changed by modifying `config.assets_dir` in the config files. 17 | 18 | ### Robust OOD Detection on Stochastic Neural Networks 19 | 20 | #### MLP on MNIST 21 | 22 | ```bash 23 | python3 run_functional_lagrangian.py --config=configs/config_ood_stochastic_model.py:mnist_mlp_2_256 24 | ``` 25 | 26 | #### LeNet on MNIST 27 | 28 | ```bash 29 | python3 run_functional_lagrangian.py --config=configs/config_ood_stochastic_model.py:mnist_cnn 30 | ``` 31 | 32 | #### VGG on CIFAR 33 | 34 | Example for a VGG-32 (other variants also implemented): 35 | 36 | ```bash 37 | python3 run_functional_lagrangian.py --config=configs/config_ood_stochastic_model.py:cifar_vgg_32 38 | ``` 39 | 40 | ### Adversarial Robustness for Stochastic Neural Networks 41 | 42 | Example for an MLP with 2 layers and 256 neurons (other variants also implemented): 43 | 44 | ```bash 45 | python3 run_functional_lagrangian.py --config=configs/config_adv_stochastic_model.py:mnist_mlp_2_256 46 | ``` 47 | 48 | ### Distributionally Robust OOD Detection 49 | 50 | ```bash 51 | python3 run_functional_lagrangian.py --config=configs/config_adv_stochastic_input.py 52 | ``` 53 | 54 | ## Citing 55 | 56 | If you find this code useful, we would appreciate if you cite our paper: 57 | 58 | ``` 59 | @article{berrada2021funclag, 60 | title={Make Sure You're Unsure: A Framework for Verifying Probabilistic Specifications}, 61 | author={Berrada, Leonard and Dathathri, Sumanth and Dvijotham, Krishnamurthy and Stanforth, Robert and Bunel, Rudy and Uesato, Jonathan and Gowal, Sven and Kumar, M. Pawan}, 62 | journal={NeurIPS}, 63 | year={2021} 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Data util functions.""" 17 | 18 | import os 19 | import pickle 20 | from typing import Sequence, Tuple 21 | 22 | import jax.numpy as jnp 23 | import jax_verify 24 | from jax_verify.extensions.functional_lagrangian import verify_utils 25 | from jax_verify.extensions.sdp_verify import utils as sdp_utils 26 | from jax_verify.src import utils as jv_utils 27 | import ml_collections 28 | import numpy as np 29 | 30 | 31 | ConfigDict = ml_collections.ConfigDict 32 | DataSpec = verify_utils.DataSpec 33 | IntervalBound = jax_verify.IntervalBound 34 | SpecType = verify_utils.SpecType 35 | Tensor = jnp.array 36 | LayerParams = verify_utils.LayerParams 37 | ModelParams = verify_utils.ModelParams 38 | ModelParamsElided = verify_utils.ModelParamsElided 39 | 40 | 41 | DATA_PATH = ml_collections.ConfigDict({ 42 | 'emnist_CEDA': 'emnist_CEDA.pkl', 43 | 'mnist': 'mnist', 44 | 'cifar10': 'cifar10', 45 | 'emnist': 'emnist', 46 | 'cifar100': 'cifar100', 47 | }) 48 | 49 | 50 | def load_dataset( 51 | root_dir: str, 52 | dataset: str, 53 | ) -> Tuple[Sequence[np.ndarray], Sequence[np.ndarray]]: 54 | """Loads the MNIST/CIFAR/EMNIST test set examples, saved as numpy arrays.""" 55 | 56 | data_path = DATA_PATH.get(dataset) 57 | 58 | if dataset == 'emnist_CEDA': 59 | with jv_utils.open_file(data_path, 'rb', root_dir=root_dir) as f: 60 | ds = pickle.load(f) 61 | xs, ys = ds[0], ds[1] 62 | xs = np.reshape(xs, [-1, 28, 28, 1]) 63 | ys = np.reshape(ys, [-1]) 64 | return xs, ys 65 | else: 66 | x_filename = os.path.join(data_path, 'x_test.npy') 67 | y_filename = os.path.join(data_path, 'y_test.npy') 68 | with jv_utils.open_file(x_filename, 'rb', root_dir=root_dir) as f: 69 | xs = np.load(f) 70 | with jv_utils.open_file(y_filename, 'rb', root_dir=root_dir) as f: 71 | ys = np.load(f) 72 | return xs, ys 73 | 74 | 75 | def make_data_spec(config_problem: ConfigDict, root_dir: str) -> DataSpec: 76 | """Create data specification from config_problem.""" 77 | xs, ys = load_dataset(root_dir, config_problem.dataset) 78 | if config_problem.dataset in ('cifar10', 'cifar100'): 79 | x = sdp_utils.preprocess_cifar(xs[config_problem.dataset_idx]) 80 | epsilon, input_bounds = sdp_utils.preprocessed_cifar_eps_and_input_bounds( 81 | shape=x.shape, 82 | epsilon=config_problem.epsilon_unprocessed, 83 | inception_preprocess=config_problem.scale_center) 84 | else: 85 | x = xs[config_problem.dataset_idx] 86 | epsilon = config_problem.epsilon_unprocessed 87 | input_bounds = (jnp.zeros_like(x), jnp.ones_like(x)) 88 | true_label = ys[config_problem.dataset_idx] 89 | target_label = config_problem.target_label_idx 90 | return DataSpec( 91 | input=x, 92 | true_label=true_label, 93 | target_label=target_label, 94 | epsilon=epsilon, 95 | input_bounds=input_bounds) 96 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/inner_solvers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Inner solvers.""" 17 | 18 | from jax_verify.extensions.functional_lagrangian.inner_solvers.get_strategy import get_strategy 19 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/inner_solvers/get_strategy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Inner solvers.""" 17 | 18 | from jax_verify.extensions.functional_lagrangian.inner_solvers import input_uncertainty_spec 19 | from jax_verify.extensions.functional_lagrangian.inner_solvers import lp 20 | from jax_verify.extensions.functional_lagrangian.inner_solvers import mixed 21 | from jax_verify.extensions.functional_lagrangian.inner_solvers import pga 22 | from jax_verify.extensions.functional_lagrangian.inner_solvers import uncertainty_spec 23 | 24 | 25 | def get_strategy(config, params, mode): 26 | """Returns configured strategy for inner maximisation.""" 27 | 28 | return _build_strategy_recursively(config.inner_opt.get(mode), params) 29 | 30 | 31 | def _build_strategy_recursively(config_inner_opt, params): 32 | """Create inner solver strategy (potentially recursively).""" 33 | 34 | optim_type = config_inner_opt['optim_type'] 35 | 36 | if optim_type == 'pga': 37 | strategy = pga.PgaStrategy( 38 | n_iter=config_inner_opt['n_iter'], 39 | lr=config_inner_opt['lr'], 40 | n_restarts=config_inner_opt['n_restarts'], 41 | method=config_inner_opt['method'], 42 | finetune_n_iter=config_inner_opt['finetune_n_iter'], 43 | finetune_lr=config_inner_opt['finetune_lr'], 44 | finetune_method=config_inner_opt['finetune_method'], 45 | normalize=config_inner_opt['normalize']) 46 | elif optim_type == 'lp': 47 | strategy = lp.LpStrategy() 48 | elif optim_type == 'probability_threshold': 49 | strategy = input_uncertainty_spec.ProbabilityThresholdSpecStrategy() 50 | elif optim_type == 'uncertainty': 51 | solve_max = {f.value: f for f in uncertainty_spec.MaxType 52 | }[config_inner_opt.get('solve_max')] 53 | strategy = uncertainty_spec.UncertaintySpecStrategy( 54 | n_iter=config_inner_opt.get('n_iter'), 55 | n_pieces=config_inner_opt.get('n_pieces'), 56 | solve_max=solve_max, 57 | learning_rate=config_inner_opt.get('learning_rate'), 58 | ) 59 | elif optim_type == 'uncertainty_input': 60 | layer_type = {f.value: f for f in input_uncertainty_spec.LayerType 61 | }[config_inner_opt.get('layer_type')] 62 | sig_max = config_inner_opt.get('sig_max') 63 | strategy = input_uncertainty_spec.InputUncertaintySpecStrategy( 64 | layer_type=layer_type, sig_max=sig_max) 65 | elif optim_type == 'mixed': 66 | solvers = [[ 67 | _build_strategy_recursively(strat, params) for strat in strats_for_layer 68 | ] for strats_for_layer in config_inner_opt['mixed_strat']] 69 | strategy = mixed.MixedStrategy( 70 | solvers=solvers, solver_weights=config_inner_opt['solver_weights']) 71 | else: 72 | raise NotImplementedError( 73 | f'Unsupported optim type {config_inner_opt["optim_type"]}') 74 | 75 | if (any(p.has_bounds for p in params) and 76 | not strategy.supports_stochastic_parameters()): 77 | # this is a conservative check: we fail if *any* parameter is 78 | # stochastic, although it might not actually be used by strategy 79 | raise ValueError('Inner opt cannot handle stochastic parameters.') 80 | 81 | return strategy 82 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/inner_solvers/lp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Solving linear problems.""" 17 | 18 | from typing import Any 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | from jax_verify.extensions.functional_lagrangian import dual_build 23 | from jax_verify.extensions.functional_lagrangian import lagrangian_form as lag_form 24 | from jax_verify.extensions.functional_lagrangian import verify_utils 25 | from jax_verify.extensions.sdp_verify import utils as sdp_utils 26 | 27 | InnerVerifInstance = verify_utils.InnerVerifInstance 28 | 29 | 30 | class LpStrategy(dual_build.InnerMaxStrategy): 31 | """Solves inner maximisations (for linear Lagrangian) in closed form.""" 32 | 33 | def supports_stochastic_parameters(self): 34 | # can use expectations of parameters instead of deterministic parameters 35 | return True 36 | 37 | def solve_max( 38 | self, 39 | inner_dual_vars: Any, 40 | opt_instance: InnerVerifInstance, 41 | key: jnp.ndarray, 42 | step: int, 43 | ) -> jnp.ndarray: 44 | """Solve maximization problem of opt_instance in closed form. 45 | 46 | Args: 47 | inner_dual_vars: Dual variables for the inner maximisation. 48 | opt_instance: Verification instance that defines optimization problem to 49 | be solved. 50 | key: Jax PRNG key. 51 | step: outer optimization iteration number 52 | 53 | Returns: 54 | max_value: final value of the objective function found. 55 | """ 56 | if opt_instance.affine_before_relu: 57 | raise ValueError('LPStratgey requires affine_before_relu to be False.') 58 | 59 | if not opt_instance.same_lagrangian_form_pre_post: 60 | raise ValueError('Different lagrangian forms on inputs and outputs not' 61 | 'supported') 62 | 63 | if (isinstance(opt_instance.lagrangian_form_pre, lag_form.Linear) or 64 | isinstance(opt_instance.lagrangian_form_post, lag_form.Linear)): 65 | pass 66 | else: 67 | raise ValueError('LpStrategy cannot use Lagrangian form of type ' 68 | f'{type(opt_instance.lagrangian_form_pre)}.') 69 | 70 | # some renaming to simplify variable names 71 | affine_fn, = opt_instance.affine_fns 72 | bounds = opt_instance.bounds 73 | duals_pre = opt_instance.lagrange_params_pre 74 | if (opt_instance.is_last and 75 | opt_instance.spec_type == verify_utils.SpecType.ADVERSARIAL): 76 | # No duals_post for last layer, and objective folded in. 77 | batch_size = bounds[0].lb.shape[0] 78 | duals_post = jnp.ones([batch_size]) 79 | else: 80 | duals_post = opt_instance.lagrange_params_post 81 | 82 | if opt_instance.is_first: 83 | # no "pre-activation" for input of first layer 84 | lb = bounds[0].lb 85 | ub = bounds[0].ub 86 | else: 87 | lb = bounds[0].lb_pre 88 | ub = bounds[0].ub_pre 89 | 90 | zero_inputs = jnp.zeros_like(lb) 91 | affine_constant = affine_fn(zero_inputs) 92 | duals_post = jnp.reshape(duals_post, affine_constant.shape) 93 | 94 | post_slope_x = jax.grad(lambda x: jnp.sum(affine_fn(x) * duals_post))( 95 | zero_inputs) 96 | 97 | if opt_instance.is_first: 98 | # find max element-wise (separable problem): either at lower bound or 99 | # upper bound -- no duals_pre for first layer 100 | max_per_element = jnp.maximum( 101 | post_slope_x * lb, 102 | post_slope_x * ub, 103 | ) 104 | else: 105 | # find max element-wise (separable problem): either at lower bound, 0 or 106 | # upper bound 107 | duals_pre = jnp.reshape(duals_pre, lb.shape) 108 | max_per_element_bounds = jnp.maximum( 109 | post_slope_x * jax.nn.relu(lb) - duals_pre * lb, 110 | post_slope_x * jax.nn.relu(ub) - duals_pre * ub 111 | ) 112 | max_per_element = jnp.where( 113 | jnp.logical_and(lb <= 0, ub >= 0), 114 | jax.nn.relu(max_per_element_bounds), # include zero where feasible 115 | max_per_element_bounds) # otherwise only at boundaries 116 | # sum over coordinates and add constant term (does not change max choice) 117 | max_value = jnp.sum(max_per_element, 118 | axis=tuple(range(1, max_per_element.ndim))) 119 | constant_per_element = affine_constant * duals_post 120 | constant = jnp.sum(constant_per_element, 121 | axis=tuple(range(1, constant_per_element.ndim))) 122 | return max_value + constant 123 | 124 | def init_layer_inner_params(self, opt_instance): 125 | """Returns initial inner maximisation duals and their types.""" 126 | # no need for auxiliary variables 127 | return None, sdp_utils.DualVarTypes.EQUALITY 128 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/inner_solvers/mixed.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Mixture of strategies for solving the inner maximization.""" 17 | from typing import Any 18 | 19 | import jax.numpy as jnp 20 | 21 | from jax_verify.extensions.functional_lagrangian import dual_build 22 | from jax_verify.extensions.functional_lagrangian import verify_utils 23 | 24 | InnerVerifInstance = verify_utils.InnerVerifInstance 25 | 26 | 27 | class MixedStrategy(dual_build.InnerMaxStrategy): 28 | """Solves inner maximisations with a combination of solvers.""" 29 | 30 | def __init__(self, solvers, solver_weights): 31 | self._solvers = solvers 32 | self._solver_weights = solver_weights 33 | 34 | def solve_max( 35 | self, 36 | inner_dual_vars: Any, 37 | opt_instance: InnerVerifInstance, 38 | key: jnp.ndarray, 39 | step: int, 40 | ) -> jnp.ndarray: 41 | """Solve maximization problem of opt_instance with a combination of solvers. 42 | 43 | Args: 44 | inner_dual_vars: Dual variables for the inner maximisation. 45 | opt_instance: Verification instance that defines optimization problem to 46 | be solved. 47 | key: Jax PRNG key. 48 | step: outer optimization iteration number. 49 | 50 | Returns: 51 | final_value: final value of the objective function found by PGA. 52 | """ 53 | # some renaming to simplify variable names 54 | layer_idx = opt_instance.idx 55 | solver_weights_for_layer = self._solver_weights[layer_idx] 56 | solvers_for_layer = self._solvers[layer_idx] 57 | final_value = 0. 58 | for solver, solver_weight, inner_var in zip(solvers_for_layer, 59 | solver_weights_for_layer, 60 | inner_dual_vars): 61 | final_value += solver_weight * solver.solve_max(inner_var, opt_instance, 62 | key, step) 63 | return final_value # pytype: disable=bad-return-type # jnp-array 64 | 65 | def init_layer_inner_params(self, opt_instance): 66 | """Returns initial inner maximisation duals and their types.""" 67 | 68 | dual_vars_types = [ 69 | solver.init_layer_inner_params(opt_instance) 70 | for solver in self._solvers[opt_instance.idx] 71 | ] 72 | return zip(*dual_vars_types) 73 | 74 | def supports_stochastic_parameters(self): 75 | for solvers_for_layer in self._solvers: 76 | for solver in solvers_for_layer: 77 | if not solver.supports_stochastic_parameters(): 78 | return False 79 | return True 80 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/inner_solvers/pga/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Projected gradient ascent.""" 17 | 18 | from jax_verify.extensions.functional_lagrangian.inner_solvers.pga import pga_strategy 19 | 20 | PgaStrategy = pga_strategy.PgaStrategy 21 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/inner_solvers/pga/optimizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Optimizers used in the PGA strategy.""" 17 | 18 | import collections 19 | from typing import Callable, Optional, Tuple 20 | 21 | import chex 22 | import jax 23 | import jax.numpy as jnp 24 | 25 | from jax_verify.extensions.functional_lagrangian.inner_solvers.pga import utils 26 | 27 | _State = collections.namedtuple('State', ['iteration', 'rng', 'state']) # pylint: disable=invalid-name 28 | 29 | 30 | def grad_fn( 31 | loss_fn: utils.LossFn, 32 | ) -> Callable[[chex.Array], Tuple[chex.Array, chex.Array]]: 33 | """Returns the analytical gradient as computed by `jax.grad`.""" 34 | 35 | def reduced_loss_fn(x): 36 | loss = loss_fn(x) 37 | return jnp.sum(loss), loss 38 | 39 | return jax.grad(reduced_loss_fn, has_aux=True) 40 | 41 | 42 | class IteratedFGSM: 43 | """L-infinity normalized steps.""" 44 | 45 | def __init__(self, learning_rate: chex.Numeric): 46 | self._learning_rate = learning_rate 47 | 48 | def init(self, loss_fn: utils.LossFn, rng: chex.PRNGKey, 49 | x: chex.Array) -> _State: 50 | del x 51 | self._loss_fn = loss_fn 52 | return _State(jnp.array(0, dtype=jnp.int32), rng, ()) 53 | 54 | def minimize(self, x: chex.Array, 55 | state: _State) -> Tuple[chex.Array, chex.Array, _State]: 56 | """Performs a single minimization step.""" 57 | lr = jnp.array(self._learning_rate) 58 | g, loss = grad_fn(self._loss_fn)(x) 59 | if g is None: 60 | raise ValueError('loss_fn does not depend on input.') 61 | g = jnp.sign(g) 62 | g, s = self._update(lr, g, state.state) 63 | new_state = _State(state.iteration + 1, state.rng, s) 64 | return x - g, loss, new_state 65 | 66 | def _update( 67 | self, 68 | learning_rate: chex.Numeric, 69 | gradients: chex.Array, 70 | state: chex.Array, 71 | ) -> Tuple[chex.Array, chex.Array]: 72 | return learning_rate.astype(gradients.dtype) * gradients, state # pytype: disable=attribute-error # numpy-scalars 73 | 74 | 75 | class PGD: 76 | """Uses the above defined optimizers to minimize and loss function.""" 77 | 78 | def __init__( 79 | self, 80 | optimizer, 81 | num_steps: int, 82 | initialize_fn: Optional[utils.InitializeFn] = None, 83 | project_fn: Optional[utils.ProjectFn] = None, 84 | ): 85 | self._optimizer = optimizer 86 | if initialize_fn is None: 87 | initialize_fn = lambda rng, x: x 88 | self._initialize_fn = initialize_fn 89 | if project_fn is None: 90 | project_fn = lambda x, origin_x: x 91 | self._project_fn = project_fn 92 | self._num_steps = num_steps 93 | 94 | def __call__( 95 | self, 96 | loss_fn: utils.LossFn, 97 | rng: chex.PRNGKey, 98 | x: chex.Array, 99 | ) -> chex.Array: 100 | 101 | def _optimize(rng, x): 102 | """Optimizes loss_fn.""" 103 | 104 | def body_fn(_, inputs): 105 | opt_state, current_x = inputs 106 | current_x, _, opt_state = self._optimizer.minimize(current_x, opt_state) 107 | current_x = self._project_fn(current_x, x) 108 | return opt_state, current_x 109 | 110 | rng, next_rng = jax.random.split(rng) 111 | opt_state = self._optimizer.init(loss_fn, next_rng, x) 112 | current_x = self._project_fn(self._initialize_fn(rng, x), x) 113 | _, current_x = jax.lax.fori_loop(0, self._num_steps, body_fn, 114 | (opt_state, current_x)) 115 | return current_x 116 | 117 | x = _optimize(rng, x) 118 | return jax.lax.stop_gradient(x) 119 | 120 | 121 | class Restarted: 122 | """Repeats an optimization multiple times.""" 123 | 124 | def __init__( 125 | self, 126 | optimizer, 127 | restarts_using_tiling: int = 1, 128 | has_batch_dim: bool = True, 129 | ): 130 | self._wrapped_optimizer = optimizer 131 | if (isinstance(restarts_using_tiling, int) and restarts_using_tiling > 1 and 132 | not has_batch_dim): 133 | raise ValueError('Cannot use tiling when `has_batch_dim` is False.') 134 | self._has_batch_dim = has_batch_dim 135 | if (isinstance(restarts_using_tiling, int) and restarts_using_tiling < 1): 136 | raise ValueError('Fewer than one restart requested.') 137 | self._restarts_using_tiling = restarts_using_tiling 138 | 139 | def __call__( 140 | self, 141 | loss_fn: utils.LossFn, 142 | rng: chex.PRNGKey, 143 | inputs: chex.Array, 144 | ) -> chex.Array: 145 | """Performs an optimization multiple times by tiling the inputs.""" 146 | if not self._has_batch_dim: 147 | opt_inputs = self._wrapped_optimizer(loss_fn, rng, inputs) 148 | opt_losses = loss_fn(opt_inputs) 149 | return opt_inputs, opt_losses # pytype: disable=bad-return-type # numpy-scalars 150 | 151 | # Tile the inputs and labels. 152 | batch_size = inputs.shape[0] 153 | 154 | # Tile inputs. 155 | shape = inputs.shape[1:] 156 | # Shape is [num_restarts * batch_size, ...]. 157 | inputs = jnp.tile(inputs, [self._restarts_using_tiling] + [1] * len(shape)) 158 | 159 | # Optimize. 160 | opt_inputs = self._wrapped_optimizer(loss_fn, rng, inputs) 161 | opt_losses = loss_fn(opt_inputs) 162 | opt_losses = jnp.reshape(opt_losses, 163 | [self._restarts_using_tiling, batch_size]) 164 | 165 | # Extract best. 166 | i = jnp.argmin(opt_losses, axis=0) 167 | j = jnp.arange(batch_size) 168 | 169 | shape = opt_inputs.shape[1:] 170 | return jnp.reshape(opt_inputs, 171 | (self._restarts_using_tiling, batch_size) + shape)[i, j] 172 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/inner_solvers/pga/square.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of Square (https://arxiv.org/pdf/1912.00049).""" 17 | 18 | from typing import Callable, List, Tuple 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | from jax_verify.extensions.functional_lagrangian.inner_solvers.pga import utils 25 | 26 | 27 | def _schedule(values: List[float], 28 | boundaries: List[int], 29 | dtype=jnp.float32) -> Callable[[chex.Array], chex.Numeric]: 30 | """Schedule the value of p, the proportion of elements to be modified.""" 31 | large_step = max(boundaries) + 1 32 | boundaries = boundaries + [large_step, large_step + 1] 33 | num_values = len(values) 34 | values = jnp.array(values, dtype=jnp.float32) 35 | large_step = jnp.array([large_step] * len(boundaries), dtype=jnp.int32) 36 | boundaries = jnp.array(boundaries, dtype=jnp.int32) 37 | 38 | def _get(step): 39 | """Returns the value according to the current step and schedule.""" 40 | b = boundaries - jnp.minimum(step + 1, large_step + 1) 41 | b = jnp.where(b < 0, large_step, b) 42 | idx = jnp.minimum(jnp.argmin(b), num_values - 1) 43 | return values[idx].astype(dtype) 44 | 45 | return _get 46 | 47 | 48 | class Square: 49 | """Performs a blackbox optimization as in https://arxiv.org/pdf/1912.00049.""" 50 | 51 | def __init__( 52 | self, 53 | num_steps: int, 54 | epsilon: chex.Numeric, 55 | initialize_fn: utils.InitializeFn, 56 | bounds: Tuple[chex.ArrayTree, chex.ArrayTree], 57 | ): 58 | """Creates a Square attack.""" 59 | self._num_steps = num_steps 60 | self._initialize_fn = initialize_fn 61 | self._project_fn = utils.linf_project_fn(epsilon=epsilon, bounds=bounds) 62 | self._epsilon = epsilon 63 | self._p_init = p = .8 64 | self._p_schedule = _schedule([ 65 | p, p / 2, p / 4, p / 4, p / 8, p / 16, p / 32, p / 64, p / 128, p / 256, 66 | p / 512 67 | ], [10, 50, 200, 500, 1000, 2000, 4000, 6000, 8000]) 68 | 69 | def __call__( 70 | self, 71 | loss_fn: utils.LossFn, 72 | rng: chex.PRNGKey, 73 | x: chex.Array, 74 | ) -> chex.Array: 75 | if len(x.shape) != 4: 76 | raise ValueError(f'Unsupported tensor shape: {x.shape}') 77 | h, w, c = x.shape[1:] 78 | batch_size = x.shape[0] 79 | broadcast_shape = [batch_size] + [1] * (len(x.shape) - 1) 80 | min_size = 1 81 | 82 | def init_fn(rng): 83 | init_x = self._project_fn(self._initialize_fn(rng, x), x) 84 | init_loss = loss_fn(init_x) 85 | return init_x, init_loss 86 | 87 | def random_window_mask(rng, size, dtype): 88 | height_rng, width_rng = jax.random.split(rng) 89 | height_offset = jax.random.randint( 90 | height_rng, 91 | shape=(batch_size, 1, 1, 1), 92 | minval=0, 93 | maxval=h - size, 94 | dtype=jnp.int32) 95 | width_offset = jax.random.randint( 96 | width_rng, 97 | shape=(batch_size, 1, 1, 1), 98 | minval=0, 99 | maxval=w - size, 100 | dtype=jnp.int32) 101 | h_range = jnp.reshape(jnp.arange(h), [1, h, 1, 1]) 102 | w_range = jnp.reshape(jnp.arange(w), [1, 1, w, 1]) 103 | return jnp.logical_and( 104 | jnp.logical_and(height_offset <= h_range, 105 | h_range < height_offset + size), 106 | jnp.logical_and(width_offset <= w_range, 107 | w_range < width_offset + size)).astype(dtype) 108 | 109 | def random_linf_perturbation(rng, x, size): 110 | rng, perturbation_rng = jax.random.split(rng) 111 | perturbation = jax.random.randint( 112 | perturbation_rng, shape=(batch_size, 1, 1, c), minval=0, 113 | maxval=2) * 2 - 1 114 | return random_window_mask(rng, size, x.dtype) * perturbation 115 | 116 | def body_fn(i, loop_inputs): 117 | best_x, best_loss, rng = loop_inputs 118 | 119 | p = self._get_p(i) 120 | size = jnp.maximum( 121 | jnp.round(jnp.sqrt(p * h * w / c)).astype(jnp.int32), min_size) 122 | rng, next_rng = jax.random.split(rng) 123 | 124 | perturbation = random_linf_perturbation(next_rng, best_x, size) 125 | current_x = best_x + perturbation * self._epsilon 126 | 127 | current_x = self._project_fn(current_x, x) 128 | loss = loss_fn(current_x) 129 | 130 | cond = loss < best_loss 131 | best_x = jnp.where(jnp.reshape(cond, broadcast_shape), current_x, best_x) 132 | best_loss = jnp.where(cond, loss, best_loss) 133 | return best_x, best_loss, rng 134 | 135 | rng, next_rng = jax.random.split(rng) 136 | best_x, best_loss = init_fn(next_rng) 137 | loop_inputs = (best_x, best_loss, rng) 138 | return jax.lax.fori_loop(0, self._num_steps, body_fn, loop_inputs)[0] 139 | 140 | def _get_p(self, step): 141 | """Schedule on `p`.""" 142 | step = step / self._num_steps * 10000. 143 | return self._p_schedule(jnp.array(step)) 144 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/inner_solvers/pga/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities.""" 17 | 18 | from typing import Callable, Optional, Tuple 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | InitializeFn = Callable[[chex.Array, chex.Array], chex.Array] 25 | ProjectFn = Callable[[chex.Array, chex.Array], chex.Array] 26 | LossFn = Callable[[chex.Array], chex.Array] 27 | 28 | 29 | def linf_project_fn(epsilon: float, bounds: Tuple[float, float]) -> ProjectFn: 30 | 31 | def project_fn(x, origin_x): 32 | dx = jnp.clip(x - origin_x, -epsilon, epsilon) 33 | return jnp.clip(origin_x + dx, bounds[0], bounds[1]) 34 | 35 | return project_fn 36 | 37 | 38 | def bounded_initialize_fn( 39 | bounds: Optional[Tuple[chex.Array, chex.Array]] = None,) -> InitializeFn: 40 | """Returns an initialization function.""" 41 | if bounds is None: 42 | return noop_initialize_fn() 43 | else: 44 | lower_bound, upper_bound = bounds 45 | 46 | def _initialize_fn(rng, x): 47 | a = jax.random.uniform(rng, x.shape, minval=0., maxval=1.) 48 | x = a * lower_bound + (1. - a) * upper_bound 49 | return x 50 | 51 | return _initialize_fn 52 | 53 | 54 | def noop_initialize_fn() -> InitializeFn: 55 | 56 | def _initialize_fn(rng, x): 57 | del rng 58 | return x 59 | 60 | return _initialize_fn 61 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/lagrangian_form.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Lagrangian penalty functions.""" 17 | 18 | import abc 19 | 20 | from typing import Sequence, Union 21 | 22 | import jax 23 | import jax.numpy as jnp 24 | import jax.random as random 25 | import ml_collections 26 | 27 | PRNGKey = jnp.array 28 | Tensor = jnp.array 29 | Params = Union[Tensor, Sequence[Tensor]] 30 | Shape = Union[int, Sequence[int]] 31 | ConfigDict = ml_collections.ConfigDict 32 | 33 | 34 | def _flatten_spatial_dims(x: Tensor) -> Tensor: 35 | """Flatten spatial dimensions (assumed batched).""" 36 | return jnp.reshape(x, [x.shape[0], -1]) 37 | 38 | 39 | def size_from_shape(shape: Shape) -> int: 40 | return int(jnp.prod(jnp.array(shape))) 41 | 42 | 43 | class LagrangianForm(metaclass=abc.ABCMeta): 44 | """Abstract class for Lagrangian form.""" 45 | 46 | def __init__(self, name): 47 | self._name = name 48 | 49 | @abc.abstractmethod 50 | def _init_params_per_sample(self, key: PRNGKey, *args) -> Params: 51 | """Initialize the parameters of the Lagrangian form.""" 52 | 53 | def init_params(self, key, *args, **kwargs): 54 | params = self._init_params_per_sample(key, *args, **kwargs) 55 | # expansion below currently assumes batch-size of 1 56 | return jax.tree_map(lambda p: jnp.expand_dims(p, 0), params) 57 | 58 | @abc.abstractmethod 59 | def _apply(self, x: Tensor, lagrange_params: Params, step: int) -> Tensor: 60 | """Apply the Lagrangian form the input x given lagrange_params.""" 61 | 62 | def apply(self, x: Tensor, lagrange_params: Params, step: int) -> Tensor: 63 | """Apply the Lagrangian form the input x given lagrange_params. 64 | 65 | Args: 66 | x: layer inputs, assumed batched (in leading dimension). Note that the 67 | spatial dimensions of x are flattened. 68 | lagrange_params: parameters of the lagrangian parameters, assumed to have 69 | the same batch-size as x. If provided as None, this function returns 0. 70 | step: outer optimization iteration number (unused). 71 | 72 | Returns: 73 | value_per_sample: Lagrangian penalty per element of the mini-batch. 74 | """ 75 | if lagrange_params is None: 76 | return jnp.array(0.0) 77 | x = _flatten_spatial_dims(x) 78 | value_per_sample = self._apply(x, lagrange_params, step) 79 | return value_per_sample 80 | 81 | def process_params(self, lagrange_params: Params): 82 | return lagrange_params 83 | 84 | @property 85 | def name(self): 86 | """Return name.""" 87 | return self._name 88 | 89 | 90 | class Linear(LagrangianForm): 91 | """Linear LagrangianForm (equivalent to DeepVerify formulation).""" 92 | 93 | def __init__(self): 94 | super().__init__('Linear') 95 | 96 | def _init_params_per_sample(self, 97 | key: PRNGKey, 98 | l_shape: Shape, 99 | init_zeros: bool = True) -> Params: 100 | size = size_from_shape(l_shape) 101 | if init_zeros: 102 | return jnp.zeros([size]) 103 | else: 104 | return random.normal(key, [size]) 105 | 106 | def _apply_per_sample(self, x: Tensor, lagrange_params: Params, 107 | step: int) -> Tensor: 108 | del step 109 | return jnp.dot(x, lagrange_params) 110 | 111 | def _apply(self, x: Tensor, lagrange_params: Params, step: int) -> Tensor: 112 | apply_per_sample = lambda a, b: self._apply_per_sample(a, b, step) 113 | return jax.vmap(apply_per_sample)(x, lagrange_params) 114 | 115 | 116 | class LinearExp(LagrangianForm): 117 | """LinearExp LagrangianForm.""" 118 | 119 | def __init__(self): 120 | super().__init__('LinearExp') 121 | 122 | def _init_params_per_sample(self, 123 | key: PRNGKey, 124 | l_shape: Shape, 125 | init_zeros: bool = False) -> Params: 126 | size = size_from_shape(l_shape) 127 | if init_zeros: 128 | return jnp.zeros([size]), jnp.ones(()), jnp.zeros([size]) 129 | else: 130 | return (1e-4 * random.normal(key, [size]), 1e-2 * random.normal(key, ()), 131 | 1e-2 * random.normal(key, [size])) 132 | 133 | def _apply_per_sample(self, x: Tensor, lagrange_params: Params, 134 | step: int) -> Tensor: 135 | del step 136 | linear_term = jnp.dot(x, lagrange_params[0]) 137 | lagrange_params = self.process_params(lagrange_params) 138 | exp_term = lagrange_params[1] * jnp.exp(jnp.dot(x, lagrange_params[2])) 139 | 140 | return linear_term + exp_term 141 | 142 | def _apply(self, x: Tensor, lagrange_params: Params, step: int) -> Tensor: 143 | apply_per_sample = lambda a, b: self._apply_per_sample(a, b, step) 144 | return jax.vmap(apply_per_sample)(x, lagrange_params) 145 | 146 | 147 | def get_lagrangian_form(config_lagrangian_form: ConfigDict) -> LagrangianForm: 148 | """Create the Lagrangian form.""" 149 | name = config_lagrangian_form['name'] 150 | kwargs = config_lagrangian_form['kwargs'] 151 | if name == 'linear': 152 | return Linear(**kwargs) 153 | elif name == 'linear_exp': 154 | return LinearExp(**kwargs) 155 | else: 156 | raise NotImplementedError(f'Unrecognized lagrangian functional: {name}') 157 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """API to load model parameters.""" 17 | 18 | import dataclasses 19 | import os 20 | import pickle 21 | from typing import Any, Optional 22 | import urllib 23 | 24 | import jax.numpy as jnp 25 | import jax_verify 26 | from jax_verify.extensions.functional_lagrangian import verify_utils 27 | from jax_verify.src import utils as jv_utils 28 | import ml_collections 29 | import numpy as np 30 | 31 | ConfigDict = ml_collections.ConfigDict 32 | ModelParams = verify_utils.ModelParams 33 | 34 | INTERNAL_MODEL_PATHS = ml_collections.ConfigDict({ 35 | 'mnist_ceda': 'models/mnist_ceda.pkl', 36 | 'mnist_cnn': 'models/mnist_lenet_dropout.pkl', 37 | 'cifar_vgg_16': 'models/cifar_vgg_16_dropout.pkl', 38 | 'cifar_vgg_32': 'models/cifar_vgg_32_dropout.pkl', 39 | 'cifar_vgg_64': 'models/cifar_vgg_64_dropout.pkl', 40 | }) 41 | 42 | PROBA_SAFETY_URL = ( 43 | 'https://github.com/matthewwicker/ProbabilisticSafetyforBNNs/raw/master' 44 | '/MNIST/concurMNIST2/MNIST_Networks') 45 | 46 | PROBA_SAFETY_MODEL_PATHS = ml_collections.ConfigDict({ 47 | 'mnist_mlp_1_1024': 'VIMODEL_MNIST_1_1024_relu.net.npz', 48 | 'mnist_mlp_1_128': 'VIMODEL_MNIST_1_128_relu.net.npz', 49 | 'mnist_mlp_1_2048': 'VIMODEL_MNIST_1_2048_relu.net.npz', 50 | 'mnist_mlp_1_256': 'VIMODEL_MNIST_1_256_relu.net.npz', 51 | 'mnist_mlp_1_4096': 'VIMODEL_MNIST_1_4096_relu.net.npz', 52 | 'mnist_mlp_1_512': 'VIMODEL_MNIST_1_512_relu.net.npz', 53 | 'mnist_mlp_1_64': 'VIMODEL_MNIST_1_64_relu.net.npz', 54 | 'mnist_mlp_2_1024': 'VIMODEL_MNIST_2_1024_relu.net.npz', 55 | 'mnist_mlp_2_128': 'VIMODEL_MNIST_2_128_relu.net.npz', 56 | 'mnist_mlp_2_256': 'VIMODEL_MNIST_2_256_relu.net.npz', 57 | 'mnist_mlp_2_512': 'VIMODEL_MNIST_2_512_relu.net.npz', 58 | 'mnist_mlp_2_64': 'VIMODEL_MNIST_2_64_relu.net.npz', 59 | }) 60 | 61 | 62 | def _load_pickled_model(root_dir: str, model_name: str) -> ModelParams: 63 | model_path = getattr(INTERNAL_MODEL_PATHS, model_name.lower()) 64 | if model_path.endswith('mnist_ceda.pkl'): 65 | with jv_utils.open_file(model_path, 'rb', root_dir=root_dir) as f: 66 | params_iterables = pickle.load(f, encoding='bytes') 67 | else: 68 | with jv_utils.open_file(model_path, 'rb', root_dir=root_dir) as f: 69 | params_iterables = list(np.load(f, allow_pickle=True).item().values()) 70 | return make_model_params_from_iterables(params_iterables) 71 | 72 | 73 | def make_model_params_from_iterables(raw_params: Any) -> ModelParams: 74 | """Make list of LayerParams from list of iterables.""" 75 | conv_field_names = [ 76 | f.name for f in dataclasses.fields(verify_utils.ConvParams) 77 | ] 78 | fc_field_names = [ 79 | f.name for f in dataclasses.fields(verify_utils.FCParams) 80 | ] 81 | 82 | net = [] 83 | for layer_params in raw_params: 84 | if isinstance(layer_params, tuple): 85 | w, b = layer_params 86 | layer = verify_utils.FCParams(w=w, b=b) 87 | elif (isinstance(layer_params, dict) 88 | and layer_params.get('type') == 'linear'): 89 | fc_params = dict( 90 | (k, v) for k, v in layer_params.items() if k in fc_field_names) 91 | if fc_params.get('dropout_rate', 0) > 0: 92 | w = fc_params['w'] 93 | # adapt expected value of 'w' 94 | fc_params['w'] = w * (1.0 - fc_params['dropout_rate']) 95 | fc_params['w_bound'] = jax_verify.IntervalBound( 96 | lower_bound=jnp.minimum(w, 0.0), upper_bound=jnp.maximum(w, 0.0)) 97 | layer = verify_utils.FCParams(**fc_params) 98 | elif isinstance(layer_params, dict): 99 | conv_params = dict( 100 | (k, v) for k, v in layer_params.items() if k in conv_field_names) 101 | # deal with 'W' vs 'w' 102 | if 'W' in layer_params: 103 | conv_params['w'] = layer_params['W'] 104 | layer = verify_utils.ConvParams(**conv_params) 105 | else: 106 | raise TypeError( 107 | f'layer_params type not recognized: {type(layer_params)}.') 108 | net += [layer] 109 | return net 110 | 111 | 112 | def _load_proba_safety_model( 113 | root_dir: str, 114 | model_name: str, 115 | num_std_for_bound: float, 116 | ) -> ModelParams: 117 | """Load model trained in Probabilistic Safety for BNNs paper.""" 118 | model_path = getattr(PROBA_SAFETY_MODEL_PATHS, model_name.lower()) 119 | local_path = os.path.join(root_dir, model_path) 120 | if not os.path.exists(local_path): 121 | download_url = os.path.join(PROBA_SAFETY_URL, model_path) 122 | urllib.request.urlretrieve(download_url, local_path) 123 | with open(local_path, 'rb') as f: 124 | data = np.load(f, allow_pickle=True, encoding='bytes') 125 | if not isinstance(data, np.ndarray): 126 | data = data['arr_0'] 127 | 128 | assert len(data) % 4 == 0 129 | 130 | net = [] 131 | for layer_idx in range(0, len(data) // 2, 2): 132 | # data: [w_0, b_0, w_1, b_1, ..., w_0_std, b_0_std, w_1_std, b_1_std, ...] 133 | w = jnp.array(data[layer_idx]) 134 | b = jnp.array(data[layer_idx + 1]) 135 | 136 | w_std = jnp.array(data[layer_idx + len(data) // 2]) 137 | b_std = jnp.array(data[layer_idx + len(data) // 2 + 1]) 138 | 139 | w_bound = jax_verify.IntervalBound(w - num_std_for_bound * w_std, 140 | w + num_std_for_bound * w_std) 141 | b_bound = jax_verify.IntervalBound(b - num_std_for_bound * b_std, 142 | b + num_std_for_bound * b_std) 143 | 144 | net += [ 145 | verify_utils.FCParams( 146 | w=w, 147 | b=b, 148 | w_std=w_std, 149 | b_std=b_std, 150 | w_bound=w_bound, 151 | b_bound=b_bound) 152 | ] 153 | 154 | return net 155 | 156 | 157 | def load_model( 158 | root_dir: str, 159 | model_name: str, 160 | num_std_for_bound: Optional[float], 161 | ) -> ModelParams: 162 | """Load and process model parameters.""" 163 | if model_name.startswith('mnist_mlp'): 164 | return _load_proba_safety_model(root_dir, model_name, num_std_for_bound) 165 | else: 166 | return _load_pickled_model(root_dir, model_name) 167 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/run/configs/config_adv_stochastic_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Verification configuration.""" 17 | 18 | import ml_collections 19 | 20 | 21 | def get_pga_params_config_dict(): 22 | """Create config dict with for running PGA.""" 23 | pga_config = ml_collections.ConfigDict() 24 | pga_config.optim_type = 'pga' 25 | pga_config.n_iter = 10000 26 | pga_config.lr = 0.1 27 | pga_config.n_restarts = 300 28 | pga_config.method = 'square' 29 | pga_config.finetune_n_iter = 50 30 | pga_config.finetune_lr = 0.1 31 | pga_config.finetune_method = 'pgd' 32 | pga_config.normalize = False 33 | return pga_config 34 | 35 | 36 | def get_adv_softmax_config(num_layers): 37 | """Running mixed solver strategy.""" 38 | # Train config 39 | config = ml_collections.ConfigDict() 40 | config.train = ml_collections.ConfigDict() 41 | config.train.optim_type = 'mixed' 42 | 43 | u_config_train = get_pga_params_config_dict() 44 | 45 | solver_config = {'optim_type': 'lp'} 46 | config.train.mixed_strat = [[solver_config]] * num_layers + [[u_config_train]] 47 | config.train.solver_weights = [[1.0]] * (num_layers + 1) 48 | 49 | # Eval config 50 | config.eval = ml_collections.ConfigDict() 51 | u_config_eval = { 52 | 'optim_type': 'uncertainty', 53 | 'solve_max': 'exp_bound', 54 | } 55 | config.eval.optim_type = 'mixed' 56 | config.eval.mixed_strat = [[solver_config]] * num_layers + [[u_config_eval]] 57 | config.eval.solver_weights = [[1.0]] * (num_layers + 1) 58 | return config 59 | 60 | 61 | def get_attack_config(): 62 | """Attack config.""" 63 | # Config to use for adversarial attak lower bound 64 | config = ml_collections.ConfigDict() 65 | config.num_steps = 200 66 | config.learning_rate = 1. 67 | config.num_samples = 50 68 | 69 | return config 70 | 71 | 72 | def get_dual_config(): 73 | """Dual config.""" 74 | # type of lagrangian functional: e.g. dense_quad, mlp 75 | config = ml_collections.ConfigDict() 76 | config.lagrangian_form = ml_collections.ConfigDict({ 77 | 'name': 'linear', 78 | 'kwargs': {}, 79 | }) 80 | 81 | config.affine_before_relu = False 82 | 83 | return config 84 | 85 | 86 | def get_config(model_name='mnist_mlp_1_128'): 87 | """Main configdict.""" 88 | 89 | if model_name.startswith('mnist_mlp_1'): 90 | dataset = 'mnist' 91 | num_layers = 2 92 | num_std_for_bound = 3.0 93 | elif model_name.startswith('mnist_mlp_2'): 94 | dataset = 'mnist' 95 | num_layers = 3 96 | num_std_for_bound = 3.0 97 | elif model_name.startswith('mnist_cnn'): 98 | dataset = 'mnist' 99 | num_layers = 5 100 | num_std_for_bound = None 101 | elif model_name.startswith('cifar_vgg'): 102 | dataset = 'cifar10' 103 | num_layers = 6 104 | num_std_for_bound = None 105 | 106 | config = ml_collections.ConfigDict() 107 | 108 | config.assets_dir = '/tmp/jax_verify' # directory to download data and models 109 | 110 | config.seed = 23 111 | config.use_gpu = True 112 | config.spec_type = 'adversarial_softmax' 113 | config.labels_in_distribution = [] 114 | config.use_best = False # PGA may be overly optimistic 115 | 116 | config.problem = ml_collections.ConfigDict() 117 | config.problem.dataset = dataset 118 | config.problem.dataset_idx = 0 # which example from dataset to verify? 119 | config.problem.target_label_idx = 0 # which class to target? 120 | config.problem.epsilon_unprocessed = 0.001 # radius before preprocessing 121 | config.problem.scale_center = False 122 | config.problem.num_std_for_bound = num_std_for_bound 123 | 124 | # check adversary cannot bring loss below feasibility_margin 125 | config.problem.feasibility_margin = 0.0 126 | 127 | config.problem.model_name = model_name 128 | 129 | config.dual = get_dual_config() 130 | config.attack = get_attack_config() 131 | 132 | # whether to block asynchronous dispatch at each iteration for precise timing 133 | config.block_to_time = False 134 | 135 | # Choose boundprop method: e.g. 'nonconvex', 'ibp', 'crown_ibp' 136 | config.boundprop_type = 'nonconvex' 137 | config.bilinear_boundprop_type = 'ibp' 138 | 139 | # nonconvex boundprop params, only used if config.boundprop_type = 'nonconvex' 140 | config.nonconvex_boundprop_steps = 0 141 | config.nonconvex_boundprop_nodes = 128 142 | 143 | config.outer_opt = ml_collections.ConfigDict() 144 | config.outer_opt.lr_init = 1e-3 # initial learning rate 145 | config.outer_opt.steps_per_anneal = 1000 # steps between each anneal 146 | config.outer_opt.anneal_lengths = '' # steps per epoch 147 | config.outer_opt.anneal_factor = 0.1 # learning rate anneal factor 148 | config.outer_opt.num_anneals = 3 # # of times to anneal learning rate 149 | config.outer_opt.opt_name = 'adam' # Optix class: "adam" "sgd", "rmsprop" 150 | config.outer_opt.opt_kwargs = {} # Momentum for gradient descent' 151 | 152 | config.inner_opt = get_adv_softmax_config(num_layers) 153 | return config 154 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/run/configs/config_ood_stochastic_input.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Verification configuration.""" 17 | 18 | import ml_collections 19 | 20 | 21 | def get_input_uncertainty_config(num_layers=2): 22 | """Running mixed solver strategy.""" 23 | config = ml_collections.ConfigDict() 24 | config.train = ml_collections.ConfigDict() 25 | config.train.optim_type = 'mixed' 26 | u_config = { 27 | 'optim_type': 'uncertainty', 28 | 'solve_max': 'exp', 29 | 'n_iter': 20, 30 | 'n_pieces': 30, 31 | 'learning_rate': 1. 32 | } 33 | solver_config_input_init = { 34 | 'optim_type': 'uncertainty_input', 35 | 'layer_type': 'input', 36 | 'sig_max': .1 37 | } 38 | solver_config_input_first = { 39 | 'optim_type': 'uncertainty_input', 40 | 'layer_type': 'first', 41 | 'sig_max': .1 42 | } 43 | solver_config = {'optim_type': 'lp'} 44 | config.train.mixed_strat = ( 45 | [[solver_config_input_init], [solver_config_input_first]] + 46 | [[solver_config]] * num_layers + [[u_config]]) 47 | config.train.solver_weights = [[1.0]] * (num_layers + 3) 48 | u_config_eval = { 49 | 'optim_type': 'uncertainty', 50 | 'n_iter': 0, 51 | 'n_pieces': 100, 52 | 'solve_max': 'exp_bound', 53 | } 54 | 55 | config.eval = ml_collections.ConfigDict() 56 | config.eval.optim_type = 'mixed' 57 | config.eval.mixed_strat = ( 58 | [[solver_config_input_init], [solver_config_input_first]] + 59 | [[solver_config]] * num_layers + [[u_config_eval]]) 60 | config.eval.solver_weights = [[1.0]] * (num_layers + 3) 61 | return config 62 | 63 | 64 | def get_dual_config(): 65 | """Dual config.""" 66 | config = ml_collections.ConfigDict() 67 | names = ['linear_exp', 'linear', 'linear', 'linear', 'linear'] 68 | config.lagrangian_form = [] 69 | for name in names: 70 | config.lagrangian_form.append( 71 | ml_collections.ConfigDict({ 72 | 'name': name, 73 | 'kwargs': {}, 74 | })) 75 | 76 | config.affine_before_relu = False 77 | 78 | return config 79 | 80 | 81 | def get_attack_config(): 82 | """Attack config.""" 83 | # Config to use for adversarial attak lower bound 84 | config = ml_collections.ConfigDict() 85 | config.num_steps = 200 86 | config.learning_rate = 1. 87 | 88 | return config 89 | 90 | 91 | def get_config(): 92 | """Main configdict.""" 93 | 94 | config = ml_collections.ConfigDict() 95 | 96 | config.assets_dir = '/tmp/jax_verify' # directory to download data and models 97 | 98 | config.seed = 23 99 | config.use_gpu = True 100 | config.spec_type = 'uncertainty' 101 | config.labels_in_distribution = [] 102 | config.use_best = True 103 | 104 | config.problem = ml_collections.ConfigDict() 105 | config.problem.dataset = 'emnist_CEDA' 106 | config.problem.dataset_idx = 0 # which example from dataset to verify? 107 | config.problem.target_label_idx = 4 # which class to target? 108 | config.problem.epsilon_unprocessed = 0.04 # radius before preprocessing 109 | config.problem.probability_threshold = .97 110 | config.problem.input_shape = (28, 28, 1) 111 | # Use inception_preprocessing i.e. [-1,1]-scaled inputs 112 | config.problem.scale_center = False 113 | config.problem.model_name = 'mnist_ceda' 114 | 115 | # check adversary cannot bring loss below feasibility_margin 116 | config.problem.feasibility_margin = 0.0 117 | 118 | config.add_input_noise = True 119 | config.dual = get_dual_config() 120 | config.attack = get_attack_config() 121 | 122 | # whether to block asynchronous dispatch at each iteration for precise timing 123 | config.block_to_time = False 124 | 125 | # Choose boundprop method: e.g. 'nonconvex', 'ibp', 'crown_ibp' 126 | config.boundprop_type = 'nonconvex' 127 | config.bilinear_boundprop_type = 'ibp' 128 | 129 | # nonconvex boundprop params, only used if config.boundprop_type = 'nonconvex' 130 | config.nonconvex_boundprop_steps = 100 131 | config.nonconvex_boundprop_nodes = 128 132 | 133 | config.outer_opt = ml_collections.ConfigDict() 134 | config.outer_opt.lr_init = 1e-4 # initial learning rate 135 | config.outer_opt.steps_per_anneal = 10 # steps between each anneal 136 | config.outer_opt.anneal_lengths = '60000, 20000, 20000' # steps per epoch 137 | config.outer_opt.anneal_factor = 0.1 # learning rate anneal factor 138 | config.outer_opt.num_anneals = 2 # # of times to anneal learning rate 139 | config.outer_opt.opt_name = 'adam' # Optix class: "adam" "sgd", "rmsprop" 140 | config.outer_opt.opt_kwargs = {} # Momentum for gradient descent' 141 | 142 | config.inner_opt = get_input_uncertainty_config() 143 | return config 144 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/run/configs/config_ood_stochastic_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Config for experiments on uncertainty spec with stochastic neural networks.""" 17 | 18 | import ml_collections 19 | 20 | 21 | def get_uncertainty_config(num_layers): 22 | """Running mixed solver strategy.""" 23 | config = ml_collections.ConfigDict() 24 | config.train = ml_collections.ConfigDict() 25 | config.train.optim_type = 'mixed' 26 | 27 | u_config_train = { 28 | 'optim_type': 'uncertainty', 29 | 'n_iter': 1_000, 30 | 'n_pieces': 0, 31 | 'solve_max': 'exp', 32 | 'learning_rate': 1.0, 33 | } 34 | solver_config = {'optim_type': 'lp'} 35 | config.train.mixed_strat = [[solver_config]] * num_layers + [[u_config_train]] 36 | config.train.solver_weights = [[1.0]] * (num_layers + 1) 37 | 38 | u_config_eval = { 39 | 'optim_type': 'uncertainty', 40 | 'n_iter': 0, 41 | 'n_pieces': 100, 42 | 'solve_max': 'exp_bound', 43 | } 44 | 45 | config.eval = ml_collections.ConfigDict() 46 | config.eval.optim_type = 'mixed' 47 | config.eval.mixed_strat = [[solver_config]] * num_layers + [[u_config_eval]] 48 | config.eval.solver_weights = [[1.0]] * (num_layers + 1) 49 | 50 | return config 51 | 52 | 53 | def get_attack_config(): 54 | """Attack config.""" 55 | # Config to use for adversarial attak lower bound 56 | config = ml_collections.ConfigDict() 57 | config.num_steps = 200 58 | config.learning_rate = 1. 59 | config.num_samples = 50 60 | 61 | return config 62 | 63 | 64 | def get_dual_config(): 65 | """Dual config.""" 66 | # type of lagrangian functional: e.g. dense_quad, mlp 67 | config = ml_collections.ConfigDict() 68 | config.lagrangian_form = ml_collections.ConfigDict({ 69 | 'name': 'linear', 70 | 'kwargs': {}, 71 | }) 72 | 73 | config.affine_before_relu = False 74 | 75 | return config 76 | 77 | 78 | def get_config(model_name='mnist_mlp_2_128'): 79 | """Main configdict.""" 80 | 81 | config = ml_collections.ConfigDict() 82 | 83 | config.assets_dir = '/tmp/jax_verify' # directory to download data and models 84 | 85 | if model_name.startswith('mnist_mlp'): 86 | dataset = 'emnist' 87 | num_layers = 3 88 | num_std_for_bound = 3.0 89 | epsilon = 0.01 90 | elif model_name.startswith('mnist_cnn'): 91 | dataset = 'emnist' 92 | num_layers = 5 93 | num_std_for_bound = None 94 | epsilon = 0.01 95 | elif model_name.startswith('cifar_vgg'): 96 | dataset = 'cifar100' 97 | num_layers = 6 98 | num_std_for_bound = None 99 | epsilon = 0.001 100 | 101 | config.seed = 23 102 | config.use_gpu = False 103 | config.spec_type = 'uncertainty' 104 | config.labels_in_distribution = [] 105 | config.use_best = False # PGA may be overly optimistic 106 | 107 | config.problem = ml_collections.ConfigDict() 108 | config.problem.model_name = model_name 109 | config.problem.dataset = dataset 110 | config.problem.dataset_idx = 0 # which example from dataset to verify? 111 | config.problem.target_label_idx = 0 # which class to target? 112 | config.problem.epsilon_unprocessed = epsilon # radius before preprocessing 113 | config.problem.scale_center = False 114 | config.problem.num_std_for_bound = num_std_for_bound 115 | 116 | # check adversary cannot bring loss below feasibility_margin 117 | config.problem.feasibility_margin = 0.0 118 | 119 | config.dual = get_dual_config() 120 | config.attack = get_attack_config() 121 | 122 | # whether to block asynchronous dispatch at each iteration for precise timing 123 | config.block_to_time = False 124 | 125 | # Choose boundprop method: e.g. 'nonconvex', 'ibp', 'crown_ibp' 126 | config.boundprop_type = 'nonconvex' 127 | # Choose boundprop method: e.g. 'ibp', 'crown' 128 | config.bilinear_boundprop_type = 'nonconvex' 129 | 130 | # nonconvex boundprop params, only used if config.boundprop_type = 'nonconvex' 131 | config.nonconvex_boundprop_steps = 100 132 | config.nonconvex_boundprop_nodes = 128 133 | 134 | config.outer_opt = ml_collections.ConfigDict() 135 | config.outer_opt.lr_init = 1e-3 # initial learning rate 136 | config.outer_opt.steps_per_anneal = 250 # steps between each anneal 137 | config.outer_opt.anneal_lengths = '' # steps per epoch 138 | config.outer_opt.anneal_factor = 0.1 # learning rate anneal factor 139 | config.outer_opt.num_anneals = 3 # # of times to anneal learning rate 140 | config.outer_opt.opt_name = 'adam' # Optix class: "adam" "sgd", "rmsprop" 141 | config.outer_opt.opt_kwargs = {} # Momentum for gradient descent' 142 | 143 | config.inner_opt = get_uncertainty_config(num_layers) 144 | return config 145 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/run/run_functional_lagrangian.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Run verification for feedforward ReLU networks.""" 17 | 18 | import os 19 | import time 20 | from typing import Any, Callable, Mapping 21 | 22 | from absl import app 23 | from absl import flags 24 | from absl import logging 25 | import jax.numpy as jnp 26 | from jax_verify.extensions.functional_lagrangian import attacks 27 | from jax_verify.extensions.functional_lagrangian import bounding 28 | from jax_verify.extensions.functional_lagrangian import data 29 | from jax_verify.extensions.functional_lagrangian import dual_solve 30 | from jax_verify.extensions.functional_lagrangian import model 31 | from jax_verify.extensions.functional_lagrangian import verify_utils 32 | from jax_verify.extensions.sdp_verify import utils as sdp_utils 33 | import ml_collections 34 | from ml_collections import config_flags 35 | 36 | PROJECT_PATH = os.getcwd() 37 | 38 | config_flags.DEFINE_config_file( 39 | 'config', f'{PROJECT_PATH}/configs/config_ood_stochastic_model.py', 40 | 'ConfigDict for the experiment.') 41 | 42 | FLAGS = flags.FLAGS 43 | 44 | 45 | def make_logger(log_message: str) -> Callable[[int, Mapping[str, Any]], None]: 46 | """Creates a logger. 47 | 48 | Args: 49 | log_message: description message for the logs. 50 | 51 | Returns: 52 | Function that accepts a step counter and measurements, and logs them. 53 | """ 54 | 55 | def log_fn(step, measures): 56 | msg = f'[{log_message}] step={step}' 57 | for k, v in measures.items(): 58 | msg += f', {k}={v}' 59 | logging.info(msg) 60 | 61 | return log_fn 62 | 63 | 64 | def main(unused_argv): 65 | 66 | config = FLAGS.config 67 | 68 | logging.info('Config: \n %s', config) 69 | 70 | data_spec = data.make_data_spec(config.problem, config.assets_dir) 71 | spec_type = {e.value: e for e in verify_utils.SpecType}[config.spec_type] 72 | 73 | if spec_type == verify_utils.SpecType.UNCERTAINTY: 74 | if data_spec.true_label in config.labels_in_distribution: 75 | return 76 | else: 77 | if data_spec.true_label == data_spec.target_label: 78 | return 79 | params = model.load_model( 80 | root_dir=config.assets_dir, 81 | model_name=config.problem.model_name, 82 | num_std_for_bound=config.problem.get('num_std_for_bound'), 83 | ) 84 | 85 | params_elided, bounds, bp_bound, bp_time = ( 86 | bounding.make_elided_params_and_bounds(config, data_spec, spec_type, 87 | params)) 88 | 89 | dual_state = ml_collections.ConfigDict(type_safe=False) 90 | 91 | def spec_fn(inputs): 92 | # params_elided is a list of network parameters, with the final 93 | # layer elided with the objective (output size is 1, and not num classes) 94 | return jnp.squeeze(sdp_utils.predict_cnn(params_elided, inputs), axis=-1) 95 | 96 | def run(mode: str): 97 | 98 | logger = make_logger(log_message=mode.title()) 99 | 100 | start_time = time.time() 101 | prng_key = dual_solve.solve_dual( 102 | dual_state=dual_state, 103 | config=config, 104 | bounds=bounds, 105 | spec_type=spec_type, 106 | spec_fn=spec_fn, 107 | params=params_elided, 108 | mode=mode, 109 | logger=logger) 110 | elapsed_time = time.time() - start_time 111 | 112 | adv_objective = attacks.adversarial_attack( # pytype: disable=wrong-arg-types # jax-devicearray 113 | params, data_spec, spec_type, prng_key, config.attack.num_steps, 114 | config.attack.learning_rate, config.attack.get('num_samples', 1)) 115 | 116 | output_dict = { 117 | 'dataset_idx': config.problem.dataset_idx, 118 | 'true_label': data_spec.true_label, 119 | 'target_label': data_spec.target_label, 120 | 'epsilon': config.problem.epsilon_unprocessed, 121 | 'verified_ub': dual_state.loss, 122 | 'verification_time': elapsed_time, 123 | 'adv_lb': adv_objective, 124 | 'adv_success': adv_objective > config.problem.feasibility_margin, 125 | 'bp_bound': bp_bound, 126 | 'bp_time': bp_time, 127 | } 128 | logger = make_logger(log_message=mode.title()) 129 | logger(0, output_dict) 130 | 131 | run('train') 132 | run('eval') 133 | 134 | 135 | if __name__ == '__main__': 136 | app.run(main) 137 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/specification.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Functions to elide the specification objective with the model.""" 17 | 18 | import dataclasses 19 | 20 | import jax.numpy as jnp 21 | import jax_verify 22 | from jax_verify.extensions.functional_lagrangian import verify_utils 23 | import ml_collections 24 | import numpy as np 25 | 26 | ConfigDict = ml_collections.ConfigDict 27 | DataSpec = verify_utils.DataSpec 28 | IntervalBound = jax_verify.IntervalBound 29 | SpecType = verify_utils.SpecType 30 | Tensor = jnp.array 31 | LayerParams = verify_utils.LayerParams 32 | ModelParams = verify_utils.ModelParams 33 | ModelParamsElided = verify_utils.ModelParamsElided 34 | 35 | 36 | def elide_adversarial_spec( 37 | params: ModelParams, 38 | data_spec: DataSpec, 39 | ) -> ModelParamsElided: 40 | """Elide params to have last layer merged with the adversarial objective. 41 | 42 | Args: 43 | params: parameters of the model under verification. 44 | data_spec: data specification. 45 | 46 | Returns: 47 | params_elided: elided parameters with the adversarial objective folded in 48 | the last layer (and bounds adapted accordingly). 49 | """ 50 | 51 | def elide_fn(w_fin, b_fin): 52 | label_onehot = jnp.eye(w_fin.shape[-1])[data_spec.true_label] 53 | target_onehot = jnp.eye(w_fin.shape[-1])[data_spec.target_label] 54 | obj_orig = target_onehot - label_onehot 55 | obj_bp = jnp.matmul(w_fin, obj_orig) 56 | const = jnp.expand_dims(jnp.vdot(obj_orig, b_fin), axis=-1) 57 | obj = jnp.reshape(obj_bp, (obj_bp.size, 1)) 58 | return obj, const 59 | 60 | last_params = params[-1] 61 | w_elided, b_elided = elide_fn(last_params.w, last_params.b) 62 | last_params_elided = verify_utils.FCParams(w_elided, b_elided) 63 | 64 | if last_params.has_bounds: 65 | w_bound_elided, b_bound_elided = jax_verify.interval_bound_propagation( 66 | elide_fn, last_params.w_bound, last_params.b_bound) 67 | last_params_elided = dataclasses.replace( 68 | last_params_elided, w_bound=w_bound_elided, b_bound=b_bound_elided) 69 | 70 | params_elided = params[:-1] + [last_params_elided] 71 | return params_elided 72 | 73 | 74 | def elide_adversarial_softmax_spec( 75 | params: ModelParams, 76 | data_spec: DataSpec, 77 | ) -> ModelParamsElided: 78 | """Elide params to have uncertainty objective appended as a new last layer. 79 | 80 | Args: 81 | params: parameters of the model under verification. 82 | data_spec: data specification. 83 | 84 | Returns: 85 | params_elided: parameters with the uncertainty objective appended as 86 | the last 'layer'. 87 | """ 88 | op_size = params[-1].w.shape[-1] 89 | e = np.zeros((op_size, 1)) 90 | e[data_spec.target_label] = 1. 91 | e[data_spec.true_label] = -1. 92 | params_elided = params + [verify_utils.FCParams(jnp.array(e), jnp.zeros(()))] 93 | 94 | return params_elided 95 | 96 | 97 | def elide_uncertainty_spec( 98 | params: ModelParams, 99 | data_spec: DataSpec, 100 | probability_threshold: float, 101 | ) -> ModelParamsElided: 102 | """Elide params to have uncertainty objective appended as a new last layer. 103 | 104 | Args: 105 | params: parameters of the model under verification. 106 | data_spec: data specification. 107 | probability_threshold: Maximum probability threshold for OOD detection. 108 | 109 | Returns: 110 | params_elided: parameters with the uncertainty objective appended as 111 | the last 'layer'. 112 | """ 113 | op_size = params[-1].w.shape[-1] 114 | e = np.zeros((op_size, 1)) 115 | e[data_spec.target_label] = 1. 116 | e -= probability_threshold 117 | params_elided = params + [verify_utils.FCParams(jnp.array(e), jnp.zeros(()))] 118 | 119 | return params_elided 120 | -------------------------------------------------------------------------------- /jax_verify/extensions/functional_lagrangian/verify_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Small helper functions.""" 17 | 18 | import abc 19 | import collections 20 | import dataclasses 21 | import enum 22 | from typing import Callable, List, Optional, Union 23 | 24 | import chex 25 | import jax.numpy as jnp 26 | import jax_verify 27 | from jax_verify.extensions.functional_lagrangian import lagrangian_form 28 | from jax_verify.extensions.sdp_verify import utils as sdp_utils 29 | import ml_collections 30 | 31 | 32 | Params = collections.namedtuple('Params', ['inner', 'outer']) 33 | ParamsTypes = collections.namedtuple('ParamsTypes', 34 | ['inner', 'outer', 'lagrangian_form']) 35 | 36 | DataSpec = collections.namedtuple( 37 | 'DataSpec', 38 | ['input', 'true_label', 'target_label', 'epsilon', 'input_bounds']) 39 | 40 | Array = chex.Array 41 | ArrayTree = chex.ArrayTree 42 | ConfigDict = ml_collections.ConfigDict 43 | IntervalBound = jax_verify.IntervalBound 44 | Tensor = jnp.array 45 | LayerParams = Union['FCParams', 'ConvParams'] 46 | LagrangianForm = lagrangian_form.LagrangianForm 47 | ModelParams = List[LayerParams] 48 | ModelParamsElided = ModelParams 49 | 50 | 51 | class AbstractParams(abc.ABC): 52 | """AbstractParams.""" 53 | 54 | def __call__(self, inputs: Tensor) -> Tensor: 55 | """Forward pass on layer.""" 56 | return sdp_utils.fwd(inputs, self.params) 57 | 58 | @property 59 | @abc.abstractmethod 60 | def params(self): 61 | """Representation of params with sdp_utils.fwd convention.""" 62 | 63 | @property 64 | def has_bounds(self): 65 | return self.w_bound is not None or self.b_bound is not None # pytype: disable=attribute-error # bind-properties 66 | 67 | 68 | @dataclasses.dataclass 69 | class FCParams(AbstractParams): 70 | """Params of fully connected layer.""" 71 | 72 | w: Tensor 73 | b: Tensor 74 | 75 | w_bound: Optional[IntervalBound] = None 76 | b_bound: Optional[IntervalBound] = None 77 | 78 | w_std: Optional[Tensor] = None 79 | b_std: Optional[Tensor] = None 80 | 81 | dropout_rate: float = 0.0 82 | 83 | @property 84 | def params(self): 85 | return (self.w, self.b) 86 | 87 | 88 | @dataclasses.dataclass 89 | class ConvParams(AbstractParams): 90 | """Params of convolutional layer.""" 91 | 92 | w: Tensor 93 | b: Tensor 94 | 95 | stride: int 96 | padding: str 97 | 98 | n_cin: Optional[int] = None 99 | 100 | w_bound: Optional[IntervalBound] = None 101 | b_bound: Optional[IntervalBound] = None 102 | 103 | w_std: Optional[Tensor] = None 104 | b_std: Optional[Tensor] = None 105 | 106 | dropout_rate: float = 0.0 107 | 108 | @property 109 | def params(self): 110 | return { 111 | 'W': self.w, 112 | 'b': self.b, 113 | 'n_cin': self.n_cin, 114 | 'stride': self.stride, 115 | 'padding': self.padding, 116 | } 117 | 118 | 119 | class SpecType(enum.Enum): 120 | # `params` represent a network of repeated relu(Wx+b) 121 | # The final output also includes a relu activation, and `obj` composes 122 | # the final layer weights with the original objective 123 | UNCERTAINTY = 'uncertainty' 124 | ADVERSARIAL = 'adversarial' 125 | ADVERSARIAL_SOFTMAX = 'adversarial_softmax' 126 | PROBABILITY_THRESHOLD = 'probability_threshold' 127 | 128 | 129 | class Distribution(enum.Enum): 130 | """Distribution of the weights and biases.""" 131 | GAUSSIAN = 'gaussian' 132 | BERNOULLI = 'bernoulli' 133 | 134 | 135 | class NetworkType(enum.Enum): 136 | """Distribution of the weights and biases.""" 137 | DETERMINISTIC = 'deterministic' 138 | STOCHASTIC = 'stochastic' 139 | 140 | 141 | @dataclasses.dataclass(frozen=True) 142 | class InnerVerifInstance: 143 | """Specification of inner problems.""" 144 | 145 | affine_fns: List[Callable[[Array], Array]] 146 | bounds: List[sdp_utils.IntervalBound] 147 | 148 | lagrangian_form_pre: Optional[LagrangianForm] 149 | lagrangian_form_post: Optional[LagrangianForm] 150 | 151 | lagrange_params_pre: Optional[ArrayTree] 152 | lagrange_params_post: Optional[ArrayTree] 153 | 154 | is_first: bool 155 | is_last: bool 156 | 157 | idx: int 158 | spec_type: SpecType 159 | affine_before_relu: bool 160 | 161 | @property 162 | def same_lagrangian_form_pre_post(self) -> bool: 163 | if self.is_first: 164 | return True 165 | elif self.is_last: 166 | return True 167 | else: 168 | name_pre = self.lagrangian_form_pre.name 169 | name_post = self.lagrangian_form_post.name 170 | return name_pre == name_post 171 | -------------------------------------------------------------------------------- /jax_verify/extensions/sdp_verify/README.md: -------------------------------------------------------------------------------- 1 | # SDP Neural Network Verification 2 | 3 | This directory provides a standalone implementation of the SDP-FO algorithm from [Dathathri et al 2020](https://arxiv.org/abs/2010.11645). 4 | 5 | If you find this code useful, we would appreciate if you cite our paper: 6 | 7 | ``` 8 | @article{dathathri2020sdpfo, 9 | title={Enabling certification of verification-agnostic networks via memory-efficient semidefinite programming}, 10 | author={Dathathri, Sumanth and Dvijotham, Krishnamurthy and Kurakin, Alexey and Raghunathan, Aditi and Uesato, Jonathan and Bunel, Rudy and Shankar, Shreya and Steinhardt, Jacob and Goodfellow, Ian and Liang, Percy and Kohli, Pushmeet}, 11 | journal={NeurIPS}, 12 | year={2020} 13 | } 14 | ``` 15 | -------------------------------------------------------------------------------- /jax_verify/extensions/sdp_verify/boundprop_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: disable=invalid-name 17 | """Crown bound propagation used in SDP verification.""" 18 | 19 | import functools 20 | import jax.numpy as jnp 21 | import jax_verify 22 | from jax_verify.extensions.sdp_verify import utils 23 | from jax_verify.src import bound_propagation 24 | from jax_verify.src.nonconvex import duals 25 | from jax_verify.src.nonconvex import nonconvex 26 | from jax_verify.src.nonconvex import optimizers 27 | from jax_verify.src.nonconvex.optimizers import LinesearchFistaOptimizer as FistaOptimizer 28 | IntBound = utils.IntBound 29 | 30 | 31 | def boundprop(params, x, epsilon, input_bounds, boundprop_type, 32 | **extra_boundprop_kwargs): 33 | """Computes interval bounds for NN intermediate activations. 34 | 35 | Args: 36 | params: Parameters for the NN. 37 | x: Batch of inputs to NN (dimension 2 for MLP or 4 for CNN) 38 | epsilon: l-inf perturbation to the input. 39 | input_bounds: Valid lower and upper for the NN as a tuple -- e.g. (0., 1.) 40 | boundprop_type: string, indicating method used for bound propagation, e.g. 41 | 'crown_ibp' or 'nonconvex' 42 | **extra_boundprop_kwargs: any additional kwargs, passed directly to 43 | underlying boundprop method 44 | 45 | Returns: 46 | layer_bounds: upper and lower bounds across the layers of the NN as a list 47 | of IntBound-s. 48 | """ 49 | boundprop_type_to_method = { 50 | 'crown_ibp': _crown_ibp_boundprop, 51 | 'nonconvex': _nonconvex_boundprop, 52 | } 53 | assert boundprop_type in boundprop_type_to_method, 'invalid boundprop_type' 54 | boundprop_method = boundprop_type_to_method[boundprop_type] 55 | return boundprop_method(params, x, epsilon, input_bounds, 56 | **extra_boundprop_kwargs) 57 | 58 | 59 | def _crown_ibp_boundprop(params, x, epsilon, input_bounds): 60 | """Runs CROWN-IBP for each layer separately.""" 61 | def get_layer_act(layer_idx, inputs): 62 | act = utils.predict_cnn(params[:layer_idx], inputs) 63 | return act 64 | 65 | initial_bound = jax_verify.IntervalBound( 66 | jnp.maximum(x - epsilon, input_bounds[0]), 67 | jnp.minimum(x + epsilon, input_bounds[1])) 68 | 69 | out_bounds = [IntBound( 70 | lb_pre=None, ub_pre=None, lb=initial_bound.lower, ub=initial_bound.upper)] 71 | for i in range(1, len(params) + 1): 72 | fwd = functools.partial(get_layer_act, i) 73 | bound = jax_verify.crownibp_bound_propagation(fwd, initial_bound) 74 | out_bounds.append( 75 | IntBound(lb_pre=bound.lower, ub_pre=bound.upper, 76 | lb=jnp.maximum(0, bound.lower), 77 | ub=jnp.maximum(0, bound.upper))) 78 | return out_bounds 79 | 80 | 81 | def _nonconvex_boundprop(params, x, epsilon, input_bounds, 82 | nonconvex_boundprop_steps=100, 83 | nonconvex_boundprop_nodes=128): 84 | """Wrapper for nonconvex bound propagation.""" 85 | # Get initial bounds for boundprop 86 | init_bounds = utils.init_bound(x, epsilon, input_bounds=input_bounds, 87 | add_batch_dim=False) 88 | 89 | # Build fn to boundprop through 90 | all_act_fun = functools.partial(utils.predict_cnn, params, 91 | include_preactivations=True) 92 | 93 | # Collect the intermediate bounds. 94 | input_bound = jax_verify.IntervalBound(init_bounds.lb, init_bounds.ub) 95 | 96 | optimizer = optimizers.OptimizingConcretizer( 97 | FistaOptimizer(num_steps=nonconvex_boundprop_steps), 98 | max_parallel_nodes=nonconvex_boundprop_nodes) 99 | nonconvex_algorithm = nonconvex.nonconvex_algorithm( 100 | duals.WolfeNonConvexBound, optimizer) 101 | 102 | all_outputs, _ = bound_propagation.bound_propagation( 103 | nonconvex_algorithm, all_act_fun, input_bound) 104 | _, intermediate_nonconvex_bounds = all_outputs 105 | 106 | bounds = [init_bounds] 107 | for nncvx_bound in intermediate_nonconvex_bounds: 108 | bounds.append(utils.IntBound(lb_pre=nncvx_bound.lower, 109 | ub_pre=nncvx_bound.upper, 110 | lb=jnp.maximum(nncvx_bound.lower, 0), 111 | ub=jnp.maximum(nncvx_bound.upper, 0))) 112 | return bounds 113 | -------------------------------------------------------------------------------- /jax_verify/sdp_verify.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module for SDP verification of neural networks.""" 17 | 18 | from jax_verify.extensions.sdp_verify.sdp_verify import dual_fun 19 | from jax_verify.extensions.sdp_verify.sdp_verify import solve_sdp_dual 20 | from jax_verify.extensions.sdp_verify.sdp_verify import solve_sdp_dual_simple 21 | from jax_verify.extensions.sdp_verify.utils import SdpDualVerifInstance 22 | 23 | __all__ = ( 24 | "dual_fun", 25 | "SdpDualVerifInstance", 26 | "solve_sdp_dual", 27 | "solve_sdp_dual_simple", 28 | ) 29 | -------------------------------------------------------------------------------- /jax_verify/src/branching/backpropagation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Backpropagation of sensitivity values. 17 | 18 | This computes a measure of dependence of the output objective upon each 19 | intermediate value of a ReLU-based neural network. This sensitivity measure 20 | is constructed by back-propagating through the network, and linearising 21 | each ReLU. 22 | 23 | This sensitivity measure is derived in "Provable Defenses via the Convex Outer 24 | Adversarial Polytope", https://arxiv.org/pdf/1711.00851.pdf. 25 | """ 26 | 27 | import functools 28 | from typing import Mapping, Optional, Sequence, Union 29 | 30 | import jax 31 | from jax import lax 32 | import jax.numpy as jnp 33 | from jax_verify.src import bound_propagation 34 | from jax_verify.src import graph_traversal 35 | from jax_verify.src import synthetic_primitives 36 | from jax_verify.src.types import Index, Nest, Primitive, Tensor # pylint: disable=g-multiple-import 37 | 38 | 39 | def _sensitivity_linear_op( 40 | primitive: Primitive, 41 | outval: Tensor, 42 | *args: Union[bound_propagation.Bound, Tensor], 43 | **kwargs) -> Sequence[Optional[Tensor]]: 44 | """Back-propagates sensitivity through a linear Jax operation. 45 | 46 | For linear ops, sensitivity of the inputs is computed by applying the 47 | transpose of the op to the sensitivity of the outputs. 48 | 49 | Args: 50 | primitive: Linear (or affine) primitive op through which to backprop. 51 | outval: Sensitivity value for the op's output. 52 | *args: Inputs to the linear op, in the form of interval bounds (for 53 | variables) and primal values (for constants). 54 | **kwargs: Additional parameters to the linear op. 55 | 56 | Returns: 57 | Sensitivitity values for the op's variable inputs. Entries will be `None` 58 | for constant inputs. 59 | """ 60 | # Use auto-diff to perform the transpose-product. 61 | primal_args = [ 62 | jnp.zeros_like(arg.lower) 63 | if isinstance(arg, bound_propagation.Bound) else arg 64 | for arg in args] 65 | 66 | _, vjp = jax.vjp(functools.partial(primitive.bind, **kwargs), *primal_args) 67 | return vjp(outval) 68 | 69 | 70 | def _sensitivity_relu( 71 | outval: Tensor, 72 | inp: bound_propagation.Bound 73 | ) -> Tensor: 74 | """Back-propagates sensitivity through a ReLU. 75 | 76 | For the purposes of back-propagating sensitivity, 77 | the ReLU uses a linear approximation given by the chord 78 | from (lower_bound, ReLU(lower_bound)) to (upper_bound, ReLU(upper_bound)) 79 | 80 | Args: 81 | outval: Sensitivity values for the ReLU outputs. 82 | inp: Interval bounds on the ReLU input 83 | 84 | Returns: 85 | Sensitivity values for the ReLU input. 86 | """ 87 | # Arrange for always-blocking and always-passing ReLUs to give a slope 88 | # of zero and one respectively. 89 | lower_bound = jnp.minimum(inp.lower, 0.) 90 | upper_bound = jnp.maximum(inp.upper, 0.) 91 | 92 | chord_slope = upper_bound / jnp.maximum( 93 | upper_bound - lower_bound, jnp.finfo(jnp.float32).eps) 94 | return chord_slope * outval, # pytype: disable=bad-return-type # jax-ndarray 95 | 96 | 97 | _LINEAR_PRIMITIVES: Sequence[Primitive] = [ 98 | *bound_propagation.AFFINE_PRIMITIVES, 99 | *bound_propagation.RESHAPE_PRIMITIVES, 100 | lax.div_p, 101 | ] 102 | 103 | 104 | def _build_sensitivity_ops() -> Mapping[ 105 | Primitive, graph_traversal.PrimitiveBacktransformFn]: 106 | """Builds functions to back-prop 'sensitivity' through individual primitives. 107 | 108 | Returns: 109 | Sensitivity computation functions, in the form suitable to be passed to 110 | `PropagationGraph.backward_propagation()`. 111 | """ 112 | sensitivity_primitive_ops = { 113 | primitive: functools.partial(_sensitivity_linear_op, primitive) 114 | for primitive in _LINEAR_PRIMITIVES} 115 | sensitivity_primitive_ops[synthetic_primitives.relu_p] = _sensitivity_relu 116 | # Through the sign function, we don't really have a sensitivity. 117 | sensitivity_sign = lambda outval, _: (jnp.zeros_like(outval),) 118 | sensitivity_primitive_ops[jax.lax.sign_p] = sensitivity_sign 119 | 120 | return sensitivity_primitive_ops 121 | 122 | 123 | sensitivity_backward_transform = graph_traversal.BackwardOpwiseTransform( 124 | _build_sensitivity_ops(), sum) 125 | 126 | 127 | class SensitivityAlgorithm(bound_propagation.PropagationAlgorithm[Tensor]): 128 | """Propagation algorithm computing output sensitivity to intermediate nodes.""" 129 | 130 | def __init__( 131 | self, 132 | forward_bound_transform: bound_propagation.BoundTransform, 133 | sensitivity_targets: Sequence[Index], 134 | output_sensitivity: Optional[Tensor] = None): 135 | """Define the sensitivity that needs to be computed. 136 | 137 | Args: 138 | forward_bound_transform: Transformation to use to compute intermediate 139 | bounds. 140 | sensitivity_targets: Index of the nodes for which we want to obtain 141 | sensitivities. 142 | output_sensitivity: (Optional) Linear coefficients for which we want the 143 | sensitivity, defined over the output. 144 | """ 145 | self._forward_bnd_algorithm = bound_propagation.ForwardPropagationAlgorithm( 146 | forward_bound_transform) 147 | self._output_sensitivity = output_sensitivity 148 | self._sensitivity_targets = sensitivity_targets 149 | self.target_sensitivities = [] 150 | 151 | def propagate(self, graph: graph_traversal.PropagationGraph, 152 | *bounds: Nest[graph_traversal.GraphInput]): 153 | assert len(graph.outputs) == 1 154 | out, bound_env = self._forward_bnd_algorithm.propagate(graph, bounds) 155 | 156 | if self._output_sensitivity is not None: 157 | output_sensitivity = self._output_sensitivity 158 | else: 159 | output_sensitivity = -jnp.ones(out[0].shape) 160 | sensitivities, backward_env = graph.backward_propagation( 161 | sensitivity_backward_transform, bound_env, 162 | {graph.outputs[0]: output_sensitivity}, 163 | self._sensitivity_targets) 164 | 165 | self.target_sensitivities = sensitivities 166 | 167 | return out, backward_env 168 | -------------------------------------------------------------------------------- /jax_verify/src/branching/branch_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of utils for Branch-and-Bound algorithms. 17 | 18 | Contains for example algorithm to evaluate the inputs that concretize the 19 | backward linear bounds. 20 | """ 21 | from typing import Mapping, Tuple 22 | 23 | import jax 24 | import jax.numpy as jnp 25 | from jax_verify.src import bound_propagation 26 | from jax_verify.src import bound_utils 27 | from jax_verify.src import graph_traversal 28 | from jax_verify.src import utils 29 | from jax_verify.src.linear import backward_crown 30 | from jax_verify.src.linear import linear_relaxations 31 | from jax_verify.src.types import Index, Nest, Tensor # pylint: disable=g-multiple-import 32 | 33 | 34 | class NominalEvaluateConcretizingInputAlgorithm( 35 | bound_propagation.PropagationAlgorithm[Tensor]): 36 | """Find the input concretizing a backward linear transform and evaluate it. 37 | """ 38 | 39 | def __init__(self, 40 | intermediate_bounds: Mapping[Index, Tuple[Tensor, Tensor]], 41 | backward_transform: backward_crown.LinearBoundBackwardTransform): 42 | self._forward_bnd_algorithm = bound_propagation.ForwardPropagationAlgorithm( 43 | bound_utils.FixedBoundApplier(intermediate_bounds)) 44 | self._backward_transform = backward_transform 45 | 46 | def propagate(self, graph: graph_traversal.PropagationGraph, 47 | *bounds: Nest[graph_traversal.GraphInput]): 48 | assert len(graph.outputs) == 1 49 | (out,), bound_env = self._forward_bnd_algorithm.propagate(graph, *bounds) 50 | 51 | max_output = (out.upper == out.upper.max()).astype(jnp.float32) 52 | max_output = jnp.expand_dims(max_output, 0) 53 | 54 | initial_linear_expression = backward_crown.identity(-max_output) 55 | flat_inputs, _ = jax.tree_util.tree_flatten(*bounds) 56 | bound_inputs = [inp for inp in flat_inputs 57 | if isinstance(inp, bound_propagation.Bound)] 58 | input_nodes_indices = [(i,) for i in range(len(bound_inputs))] 59 | inputs_linfuns, back_env = graph.backward_propagation( 60 | self._backward_transform, bound_env, 61 | {graph.outputs[0]: initial_linear_expression}, 62 | input_nodes_indices) 63 | 64 | concretizing_bound_inputs = [] 65 | for input_linfun, inp_bound in zip(inputs_linfuns, bound_inputs): 66 | if input_linfun is not None: 67 | conc_inp = minimizing_concretizing_input(input_linfun, inp_bound) 68 | concretizing_bound_inputs.append(conc_inp[0]) 69 | 70 | def eval_model(*graph_inputs): 71 | # We are going to only pass Tensor. 72 | # Forward propagation simply evaluate the primitive when there is no 73 | # bound inputs. 74 | outvals, _ = graph.forward_propagation(None, graph_inputs) # pytype: disable=wrong-arg-types 75 | return outvals 76 | eval_model_boundinps = utils.bind_nonbound_args(eval_model, 77 | *flat_inputs) 78 | nominal_outs = eval_model_boundinps(*concretizing_bound_inputs) 79 | 80 | return nominal_outs, back_env 81 | 82 | 83 | def minimizing_concretizing_input( 84 | backward_linexp: linear_relaxations.LinearExpression, 85 | input_bound: bound_propagation.Bound) -> Tensor: 86 | """Get the input that concretize the backward bound to its lower bound. 87 | 88 | Args: 89 | backward_linexp: Coefficients of linear functions. The leading batch 90 | dimension corresponds to different output neurons that need to be 91 | concretized. 92 | input_bound: Bound on the activations of that layer. Its shape should 93 | match the coefficients of the linear functions to concretize. 94 | Returns: 95 | concretizing_inp: The input that correspond to the lower bound of the 96 | linear function given by backward_linexp. 97 | """ 98 | return concretizing_input_interval_bounds(backward_linexp, input_bound) 99 | 100 | 101 | def concretizing_input_interval_bounds( 102 | backward_linexp: linear_relaxations.LinearExpression, 103 | input_bound: bound_propagation.Bound) -> Tensor: 104 | """Compute the inputs that achieves the lower bound of a linear function.""" 105 | act_lower = jnp.expand_dims(input_bound.lower, 0) 106 | act_upper = jnp.expand_dims(input_bound.upper, 0) 107 | return jnp.where(backward_linexp.lin_coeffs > 0., act_lower, act_upper) 108 | -------------------------------------------------------------------------------- /jax_verify/src/intersection.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Mechanism to combine input bounds from multiple methods.""" 17 | 18 | from typing import Optional 19 | 20 | import jax.numpy as jnp 21 | from jax_verify.src import bound_propagation 22 | from jax_verify.src import graph_traversal 23 | from jax_verify.src.types import Primitive, Tensor # pylint: disable=g-multiple-import 24 | 25 | 26 | class IntersectionBound(bound_propagation.Bound): 27 | """Concretises to intersection of constituent bounds.""" 28 | 29 | def __init__(self, *base_bounds: bound_propagation.Bound): 30 | self.base_bounds = base_bounds 31 | 32 | @property 33 | def lower(self) -> Tensor: 34 | return jnp.array([bound.lower for bound in self.base_bounds]).max(axis=0) 35 | 36 | @property 37 | def upper(self) -> Tensor: 38 | return jnp.array([bound.upper for bound in self.base_bounds]).min(axis=0) 39 | 40 | 41 | class ConstrainedBound(bound_propagation.Bound): 42 | """Wraps a Bound with additional concrete constraints.""" 43 | 44 | def __init__( 45 | self, 46 | base_bound: bound_propagation.Bound, 47 | lower: Optional[Tensor], 48 | upper: Optional[Tensor], 49 | ): 50 | self._base_bound = base_bound 51 | self._lower = lower 52 | self._upper = upper 53 | 54 | def unwrap(self) -> bound_propagation.Bound: 55 | return self._base_bound.unwrap() 56 | 57 | @property 58 | def lower(self) -> Tensor: 59 | if self._lower is None: 60 | return self._base_bound.lower 61 | else: 62 | return jnp.maximum(self._base_bound.lower, self._lower) 63 | 64 | @property 65 | def upper(self) -> Tensor: 66 | if self._upper is None: 67 | return self._base_bound.upper 68 | else: 69 | return jnp.minimum(self._base_bound.upper, self._upper) 70 | 71 | 72 | class IntersectionBoundTransform(bound_propagation.BoundTransform): 73 | """Aggregates several bound transforms, intersecting their concrete bounds.""" 74 | 75 | def __init__(self, *base_transforms: bound_propagation.BoundTransform): 76 | self._base_transforms = base_transforms 77 | 78 | def input_transform( 79 | self, 80 | context: bound_propagation.TransformContext, 81 | input_bound: graph_traversal.InputBound, 82 | ) -> IntersectionBound: 83 | """Constructs initial input bounds for each constituent bound type. 84 | 85 | Args: 86 | context: Transform context containing node index. 87 | input_bound: Original concrete bounds on the input. 88 | 89 | Returns: 90 | Intersection of the constituent input bounds. 91 | """ 92 | return IntersectionBound(*[ 93 | transform.input_transform(context, input_bound) 94 | for transform in self._base_transforms]) 95 | 96 | def primitive_transform( 97 | self, 98 | context: bound_propagation.TransformContext, 99 | primitive: Primitive, 100 | *args: bound_propagation.LayerInput, 101 | **kwargs, 102 | ) -> IntersectionBound: 103 | """Propagates bounds for each constituent bound type. 104 | 105 | Args: 106 | context: Transform context containing node index. 107 | primitive: Primitive Jax operation to transform. 108 | *args: Arguments of the primitive, wrapped as `IntersectionBound`s. 109 | **kwargs: Keyword Arguments of the primitive. 110 | 111 | Returns: 112 | Intersection of the propagated constituent output bounds. 113 | """ 114 | def base_args_for_arg(arg): 115 | if isinstance(arg, bound_propagation.Bound): 116 | return [ConstrainedBound(bound, arg.lower, arg.upper) 117 | for bound in arg.unwrap().base_bounds] 118 | else: 119 | # Broadcast over the intersection components. 120 | return [arg for _ in self._base_transforms] 121 | 122 | base_args = [base_args_for_arg(arg) for arg in args] 123 | return IntersectionBound(*[ 124 | transform.equation_transform(context, primitive, *args, **kwargs)[0] 125 | for transform, *args in zip(self._base_transforms, *base_args)]) 126 | 127 | def should_handle_as_subgraph(self, primitive: Primitive) -> bool: 128 | # Handle as a sub-graph only if _all_ intersectands can. 129 | # Otherwise handle at the higher (synthetic primitive) level. 130 | return all(base_transform.should_handle_as_subgraph(primitive) 131 | for base_transform in self._base_transforms) 132 | 133 | -------------------------------------------------------------------------------- /jax_verify/src/mccormick.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Mccormick relaxations for bilinear terms. 17 | 18 | Create mc-cormick relaxations of bilinear terms for boundprop and verification. 19 | """ 20 | from typing import Callable, Tuple 21 | 22 | import jax.numpy as jnp 23 | from jax_verify.src.types import Tensor 24 | 25 | BilinearFun = Callable[[Tensor, Tensor], Tensor] 26 | 27 | 28 | def mccormick_ibp( 29 | lx: Tensor, 30 | ux: Tensor, 31 | ly: Tensor, 32 | uy: Tensor, 33 | matrix: Tensor, 34 | ) -> Tuple[Tensor, Tensor]: 35 | """Compute bounds in x.T * matrix * y s.t x in [lx, ux], y in [ly, uy]. 36 | 37 | Args: 38 | lx: Lower bounds on x (d,) 39 | ux: Upper bounds on x (d,) 40 | ly: Lower bounds on y (d,) 41 | uy: Upper bounds on y (d,) 42 | matrix: (d, d) matrix 43 | 44 | Returns: 45 | Lower and upper bounds on x.T * matrix * y 46 | """ 47 | ll = matrix * jnp.outer(lx, ly) 48 | lu = matrix * jnp.outer(lx, uy) 49 | ul = matrix * jnp.outer(ux, ly) 50 | uu = matrix * jnp.outer(ux, uy) 51 | 52 | lb_elementwise = jnp.minimum(jnp.minimum(ll, lu), jnp.minimum(ul, uu)) 53 | ub_elementwise = jnp.maximum(jnp.maximum(ll, lu), jnp.maximum(ul, uu)) 54 | return jnp.sum(lb_elementwise), jnp.sum(ub_elementwise) 55 | 56 | 57 | def posbilinear_mccormick_relaxations( 58 | fn: BilinearFun, 59 | x_lb: Tensor, x_ub: Tensor, y_lb: Tensor, y_ub: Tensor, 60 | ) -> Tuple[BilinearFun, BilinearFun, BilinearFun, BilinearFun]: 61 | """Constructs all four McCormick relaxation of a positive bilinear primitive. 62 | 63 | For x in [x_l, x_u] and y in [y_l, y_u], the bound imposed are: 64 | x·y >= x·y_l + x_l·y - x_l·y_l 65 | x·y >= x·y_u + x_h·y - x_h·y_u 66 | x·y <= x·y_u + x_l·y - x_l·y_u 67 | x·y <= x·y_l + x_u·y - x_l·y_u 68 | 69 | Args: 70 | fn: Positive definite bilinear function. 71 | x_lb: Lower bounds on x 72 | x_ub: Upper bounds on x 73 | y_lb: Lower bounds on y 74 | y_ub: Upper bounds on y 75 | Returns: 76 | lb_fun0, lb_fun1, ub_fun0, ub_fun1 77 | """ 78 | def lb_fun0(x, y): 79 | return fn(x, y_lb) + fn(x_lb, y) - fn(x_lb, y_lb) 80 | 81 | def lb_fun1(x, y): 82 | return fn(x, y_ub) + fn(x_ub, y) - fn(x_ub, y_ub) 83 | 84 | def ub_fun0(x, y): 85 | return fn(x, y_ub) + fn(x_lb, y) - fn(x_lb, y_ub) 86 | 87 | def ub_fun1(x, y): 88 | return fn(x, y_lb) + fn(x_ub, y) - fn(x_ub, y_lb) 89 | 90 | return lb_fun0, lb_fun1, ub_fun0, ub_fun1 91 | 92 | 93 | def mccormick_outer_product(x: Tensor, 94 | y: Tensor, 95 | x_lb: Tensor, 96 | x_ub: Tensor, 97 | y_lb: Tensor, 98 | y_ub: Tensor, 99 | ) -> Tensor: 100 | """McCormick Upper Bound on bilinear term x @ y.T. 101 | 102 | Args: 103 | x: Input tensor 104 | y: Input tensor 105 | x_lb: Lower bounds on x 106 | x_ub: Upper bounds on x 107 | y_lb: Lower bounds on y 108 | y_ub: Upper bounds on y 109 | 110 | Returns: 111 | bd: Nonconvex bound. 112 | """ 113 | def outer(x, y): 114 | x = jnp.reshape(x, [-1, 1]) 115 | y = jnp.reshape(y, [-1, 1]) 116 | return jnp.dot(x, y.T) 117 | 118 | output_lb_a, output_lb_b, output_ub_a, output_ub_b = [ 119 | relax_fn(x, y) for relax_fn in posbilinear_mccormick_relaxations( 120 | outer, x_lb, x_ub, y_lb, y_ub)] 121 | return (jnp.maximum(output_lb_a, output_lb_b), # pytype: disable=bad-return-type # jax-ndarray 122 | jnp.minimum(output_ub_a, output_ub_b)) 123 | -------------------------------------------------------------------------------- /jax_verify/src/mip_solver/solve_relaxation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Methods for solving relaxation generated by relaxation.py. 17 | 18 | This file mainly calls out to relaxation.RelaxationSolvers defined in other 19 | files, and provides a higher-level interface than using relaxation.py directly. 20 | """ 21 | 22 | from jax_verify.src import bound_propagation 23 | from jax_verify.src.mip_solver import cvxpy_relaxation_solver 24 | from jax_verify.src.mip_solver import relaxation 25 | 26 | 27 | def solve_planet_relaxation( 28 | logits_fn, initial_bounds, boundprop_transform, 29 | objective, objective_bias, index, 30 | solver=cvxpy_relaxation_solver.CvxpySolver): 31 | """Solves the "Planet" (Ehlers 17) or "triangle" relaxation. 32 | 33 | The general approach is to use jax_verify to generate constraints, which can 34 | then be passed to generic solvers. Note that using CVXPY will incur a large 35 | overhead when defining the LP, because we define all constraints element-wise, 36 | to avoid representing convolutional layers as a single matrix multiplication, 37 | which would be inefficient. In CVXPY, defining large numbers of constraints is 38 | slow. 39 | 40 | Args: 41 | logits_fn: Mapping from inputs (batch_size x input_size) -> (batch_size, 42 | num_classes) 43 | initial_bounds: `IntervalBound` with initial bounds on inputs, 44 | with lower and upper bounds of dimension (batch_size x input_size). 45 | boundprop_transform: bound_propagation.BoundTransform instance, such as 46 | `jax_verify.ibp_transform`. Used to pre-compute interval bounds for 47 | intermediate activations used in defining the Planet relaxation. 48 | objective: Objective to optimize, given as an array of coefficients to be 49 | applied to the output of logits_fn defining the objective to minimize 50 | objective_bias: Bias to add to objective 51 | index: Index in the batch for which to solve the relaxation 52 | solver: A relaxation.RelaxationSolver, which specifies the backend to solve 53 | the resulting LP. 54 | Returns: 55 | val: The optimal value from the relaxation 56 | solution: The optimal solution found by the solver 57 | status: The status of the relaxation solver 58 | """ 59 | relaxation_transform = relaxation.RelaxationTransform(boundprop_transform) 60 | variable, env = bound_propagation.bound_propagation( 61 | bound_propagation.ForwardPropagationAlgorithm(relaxation_transform), 62 | logits_fn, initial_bounds) 63 | value, solution, status = relaxation.solve_relaxation( 64 | solver, objective, objective_bias, variable, env, 65 | index=index, time_limit_millis=None) 66 | return value, solution, status 67 | -------------------------------------------------------------------------------- /jax_verify/src/nonconvex/methods.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Pre-canned non-convex methods.""" 17 | from typing import Callable 18 | 19 | from jax_verify.src import bound_propagation 20 | from jax_verify.src import graph_traversal 21 | from jax_verify.src import ibp 22 | from jax_verify.src import synthetic_primitives 23 | from jax_verify.src.nonconvex import duals 24 | from jax_verify.src.nonconvex import nonconvex 25 | from jax_verify.src.nonconvex import optimizers 26 | from jax_verify.src.types import Nest, Tensor # pylint: disable=g-multiple-import 27 | 28 | 29 | def nonconvex_ibp_bound_propagation( 30 | function: Callable[..., Nest[Tensor]], 31 | *bounds: Nest[graph_traversal.GraphInput], 32 | graph_simplifier=synthetic_primitives.default_simplifier, 33 | ) -> Nest[nonconvex.NonConvexBound]: 34 | """Builds the non-convex objective using IBP. 35 | 36 | Args: 37 | function: Function performing computation to obtain bounds for. Takes as 38 | only arguments the network inputs. 39 | *bounds: Bounds on the inputs of the function. 40 | graph_simplifier: What graph simplifier to use. 41 | Returns: 42 | output_bounds: NonConvex bounds that can be optimized with a solver. 43 | """ 44 | algorithm = nonconvex.nonconvex_algorithm( 45 | duals.WolfeNonConvexBound, 46 | nonconvex.BaseBoundConcretizer(), 47 | base_boundprop=ibp.bound_transform) 48 | output_bounds, _ = bound_propagation.bound_propagation( 49 | algorithm, function, *bounds, graph_simplifier=graph_simplifier) 50 | return output_bounds 51 | 52 | 53 | def nonconvex_constopt_bound_propagation( 54 | function: Callable[..., Nest[Tensor]], 55 | *bounds: Nest[graph_traversal.GraphInput], 56 | graph_simplifier=synthetic_primitives.default_simplifier, 57 | ) -> Nest[nonconvex.NonConvexBound]: 58 | """Builds the optimizable objective. 59 | 60 | Args: 61 | function: Function performing computation to obtain bounds for. Takes as 62 | only arguments the network inputs. 63 | *bounds: Bounds on the inputs of the function. 64 | graph_simplifier: What graph simplifier to use. 65 | Returns: 66 | output_bounds: NonConvex bounds that can be optimized with a solver. 67 | """ 68 | nostep_optimizer = optimizers.OptimizingConcretizer( 69 | optimizers.PGDOptimizer(0, 0., optimize_dual=False), 70 | max_parallel_nodes=512) 71 | algorithm = nonconvex.nonconvex_algorithm( 72 | duals.WolfeNonConvexBound, nostep_optimizer) 73 | output_bounds, _ = bound_propagation.bound_propagation( 74 | algorithm, function, *bounds, graph_simplifier=graph_simplifier) 75 | return output_bounds 76 | -------------------------------------------------------------------------------- /jax_verify/src/simplex_bound.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Bound with L1 constraints.""" 17 | import jax 18 | from jax import numpy as jnp 19 | 20 | from jax_verify.src import bound_propagation 21 | from jax_verify.src import opt_utils 22 | from jax_verify.src.types import Tensor 23 | 24 | 25 | class SimplexIntervalBound(bound_propagation.IntervalBound): 26 | """Represent a bound for which we have a constraint on the sum of coordinates. 27 | 28 | Each coordinate is subject to interval constraints, and the sum of all 29 | coordinates must be equal to a given value. 30 | 31 | """ 32 | 33 | def __init__(self, lower_bound: Tensor, upper_bound: Tensor, 34 | simplex_sum: float): 35 | super(SimplexIntervalBound, self).__init__(lower_bound, upper_bound) 36 | self._simplex_sum = simplex_sum 37 | 38 | @property 39 | def simplex_sum(self) -> float: 40 | return self._simplex_sum 41 | 42 | @classmethod 43 | def from_jittable( 44 | cls, 45 | jittable_simplexint_bound: bound_propagation.JittableInputBound 46 | ) -> 'SimplexIntervalBound': 47 | return cls(jittable_simplexint_bound.lower, 48 | jittable_simplexint_bound.upper, 49 | jittable_simplexint_bound.kwargs['simplex_sum']) 50 | 51 | def to_jittable(self) -> bound_propagation.JittableInputBound: 52 | return bound_propagation.JittableInputBound( 53 | self.lower, self.upper, {SimplexIntervalBound: None}, 54 | {'simplex_sum': self.simplex_sum}) 55 | 56 | def project_onto_bound(self, tensor: Tensor) -> Tensor: 57 | return opt_utils.project_onto_interval_simplex(self.lower, self.upper, 58 | self.simplex_sum, tensor) 59 | 60 | 61 | def concretize_linear_function_simplexinterval_constraints( 62 | linexp, input_bound: SimplexIntervalBound) -> Tensor: 63 | """Compute the lower bound of a linear function under Simplex constraints.""" 64 | 65 | solve_lin = jax.vmap(opt_utils.fractional_exact_knapsack, 66 | in_axes=(0, None, None, None)) 67 | 68 | # We are maximizing -lin_coeffs*x in order to minimize lin_coeffs*x 69 | neg_sum_lin_bound = solve_lin(-linexp.lin_coeffs, input_bound.simplex_sum, 70 | input_bound.lower, input_bound.upper) 71 | return linexp.offset - neg_sum_lin_bound 72 | 73 | 74 | def concretizing_input_simplexinterval_constraints( 75 | linexp, input_bound: SimplexIntervalBound) -> Tensor: 76 | """Compute the input that achieves the lower bound of a linear function.""" 77 | flat_lower = jnp.reshape(input_bound.lower, (-1,)) 78 | flat_upper = jnp.reshape(input_bound.upper, (-1,)) 79 | flat_lin_coeffs = jnp.reshape(linexp.lin_coeffs, (-1, flat_lower.size)) 80 | 81 | def single_linexpr_concretizing_inp(coeffs): 82 | _, sorted_lower, sorted_upper, sorted_idx = jax.lax.sort( 83 | (coeffs, flat_lower, flat_upper, jnp.arange(coeffs.size)), num_keys=1) 84 | sorted_assignment = opt_utils.sorted_knapsack(sorted_lower, sorted_upper, 85 | input_bound.simplex_sum, 86 | backward=False) 87 | # This is a cute trick to avoid using a jnp.take_along_axis, which is 88 | # usually quite slow, particularly on TPU, when you do permutation. 89 | # jax.lax.sort can take multiple arguments, and will sort them according 90 | # to the ordering of the first tensor. 91 | # When we did the sorting of the weights, we also sorted the index of each 92 | # coordinate. By sorting by it, we will recover the initial ordering. 93 | _, assignment = jax.lax.sort((sorted_idx, sorted_assignment), num_keys=1) 94 | return assignment 95 | 96 | flat_conc_input = jax.vmap(single_linexpr_concretizing_inp)(flat_lin_coeffs) 97 | return jnp.reshape(flat_conc_input, linexp.lin_coeffs.shape) 98 | -------------------------------------------------------------------------------- /jax_verify/src/types.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common type definitions used by jax_verify.""" 17 | 18 | from typing import Any, Generic, Mapping, Sequence, Tuple, TypeVar, Union 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | import typing_extensions 23 | 24 | 25 | Primitive = jax.core.Primitive 26 | Tensor = jnp.ndarray 27 | T = TypeVar('T') 28 | U = TypeVar('U') 29 | Nest = Union[T, Sequence['Nest[T]'], Mapping[Any, 'Nest[T]']] 30 | Index = Tuple[int, ...] 31 | 32 | 33 | class TensorFun(typing_extensions.Protocol): 34 | 35 | def __call__(self, *inputs: Tensor) -> Tensor: 36 | pass 37 | 38 | 39 | class SpecFn(typing_extensions.Protocol): 40 | """Specification, expressed as all outputs are <=0.""" 41 | 42 | def __call__(self, *inputs: Nest[Tensor]) -> Nest[Tensor]: 43 | pass 44 | 45 | 46 | class ArgsKwargsCallable(typing_extensions.Protocol, Generic[T, U]): 47 | 48 | def __call__(self, *args: T, **kwargs) -> U: 49 | pass 50 | -------------------------------------------------------------------------------- /jax_verify/tests/backpropagation_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for backpropagation of sensitivity values.""" 17 | 18 | from absl.testing import absltest 19 | import chex 20 | import jax 21 | import jax.numpy as jnp 22 | import jax_verify 23 | from jax_verify.src import bound_propagation 24 | from jax_verify.src.branching import backpropagation 25 | 26 | 27 | class BackpropagationTest(chex.TestCase): 28 | 29 | def test_identity_network_leaves_sensitivities_unchanged(self): 30 | # Set up an identity network. 31 | def logits_fn(x): 32 | return x 33 | input_bounds = jax_verify.IntervalBound( 34 | lower_bound=jnp.array([-1., 0., 1.]), 35 | upper_bound=jnp.array([2., 3., 4.])) 36 | 37 | # Backpropagation 38 | output_sensitivities = jnp.array([.1, .2, -.3]) 39 | sensitivity_computation = backpropagation.SensitivityAlgorithm( 40 | jax_verify.ibp_transform, [(0,)], output_sensitivities) 41 | bound_propagation.bound_propagation(sensitivity_computation, # pytype: disable=wrong-arg-types # jax-ndarray 42 | logits_fn, input_bounds) 43 | input_sensitivities, = sensitivity_computation.target_sensitivities 44 | 45 | chex.assert_trees_all_close(input_sensitivities, jnp.array([.1, .2, -.3])) 46 | 47 | def test_relu_network_applies_chord_slopes_to_sensitivities(self): 48 | # Set up some ReLUs, with a variety of input bounds: 49 | # 1 blocking, 1 passing, and 3 'ambiguous' (straddling zero). 50 | def logits_fn(x): 51 | return jax.nn.relu(x) 52 | input_bounds = jax_verify.IntervalBound( 53 | lower_bound=jnp.array([-2., 1., -1., -4., -2.]), 54 | upper_bound=jnp.array([-1., 2., 1., 1., 3.])) 55 | 56 | # Backpropagation. 57 | output_sensitivities = jnp.array([10., 10., 10., 10., 10.]) 58 | sensitivity_computation = backpropagation.SensitivityAlgorithm( 59 | jax_verify.ibp_transform, [(0,)], output_sensitivities) 60 | bound_propagation.bound_propagation(sensitivity_computation, # pytype: disable=wrong-arg-types # jax-ndarray 61 | logits_fn, input_bounds) 62 | input_sensitivities, = sensitivity_computation.target_sensitivities 63 | 64 | # Expect blocking neurons to have no sensitivity, passing neurons to have 65 | # full sensitivity, and ambiguous neurons to interpolate between the two. 66 | chex.assert_trees_all_close( 67 | input_sensitivities, jnp.array([0., 10., 5., 2., 6.])) 68 | 69 | def test_affine_network_applies_transpose_to_sensitivites(self): 70 | # Set up a matmul with bias. 71 | w = jnp.array([[1., 4., -5.], [2., -3., 6.]]) 72 | b = jnp.array([20., 30., 40.]) 73 | def logits_fn(x): 74 | return x @ w + b 75 | 76 | input_bounds = jax_verify.IntervalBound( 77 | lower_bound=jnp.zeros(shape=(1, 2)), 78 | upper_bound=jnp.ones(shape=(1, 2))) 79 | 80 | # Backpropagation. 81 | output_sensitivities = jnp.array([[1., 0., -1.]]) 82 | sensitivity_computation = backpropagation.SensitivityAlgorithm( 83 | jax_verify.ibp_transform, [(0,)], output_sensitivities) 84 | bound_propagation.bound_propagation(sensitivity_computation, # pytype: disable=wrong-arg-types # jax-ndarray 85 | logits_fn, input_bounds) 86 | input_sensitivities, = sensitivity_computation.target_sensitivities 87 | # Expect the transpose of w to have been applied to the sensitivities. 88 | # The bias will be ignored. 89 | chex.assert_trees_all_close( 90 | input_sensitivities, jnp.array([[6., -4.]])) 91 | 92 | 93 | if __name__ == '__main__': 94 | absltest.main() 95 | -------------------------------------------------------------------------------- /jax_verify/tests/branch_algorithm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for branching.""" 17 | 18 | import functools 19 | 20 | from absl.testing import absltest 21 | import chex 22 | import haiku as hk 23 | import jax 24 | import jax.numpy as jnp 25 | import jax_verify 26 | from jax_verify.src import ibp 27 | from jax_verify.src.branching import branch_algorithm 28 | from jax_verify.src.branching import branch_selection 29 | 30 | 31 | class BranchAlgorithmTest(chex.TestCase): 32 | 33 | def test_identity_network_leaves_sensitivities_unchanged(self): 34 | # Set up a small network. 35 | @hk.transform 36 | def forward_fn(x): 37 | x = hk.Linear(7)(x) 38 | x = jax.nn.relu(x) 39 | x = hk.Linear(5)(x) 40 | return x 41 | 42 | input_bounds = jax_verify.IntervalBound( 43 | lower_bound=jnp.array([-1., 0., 1.]), 44 | upper_bound=jnp.array([2., 3., 4.])) 45 | 46 | params = forward_fn.init(jax.random.PRNGKey(0), input_bounds.lower) 47 | spec_fn = functools.partial(forward_fn.apply, params, None) 48 | 49 | upper_bound = branch_algorithm.upper_bound_with_branching( 50 | ibp.bound_transform, 51 | branch_selection.ReluSelector(), 52 | spec_fn, 53 | input_bounds, 54 | num_branches=5) 55 | chex.assert_equal((5,), upper_bound.shape) 56 | 57 | 58 | if __name__ == '__main__': 59 | absltest.main() 60 | -------------------------------------------------------------------------------- /jax_verify/tests/branch_selection_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for branching decisions.""" 17 | 18 | 19 | from absl.testing import absltest 20 | import chex 21 | import jax.numpy as jnp 22 | from jax_verify.src import ibp 23 | from jax_verify.src.branching import branch_selection 24 | 25 | 26 | class BranchSelectionTest(chex.TestCase): 27 | 28 | def test_jittable_branching_decisions_enforced(self): 29 | 30 | free_bounds = ibp.IntervalBound(-jnp.ones(4,), 31 | jnp.ones(4,)) 32 | 33 | layer_index = (1, 2) 34 | branching_decision_list = [ 35 | # Neuron 0 is greater than 0. 36 | branch_selection.BranchDecision(layer_index, 0, 0., 1), 37 | # Neuron 1 is smaller than 0.5 38 | branch_selection.BranchDecision(layer_index, 1, 0.5, -1), 39 | # Neuron 2 is between -0.3 and 0.3 40 | branch_selection.BranchDecision(layer_index, 2, -0.3, 1), 41 | branch_selection.BranchDecision(layer_index, 2, 0.3, -1), 42 | # Neuron 3 is below 2., which is a spurious constraint 43 | branch_selection.BranchDecision(layer_index, 3, 2., -1) 44 | ] 45 | branching_decisions_tensors = branch_selection.branching_decisions_tensors( 46 | branching_decision_list, 3, 8) 47 | 48 | enforced_bounds = branch_selection.enforce_jittable_branching_decisions( 49 | branching_decisions_tensors, layer_index, free_bounds) 50 | 51 | chex.assert_trees_all_close((enforced_bounds.lower, enforced_bounds.upper), 52 | (jnp.array([0., -1., -0.3, -1.]), 53 | jnp.array([1., 0.5, 0.3, 1.]))) 54 | 55 | # check that the bounds are not modified when enforced on another layer. 56 | other_lay_bound = branch_selection.enforce_jittable_branching_decisions( 57 | branching_decisions_tensors, (1, 3), free_bounds) 58 | 59 | chex.assert_trees_all_close((free_bounds.lower, free_bounds.upper), 60 | (other_lay_bound.lower, other_lay_bound.upper)) 61 | 62 | def test_infeasible_bounds_detection(self): 63 | 64 | non_crossing_bounds = ibp.IntervalBound(jnp.zeros(3,), jnp.ones(3,)) 65 | crossing_bounds = ibp.IntervalBound(jnp.array([0., 0., 1.]), 66 | jnp.array([1., 1., 0.5])) 67 | 68 | non_crossing_infeasible = branch_selection.infeasible_bounds( 69 | non_crossing_bounds.to_jittable()) 70 | self.assertFalse(non_crossing_infeasible) 71 | 72 | crossing_infeasible = branch_selection.infeasible_bounds( 73 | crossing_bounds.to_jittable()) 74 | self.assertTrue(crossing_infeasible) 75 | 76 | if __name__ == '__main__': 77 | absltest.main() 78 | -------------------------------------------------------------------------------- /jax_verify/tests/branch_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test the extraction of concretizing values.""" 17 | 18 | from absl.testing import absltest 19 | import chex 20 | 21 | import jax.numpy as jnp 22 | from jax_verify.src import ibp 23 | from jax_verify.src.branching import branch_utils 24 | from jax_verify.src.linear import linear_relaxations 25 | 26 | 27 | class ConcretizingInputTest(chex.TestCase): 28 | 29 | def test_linf_bound_concretizing_inputs(self): 30 | 31 | # Create a stack of two linear expression to test things. 32 | linexp = linear_relaxations.LinearExpression( 33 | jnp.stack([jnp.ones((2, 3)), 34 | -jnp.ones((2, 3))]), 35 | jnp.zeros((2,))) 36 | 37 | input_bound = ibp.IntervalBound(-2 * jnp.ones((2, 3)), 38 | 2 * jnp.ones((2, 3))) 39 | 40 | concretizing_inp = branch_utils.minimizing_concretizing_input( 41 | linexp, input_bound) 42 | # Check that the shape of the concretizing inp for each linexp is of the 43 | # shape of the input. 44 | chex.assert_shape(concretizing_inp, (2, 2, 3)) 45 | # Evaluating the bound given by the concretizing inp 46 | bound_by_concinp = ((linexp.lin_coeffs * concretizing_inp).sum(axis=(1, 2)) 47 | + linexp.offset) 48 | 49 | # Result by concretizing directly: 50 | concretized_bound = linear_relaxations.concretize_linear_expression( 51 | linexp, input_bound) 52 | 53 | chex.assert_trees_all_close(bound_by_concinp, concretized_bound) 54 | 55 | if __name__ == '__main__': 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /jax_verify/tests/crownibp_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for CrownIBP.""" 17 | 18 | import functools 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | 23 | import haiku as hk 24 | import jax 25 | import jax.numpy as jnp 26 | import jax_verify 27 | 28 | 29 | class CrownIBPBoundTest(parameterized.TestCase): 30 | 31 | def assertArrayAlmostEqual(self, lhs, rhs): 32 | diff = jnp.abs(lhs - rhs).max() 33 | self.assertAlmostEqual(diff, 0., delta=1e-5) 34 | 35 | def test_fc_crownibp(self): 36 | 37 | @hk.without_apply_rng 38 | @hk.transform 39 | def linear_model(inp): 40 | return hk.Linear(1)(inp) 41 | 42 | z = jnp.array([[1., 2., 3.]]) 43 | params = {'linear': 44 | {'w': jnp.ones((3, 1), dtype=jnp.float32), 45 | 'b': jnp.array([2.])}} 46 | input_bounds = jax_verify.IntervalBound(z-1., z+1.) 47 | fun = functools.partial(linear_model.apply, params) 48 | output_bounds = jax_verify.crownibp_bound_propagation( 49 | fun, input_bounds) 50 | 51 | self.assertAlmostEqual(5., output_bounds.lower) 52 | self.assertAlmostEqual(11., output_bounds.upper) 53 | 54 | def test_conv2d_crownibp(self): 55 | 56 | @hk.without_apply_rng 57 | @hk.transform 58 | def conv2d_model(inp): 59 | return hk.Conv2D(output_channels=1, kernel_shape=(2, 2), 60 | padding='VALID', stride=1, with_bias=True)(inp) 61 | 62 | z = jnp.array([1., 2., 3., 4.]) 63 | z = jnp.reshape(z, [1, 2, 2, 1]) 64 | 65 | params = {'conv2_d': 66 | {'w': jnp.ones((2, 2, 1, 1), dtype=jnp.float32), 67 | 'b': jnp.array([2.])}} 68 | 69 | fun = functools.partial(conv2d_model.apply, params) 70 | input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) 71 | output_bounds = jax_verify.crownibp_bound_propagation( 72 | fun, input_bounds) 73 | 74 | self.assertAlmostEqual(8., output_bounds.lower) 75 | self.assertAlmostEqual(16., output_bounds.upper) 76 | 77 | def test_relu_crownibp(self): 78 | def relu_model(inp): 79 | return jax.nn.relu(inp) 80 | z = jnp.array([[-2., 3.]]) 81 | 82 | input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) 83 | output_bounds = jax_verify.crownibp_bound_propagation( 84 | relu_model, input_bounds) 85 | 86 | self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), output_bounds.lower) 87 | self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), output_bounds.upper) 88 | 89 | 90 | if __name__ == '__main__': 91 | absltest.main() 92 | -------------------------------------------------------------------------------- /jax_verify/tests/cvxpy_relaxation_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for solving the convex relaxation using CVXPY.""" 17 | 18 | import functools 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | 23 | import haiku as hk 24 | import jax 25 | import jax.numpy as jnp 26 | import jax_verify 27 | from jax_verify.src import bound_propagation 28 | from jax_verify.src.mip_solver import cvxpy_relaxation_solver 29 | from jax_verify.src.mip_solver import relaxation 30 | 31 | 32 | class CVXPYRelaxationTest(parameterized.TestCase): 33 | 34 | def assertArrayAlmostEqual(self, lhs, rhs): 35 | diff = jnp.abs(lhs - rhs).max() 36 | self.assertAlmostEqual(diff, 0.) 37 | 38 | def get_bounds(self, fun, input_bounds): 39 | output = fun(input_bounds.lower) 40 | 41 | boundprop_transform = jax_verify.ibp_transform 42 | relaxation_transform = relaxation.RelaxationTransform(boundprop_transform) 43 | var, env = bound_propagation.bound_propagation( 44 | bound_propagation.ForwardPropagationAlgorithm(relaxation_transform), 45 | fun, input_bounds) 46 | 47 | objective_bias = 0. 48 | index = 0 49 | 50 | lower_bounds = [] 51 | upper_bounds = [] 52 | for output_idx in range(output.size): 53 | objective = (jnp.arange(output.size) == output_idx).astype(jnp.float32) 54 | 55 | lower_bound, _, _ = relaxation.solve_relaxation( 56 | cvxpy_relaxation_solver.CvxpySolver, objective, objective_bias, 57 | var, env, index) 58 | 59 | neg_upper_bound, _, _ = relaxation.solve_relaxation( 60 | cvxpy_relaxation_solver.CvxpySolver, -objective, objective_bias, 61 | var, env, index) 62 | lower_bounds.append(lower_bound) 63 | upper_bounds.append(-neg_upper_bound) 64 | 65 | return jnp.array(lower_bounds), jnp.array(upper_bounds) 66 | 67 | def test_linear_cvxpy_relaxation(self): 68 | 69 | def linear_model(inp): 70 | return hk.Linear(1)(inp) 71 | 72 | z = jnp.array([[1., 2., 3.]]) 73 | params = {'linear': 74 | {'w': jnp.ones((3, 1), dtype=jnp.float32), 75 | 'b': jnp.array([2.])}} 76 | 77 | fun = functools.partial( 78 | hk.without_apply_rng(hk.transform(linear_model)).apply, 79 | params) 80 | input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) 81 | 82 | lower_bounds, upper_bounds = self.get_bounds(fun, input_bounds) 83 | self.assertAlmostEqual(5., lower_bounds) 84 | self.assertAlmostEqual(11., upper_bounds) 85 | 86 | def test_conv1d_cvxpy_relaxation(self): 87 | 88 | def conv1d_model(inp): 89 | return hk.Conv1D(output_channels=1, kernel_shape=2, 90 | padding='VALID', stride=1, with_bias=True)(inp) 91 | z = jnp.array([3., 4.]) 92 | z = jnp.reshape(z, [1, 2, 1]) 93 | 94 | params = {'conv1_d': 95 | {'w': jnp.ones((2, 1, 1), dtype=jnp.float32), 96 | 'b': jnp.array([2.])}} 97 | 98 | fun = functools.partial( 99 | hk.without_apply_rng(hk.transform(conv1d_model)).apply, 100 | params) 101 | input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) 102 | 103 | lower_bounds, upper_bounds = self.get_bounds(fun, input_bounds) 104 | 105 | self.assertAlmostEqual(7., lower_bounds, delta=1e-5) 106 | self.assertAlmostEqual(11., upper_bounds, delta=1e-5) 107 | 108 | def test_conv2d_cvxpy_relaxation(self): 109 | def conv2d_model(inp): 110 | return hk.Conv2D(output_channels=1, kernel_shape=(2, 2), 111 | padding='VALID', stride=1, with_bias=True)(inp) 112 | z = jnp.array([1., 2., 3., 4.]) 113 | z = jnp.reshape(z, [1, 2, 2, 1]) 114 | 115 | params = {'conv2_d': 116 | {'w': jnp.ones((2, 2, 1, 1), dtype=jnp.float32), 117 | 'b': jnp.array([2.])}} 118 | 119 | fun = functools.partial( 120 | hk.without_apply_rng(hk.transform(conv2d_model)).apply, 121 | params) 122 | input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) 123 | 124 | lower_bounds, upper_bounds = self.get_bounds(fun, input_bounds) 125 | self.assertAlmostEqual(8., lower_bounds) 126 | self.assertAlmostEqual(16., upper_bounds) 127 | 128 | def test_relu_cvxpy_relaxation(self): 129 | def relu_model(inp): 130 | return jax.nn.relu(inp) 131 | z = jnp.array([[-2., 3.]]) 132 | 133 | input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) 134 | lower_bounds, upper_bounds = self.get_bounds(relu_model, input_bounds) 135 | 136 | self.assertArrayAlmostEqual(jnp.array([[0., 2.]]), lower_bounds) 137 | self.assertArrayAlmostEqual(jnp.array([[0., 4.]]), upper_bounds) 138 | 139 | 140 | if __name__ == '__main__': 141 | absltest.main() 142 | -------------------------------------------------------------------------------- /jax_verify/tests/functional_lagrangian/attacks_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Attacks test.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import chex 21 | import haiku as hk 22 | import jax 23 | import jax.numpy as jnp 24 | import jax_verify 25 | from jax_verify.extensions.functional_lagrangian import attacks 26 | from jax_verify.extensions.functional_lagrangian import verify_utils 27 | 28 | EPS = 0.1 29 | 30 | 31 | def make_data_spec(prng_key): 32 | """Create data specification from config.""" 33 | x = jax.random.normal(prng_key, [8]) 34 | input_bounds = (x - EPS, x + EPS) 35 | 36 | return verify_utils.DataSpec( 37 | input=x, 38 | true_label=0, 39 | target_label=1, 40 | epsilon=EPS, 41 | input_bounds=input_bounds) 42 | 43 | 44 | def make_params(prng_key, dropout_rate=0.0, std=None): 45 | prng_key_seq = hk.PRNGSequence(prng_key) 46 | 47 | w1 = jax.random.normal(next(prng_key_seq), [8, 4]) 48 | b1 = jax.random.normal(next(prng_key_seq), [4]) 49 | 50 | w2 = jax.random.normal(next(prng_key_seq), [4, 2]) 51 | b2 = jax.random.normal(next(prng_key_seq), [2]) 52 | 53 | if std is not None: 54 | w1_std = std * jnp.ones([8, 4]) 55 | b1_std = std * jnp.ones([4]) 56 | w1_bound = jax_verify.IntervalBound(w1 - 3 * w1_std, w1 + 3 * w1_std) 57 | b1_bound = jax_verify.IntervalBound(b1 - 3 * b1_std, b1 + 3 * b1_std) 58 | else: 59 | w1_std, b1_std, w1_bound, b1_bound = None, None, None, None 60 | 61 | params = [ 62 | verify_utils.FCParams( 63 | w=w1, 64 | b=b1, 65 | w_std=w1_std, 66 | b_std=b1_std, 67 | w_bound=w1_bound, 68 | b_bound=b1_bound, 69 | ), 70 | verify_utils.FCParams( 71 | w=w2, 72 | b=b2, 73 | dropout_rate=dropout_rate, 74 | ) 75 | ] 76 | return params 77 | 78 | 79 | class AttacksTest(parameterized.TestCase): 80 | 81 | def setUp(self): 82 | super().setUp() 83 | self.prng_seq = hk.PRNGSequence(1234) 84 | self.data_spec = make_data_spec(next(self.prng_seq)) 85 | 86 | def test_forward_deterministic(self): 87 | params = make_params(next(self.prng_seq)) 88 | self._check_deterministic_behavior(params) 89 | 90 | def test_forward_almost_no_randomness(self): 91 | params = make_params(next(self.prng_seq), std=1e-8, dropout_rate=1e-8) 92 | self._check_deterministic_behavior(params) 93 | 94 | def test_forward_gaussian(self): 95 | params = make_params(next(self.prng_seq), std=1.0) 96 | self._check_stochastic_behavior(params) 97 | 98 | def test_forward_dropout(self): 99 | params = make_params(next(self.prng_seq), dropout_rate=0.8) 100 | self._check_stochastic_behavior(params) 101 | 102 | def test_adversarial_integration(self): 103 | spec_type = verify_utils.SpecType.ADVERSARIAL 104 | params = make_params(next(self.prng_seq), std=0.1, dropout_rate=0.2) 105 | attacks.adversarial_attack( 106 | params, 107 | self.data_spec, 108 | spec_type, 109 | next(self.prng_seq), 110 | num_steps=5, 111 | learning_rate=0.1, 112 | num_samples=3) 113 | 114 | def test_adversarial_uncertainty_integration(self): 115 | spec_type = verify_utils.SpecType.ADVERSARIAL 116 | params = make_params(next(self.prng_seq), std=0.1, dropout_rate=0.2) 117 | attacks.adversarial_attack( 118 | params, 119 | self.data_spec, 120 | spec_type, 121 | next(self.prng_seq), 122 | num_steps=5, 123 | learning_rate=0.1, 124 | num_samples=3) 125 | 126 | def _make_value_and_grad(self, params, num_samples): 127 | forward_fn = attacks.make_forward(params, num_samples) 128 | 129 | def objective_fn(x, prng_key): 130 | out = jnp.reshape(forward_fn(x, prng_key), [2]) 131 | return out[1] - out[0] 132 | 133 | return jax.value_and_grad(objective_fn) 134 | 135 | def _check_deterministic_behavior(self, params): 136 | # build function with 1 sample 137 | value_and_grad_fn = self._make_value_and_grad(params, num_samples=1) 138 | # forward first time 139 | out_1 = value_and_grad_fn(self.data_spec.input, next(self.prng_seq)) 140 | 141 | # forward again gives the same result 142 | out_1_again = value_and_grad_fn(self.data_spec.input, next(self.prng_seq)) 143 | chex.assert_trees_all_close(out_1, out_1_again, rtol=1e-5) 144 | 145 | # forward with 3 samples should still give the same result 146 | value_and_grad_fn = self._make_value_and_grad(params, num_samples=3) 147 | out_3 = value_and_grad_fn(self.data_spec.input, next(self.prng_seq)) 148 | chex.assert_trees_all_close(out_3, out_1, rtol=1e-5) 149 | 150 | def _check_stochastic_behavior(self, params): 151 | value_and_grad_fn = self._make_value_and_grad(params, num_samples=2) 152 | prng = next(self.prng_seq) 153 | 154 | # forward a first time 155 | out_2 = value_and_grad_fn(self.data_spec.input, prng) 156 | 157 | # forward with a different seed does not give the same result 158 | out_2_diff = value_and_grad_fn(self.data_spec.input, next(self.prng_seq)) 159 | with self.assertRaises(AssertionError): 160 | chex.assert_trees_all_close(out_2, out_2_diff) 161 | 162 | # forward with 3 samples and the same prng is not the same 163 | value_and_grad_fn = self._make_value_and_grad(params, num_samples=3) 164 | out_3_same_prng = value_and_grad_fn(self.data_spec.input, prng) 165 | with self.assertRaises(AssertionError): 166 | chex.assert_trees_all_close(out_2, out_3_same_prng) 167 | 168 | 169 | if __name__ == '__main__': 170 | absltest.main() 171 | -------------------------------------------------------------------------------- /jax_verify/tests/functional_lagrangian/lagrangian_form_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Elementary tests for the Lagrangian forms.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import chex 21 | import haiku as hk 22 | import jax 23 | from jax_verify.extensions.functional_lagrangian import lagrangian_form 24 | 25 | INPUT_SHAPES = (('batched_0d', [1]), ('batched_1d', [1, 2]), 26 | ('batched_2d', [1, 2, 3]), ('batched_3d', [1, 2, 3, 4])) 27 | 28 | 29 | class ShapeTest(chex.TestCase): 30 | 31 | def setUp(self): 32 | super(ShapeTest, self).setUp() 33 | self._prng_seq = hk.PRNGSequence(13579) 34 | 35 | def _assert_output_shape(self, form, shape): 36 | x = jax.random.normal(next(self._prng_seq), shape) 37 | params = form.init_params(next(self._prng_seq), x.shape[1:]) 38 | out = form.apply(x, params, step=0) 39 | assert out.ndim == 1 40 | 41 | @parameterized.named_parameters(*INPUT_SHAPES) 42 | def test_linear(self, shape): 43 | form = lagrangian_form.Linear() 44 | self._assert_output_shape(form, shape) 45 | 46 | @parameterized.named_parameters(*INPUT_SHAPES) 47 | def test_linear_exp(self, shape): 48 | form = lagrangian_form.LinearExp() 49 | self._assert_output_shape(form, shape) 50 | 51 | 52 | if __name__ == '__main__': 53 | absltest.main() 54 | -------------------------------------------------------------------------------- /jax_verify/tests/functional_lagrangian/lp_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unit-test for linear Lagrangian.""" 17 | 18 | from absl.testing import absltest 19 | import chex 20 | import jax 21 | import jax.numpy as jnp 22 | import jax_verify 23 | from jax_verify.extensions.functional_lagrangian import dual_build 24 | from jax_verify.extensions.functional_lagrangian import dual_solve 25 | from jax_verify.extensions.functional_lagrangian import lagrangian_form 26 | from jax_verify.extensions.functional_lagrangian import verify_utils 27 | from jax_verify.extensions.functional_lagrangian.inner_solvers import lp 28 | from jax_verify.extensions.sdp_verify import utils as sdp_utils 29 | from jax_verify.src import bound_propagation 30 | from jax_verify.src.mip_solver import cvxpy_relaxation_solver 31 | from jax_verify.src.mip_solver import relaxation 32 | from jax_verify.tests.sdp_verify import test_utils as sdp_test_utils 33 | import ml_collections 34 | import numpy as np 35 | 36 | NUM_SAMPLES = 1 37 | LAYER_SIZES = [3, 4, 5, 6] 38 | 39 | 40 | def create_inputs(prng_key): 41 | return jax.random.uniform( 42 | prng_key, [NUM_SAMPLES, LAYER_SIZES[0]], minval=0.0, maxval=1.0) 43 | 44 | 45 | def make_model_fn(params): 46 | 47 | def model_fn(inputs): 48 | inputs = np.reshape(inputs, (inputs.shape[0], -1)) 49 | return sdp_utils.predict_mlp(params, inputs) 50 | 51 | return model_fn 52 | 53 | 54 | def get_config(): 55 | config = ml_collections.ConfigDict() 56 | 57 | config.outer_opt = ml_collections.ConfigDict() 58 | config.outer_opt.lr_init = 0.001 59 | config.outer_opt.steps_per_anneal = 500 60 | config.outer_opt.anneal_lengths = '' 61 | config.outer_opt.anneal_factor = 0.1 62 | config.outer_opt.num_anneals = 1 63 | config.outer_opt.opt_name = 'adam' 64 | config.outer_opt.opt_kwargs = {} 65 | 66 | return config 67 | 68 | 69 | class LinearTest(chex.TestCase): 70 | 71 | def setUp(self): 72 | super(LinearTest, self).setUp() 73 | 74 | self.target_label = 1 75 | self.label = 0 76 | self.input_bounds = (0.0, 1.0) 77 | self.layer_sizes = LAYER_SIZES 78 | self.eps = 0.1 79 | 80 | prng_key = jax.random.PRNGKey(13579) 81 | 82 | self.keys = jax.random.split(prng_key, 5) 83 | self.network_params = sdp_test_utils.make_mlp_params( 84 | self.layer_sizes, self.keys[0]) 85 | 86 | self.inputs = create_inputs(self.keys[1]) 87 | 88 | objective = jnp.zeros(self.layer_sizes[-1]) 89 | objective = objective.at[self.target_label].add(1) 90 | objective = objective.at[self.label].add(-1) 91 | self.objective = objective 92 | self.objective_bias = jax.random.normal(self.keys[2], []) 93 | 94 | def solve_with_jax_verify(self): 95 | lower_bound = jnp.minimum(jnp.maximum(self.inputs - self.eps, 0.0), 1.0) 96 | upper_bound = jnp.minimum(jnp.maximum(self.inputs + self.eps, 0.0), 1.0) 97 | init_bound = jax_verify.IntervalBound(lower_bound, upper_bound) 98 | 99 | logits_fn = make_model_fn(self.network_params) 100 | 101 | solver = cvxpy_relaxation_solver.CvxpySolver 102 | relaxation_transform = relaxation.RelaxationTransform( 103 | jax_verify.ibp_transform) 104 | 105 | var, env = bound_propagation.bound_propagation( 106 | bound_propagation.ForwardPropagationAlgorithm(relaxation_transform), 107 | logits_fn, init_bound) 108 | 109 | # This solver minimizes the objective -> get max with -min(-objective) 110 | neg_value_opt, _, _ = relaxation.solve_relaxation( 111 | solver, 112 | -self.objective, 113 | -self.objective_bias, 114 | var, 115 | env, 116 | index=0, 117 | time_limit_millis=None) 118 | value_opt = -neg_value_opt 119 | 120 | return value_opt 121 | 122 | def solve_with_functional_lagrangian(self): 123 | config = get_config() 124 | 125 | init_bound = sdp_utils.init_bound( 126 | self.inputs[0], self.eps, input_bounds=self.input_bounds) 127 | bounds = sdp_utils.boundprop( 128 | self.network_params + [(self.objective, self.objective_bias)], 129 | init_bound) 130 | 131 | logits_fn = make_model_fn(self.network_params) 132 | 133 | def spec_fn(inputs): 134 | return jnp.matmul(logits_fn(inputs), self.objective) + self.objective_bias 135 | 136 | input_bounds = jax_verify.IntervalBound(bounds[0].lb, bounds[0].ub) 137 | 138 | lagrangian_form_per_layer = lagrangian_form.Linear() 139 | lagrangian_form_per_layer = [lagrangian_form_per_layer for bd in bounds] 140 | inner_opt = lp.LpStrategy() 141 | env, dual_params, dual_params_types = inner_opt.init_duals( 142 | jax_verify.ibp_transform, verify_utils.SpecType.ADVERSARIAL, False, 143 | spec_fn, self.keys[3], lagrangian_form_per_layer, input_bounds) 144 | opt, num_steps = dual_build.make_opt_and_num_steps(config.outer_opt) 145 | dual_state = ml_collections.ConfigDict(type_safe=False) 146 | dual_solve.solve_dual_train( 147 | env, 148 | key=self.keys[4], 149 | num_steps=num_steps, 150 | opt=opt, 151 | dual_params=dual_params, 152 | dual_params_types=dual_params_types, 153 | dual_state=dual_state, 154 | affine_before_relu=False, 155 | spec_type=verify_utils.SpecType.ADVERSARIAL, 156 | inner_opt=inner_opt, 157 | logger=(lambda *args: None), 158 | ) 159 | 160 | return dual_state.loss 161 | 162 | def test_lp_against_jax_verify_relaxation(self): 163 | value_jax_verify = self.solve_with_jax_verify() 164 | value_functional_lagrangian = self.solve_with_functional_lagrangian() 165 | 166 | np.testing.assert_allclose( 167 | value_jax_verify, value_functional_lagrangian, rtol=1e-3) 168 | 169 | 170 | if __name__ == '__main__': 171 | absltest.main() 172 | -------------------------------------------------------------------------------- /jax_verify/tests/functional_lagrangian/uncertainty_spec_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unit-test for uncertainty spec inner max.""" 17 | 18 | from absl.testing import absltest 19 | import chex 20 | import haiku as hk 21 | import jax 22 | import jax.numpy as jnp 23 | from jax_verify.extensions.functional_lagrangian import lagrangian_form as lag_form 24 | from jax_verify.extensions.functional_lagrangian import verify_utils 25 | from jax_verify.extensions.functional_lagrangian.inner_solvers import uncertainty_spec 26 | from jax_verify.extensions.sdp_verify import utils as sdp_utils 27 | import numpy as np 28 | 29 | X_SHAPE = [1, 7] 30 | 31 | 32 | class UncertaintySpecTest(chex.TestCase): 33 | 34 | def setUp(self): 35 | super(UncertaintySpecTest, self).setUp() 36 | 37 | self._prng_seq = hk.PRNGSequence(13579) 38 | self._n_classes = X_SHAPE[1] 39 | 40 | self.bounds = [ 41 | sdp_utils.IntBound( 42 | lb_pre=-0.1 * jnp.ones(X_SHAPE), 43 | ub_pre=0.1 * jnp.ones(X_SHAPE), 44 | lb=None, 45 | ub=None) 46 | ] 47 | 48 | def test_softmax_upper(self): 49 | rand_class = jax.random.randint( 50 | next(self._prng_seq), shape=(), minval=0, maxval=self._n_classes) 51 | objective = jnp.arange(self._n_classes) == rand_class 52 | constant = jax.random.uniform(next(self._prng_seq), ()) 53 | 54 | affine_fn = lambda x: jnp.sum(x * objective) + constant 55 | 56 | lagrangian_form = lag_form.Linear() 57 | lp_pre = lagrangian_form.init_params( 58 | next(self._prng_seq), l_shape=X_SHAPE, init_zeros=False) 59 | 60 | opt_instance = verify_utils.InnerVerifInstance( 61 | affine_fns=[affine_fn], 62 | bounds=self.bounds, 63 | lagrangian_form_pre=lagrangian_form, 64 | lagrangian_form_post=lagrangian_form, 65 | is_first=False, 66 | is_last=True, 67 | lagrange_params_pre=lp_pre, 68 | lagrange_params_post=None, 69 | idx=0, 70 | spec_type=verify_utils.SpecType.UNCERTAINTY, 71 | affine_before_relu=True) 72 | 73 | # run PGA to find approximate max 74 | pga_opt = uncertainty_spec.UncertaintySpecStrategy( 75 | n_iter=10_000, 76 | n_pieces=0, 77 | solve_max=uncertainty_spec.MaxType.EXP, 78 | ) 79 | 80 | value_pga = pga_opt.solve_max( 81 | inner_dual_vars=None, 82 | opt_instance=opt_instance, 83 | key=next(self._prng_seq), 84 | step=0) 85 | 86 | # use cvxpy to find upper bound 87 | cvx_opt = uncertainty_spec.UncertaintySpecStrategy( 88 | n_iter=0, 89 | n_pieces=10, 90 | solve_max=uncertainty_spec.MaxType.EXP_BOUND, 91 | ) 92 | 93 | value_cvx = cvx_opt.solve_max( 94 | inner_dual_vars=None, 95 | opt_instance=opt_instance, 96 | key=next(self._prng_seq), 97 | step=0) 98 | 99 | # evaluate objective function on an arbitrarily chosen feasible point 100 | def objective_fn(x): 101 | return (jnp.squeeze(affine_fn(jax.nn.softmax(x)), ()) - 102 | jnp.squeeze(lagrangian_form.apply(x, lp_pre, step=0), ())) 103 | 104 | middle_x = 0.5 * self.bounds[0].lb_pre + 0.5 * self.bounds[0].ub_pre 105 | value_middle = objective_fn(middle_x) 106 | 107 | np.testing.assert_array_less(value_middle, value_pga) 108 | np.testing.assert_array_less(value_pga, value_cvx + 1e-5) 109 | 110 | 111 | if __name__ == '__main__': 112 | absltest.main() 113 | -------------------------------------------------------------------------------- /jax_verify/tests/model_zoo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Example architecture to test functions. 17 | """ 18 | 19 | import haiku as hk 20 | import jax 21 | 22 | 23 | class SmallResidualModel(hk.Module): 24 | """Small network with residual connections. 25 | 26 | Smaller version of ResidualModel. 27 | """ 28 | 29 | def __init__(self): 30 | super().__init__() 31 | bn_config = {'create_scale': True, 32 | 'create_offset': True, 33 | 'decay_rate': 0.999} 34 | 35 | # Definition of the modules. 36 | self.conv_block = hk.Sequential([ 37 | hk.Conv2D(1, (3, 3), stride=3, rate=1), jax.nn.relu, 38 | hk.Conv2D(1, (3, 3), stride=3, rate=1), jax.nn.relu, 39 | ]) 40 | 41 | self.conv_res_block = hk.Sequential([ 42 | hk.Conv2D(1, (1, 1), stride=1, rate=1), jax.nn.relu, 43 | hk.Conv2D(1, (1, 1), stride=1, rate=1), jax.nn.relu, 44 | ]) 45 | 46 | self.reshape_mod = hk.Flatten() 47 | 48 | self.lin_res_block = [ 49 | (hk.Linear(16), hk.BatchNorm(name='lin_batchnorm_0', **bn_config)) 50 | ] 51 | 52 | self.final_linear = hk.Linear(10) 53 | 54 | def call_all_act(self, inputs, is_training, test_local_stats=False): 55 | """Evaluate the model, returning its intermediate activations. 56 | 57 | Args: 58 | inputs: BHWC array of images. 59 | is_training: Boolean flag, whether this is during training. 60 | test_local_stats: Boolean flag, Whether local stats are used 61 | when is_training=False (for batchnorm). 62 | Returns: 63 | all_acts: List with the intermediate activations of interest. 64 | """ 65 | all_acts = [] 66 | all_acts.append(inputs) 67 | 68 | ## Forward propagation. 69 | # First conv layer. 70 | act = self.conv_block(inputs) 71 | all_acts.append(act) 72 | # Convolutional residual block. 73 | act = act + self.conv_res_block(act) 74 | all_acts.append(act) 75 | # Reshape before fully connected part. 76 | act = self.reshape_mod(act) 77 | all_acts.append(act) 78 | # Fully connected residual block. 79 | lin_block_act = act 80 | for lin_i, bn_i in self.lin_res_block: 81 | lin_block_act = lin_i(lin_block_act) 82 | lin_block_act = bn_i(lin_block_act, is_training, test_local_stats) 83 | lin_block_act = jax.nn.relu(lin_block_act) 84 | act = act + lin_block_act 85 | all_acts.append(act) 86 | # Final layer. 87 | act = self.final_linear(act) 88 | all_acts.append(act) 89 | return all_acts 90 | 91 | def __call__(self, inputs, is_training, test_local_stats=False): 92 | """Return only the final prediction of the model. 93 | 94 | Args: 95 | inputs: BHWC array of images. 96 | is_training: Boolean flag, whether this is during training. 97 | test_local_stats: Boolean flag, Whether local stats are used 98 | when is_training=False (for batchnorm). 99 | Returns: 100 | pred: Array with the predictions, corresponding to the last activations. 101 | """ 102 | all_acts = self.call_all_act(inputs, is_training, test_local_stats) 103 | return all_acts[-1] 104 | 105 | 106 | class TinyModel(hk.Module): 107 | """Tiny network. 108 | 109 | Single conv layer. 110 | """ 111 | 112 | def __init__(self): 113 | super().__init__() 114 | # Definition of the modules. 115 | self.reshape_mod = hk.Flatten() 116 | 117 | self.lin_block = hk.Sequential([ 118 | hk.Linear(20), jax.nn.relu, 119 | ]) 120 | 121 | self.final_linear = hk.Linear(10) 122 | 123 | def call_all_act(self, inputs, is_training, test_local_stats=False): 124 | """Evaluate the model, returning its intermediate activations. 125 | 126 | Args: 127 | inputs: BHWC array of images. 128 | is_training: Boolean flag, whether this is during training. 129 | test_local_stats: Boolean flag, Whether local stats are used 130 | when is_training=False (for batchnorm). 131 | Returns: 132 | all_acts: List with the intermediate activations of interest. 133 | """ 134 | all_acts = [] 135 | all_acts.append(inputs) 136 | act = inputs 137 | ## Forward propagation. 138 | act = self.reshape_mod(act) 139 | all_acts.append(act) 140 | 141 | # First linear layer. 142 | act = self.lin_block(act) 143 | all_acts.append(act) 144 | # Final layer. 145 | act = self.final_linear(act) 146 | all_acts.append(act) 147 | return all_acts 148 | 149 | def __call__(self, inputs, is_training, test_local_stats=False): 150 | """Return only the final prediction of the model. 151 | 152 | Args: 153 | inputs: BHWC array of images. 154 | is_training: Boolean flag, whether this is during training. 155 | test_local_stats: Boolean flag, Whether local stats are used 156 | when is_training=False (for batchnorm). 157 | Returns: 158 | pred: Array with the predictions, corresponding to the last activations. 159 | """ 160 | all_acts = self.call_all_act(inputs, is_training, test_local_stats) 161 | return all_acts[-1] 162 | -------------------------------------------------------------------------------- /jax_verify/tests/opt_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for jax_verify opt_utils.""" 17 | import functools 18 | from absl.testing import absltest 19 | 20 | import chex 21 | import jax 22 | from jax import numpy as jnp 23 | from jax_verify.src import opt_utils 24 | import numpy as np 25 | 26 | 27 | class OptUtilsTest(absltest.TestCase): 28 | 29 | def test_greedy_assign(self): 30 | 31 | # Build a list of upper bounds, sum, and the expected greedy assginment. 32 | problems = [ 33 | (0.5 * jnp.ones(5,), 2.5, 0.5 * jnp.ones(5,)), 34 | (0.5 * jnp.ones(5,), 1.0, jnp.array([0.5, 0.5, 0., 0., 0.])), 35 | (0.5 * jnp.ones(5,), 0.75, jnp.array([0.5, 0.25, 0., 0., 0.])), 36 | (0.5 * jnp.ones(5,), 0.3, jnp.array([0.3, 0., 0., 0., 0.])), 37 | (jnp.array([0., 1., 0., 0.5]), 1.2, jnp.array([0., 1., 0., 0.2])), 38 | (jnp.array([1., 2., 3.]), 2.5, jnp.array([1., 1.5, 0.])) 39 | ] 40 | 41 | for upper, total_sum, ref_answer in problems: 42 | # Try the forward assignment. 43 | pred = opt_utils.greedy_assign(upper, total_sum) 44 | chex.assert_trees_all_close(pred, ref_answer) 45 | 46 | def test_1d_binary_search(self): 47 | for seed in range(10): 48 | argmax = jax.random.uniform(jax.random.PRNGKey(seed), ()) 49 | 50 | # Try out two possible types of concave function for which we know the 51 | # maximum. 52 | ccv_fun = lambda x, argmax=argmax: -(x - argmax)**2 53 | pred_argmax, max_val = opt_utils.concave_1d_max( 54 | ccv_fun, jnp.zeros(()), jnp.ones(()), num_steps=64) 55 | self.assertAlmostEqual(max_val, 0., delta=1e-6) # pytype: disable=wrong-arg-types # jax-ndarray 56 | self.assertAlmostEqual(pred_argmax, argmax, delta=1e-6) 57 | 58 | alt_ccv_fun = lambda x, argmax=argmax: -jnp.abs(x - argmax) 59 | pred_argmax, max_val = opt_utils.concave_1d_max( 60 | alt_ccv_fun, jnp.zeros(()), jnp.ones(()), num_steps=64) 61 | self.assertAlmostEqual(max_val, 0., delta=1e-6) # pytype: disable=wrong-arg-types # jax-ndarray 62 | self.assertAlmostEqual(pred_argmax, argmax, delta=1e-6) 63 | 64 | x, y = opt_utils.concave_1d_max( 65 | lambda x: -x**2 + 4.*x - 3., # max at x=2, y=1 66 | jnp.array([0., -11., 10.]), 67 | jnp.array([3., -10., 11.]), 68 | ) 69 | np.testing.assert_array_almost_equal(x, np.array([2., -10., 10.]), 70 | decimal=3) 71 | np.testing.assert_array_almost_equal(y, np.array([1., -143., -63.]), 72 | decimal=4) 73 | 74 | def test_simplex_projection_fully_constrained(self): 75 | # Test the edge case of an simplex sum with one element. 76 | # This should always give the simplex_sum if it's in the valid bounds. 77 | all_initial_values = jnp.expand_dims(jnp.linspace(-10., 10., 100), 1) 78 | 79 | project_onto_01 = functools.partial(opt_utils.project_onto_interval_simplex, 80 | jnp.zeros((1,)), jnp.ones((1,)), 81 | 1.0) 82 | batch_project_onto_01 = jax.vmap(project_onto_01) 83 | all_res = batch_project_onto_01(all_initial_values) 84 | 85 | self.assertAlmostEqual(all_res.min(), 1.0, delta=1e-6) 86 | self.assertAlmostEqual(all_res.max(), 1.0, delta=1e-6) 87 | 88 | project_onto_03 = functools.partial(opt_utils.project_onto_interval_simplex, 89 | jnp.zeros((1,)), 3*jnp.ones((1,)), 90 | 1.0) 91 | batch_project_onto_03 = jax.vmap(project_onto_03) 92 | all_res = batch_project_onto_03(all_initial_values) 93 | self.assertAlmostEqual(all_res.min(), 1.0, delta=1e-6) 94 | self.assertAlmostEqual(all_res.max(), 1.0, delta=1e-6) 95 | 96 | key = jax.random.PRNGKey(0) 97 | initial_values = jax.random.uniform(key, (100, 5), minval=-10, maxval=10) 98 | # There is only one valid solution to this problem: everything is 1. 99 | project = functools.partial(opt_utils.project_onto_interval_simplex, 100 | jnp.zeros((5,)), jnp.ones((5,)), 5.0) 101 | batch_project = jax.vmap(project) 102 | all_res = batch_project(initial_values) 103 | self.assertAlmostEqual(all_res.min(), 1.0, delta=1e-6) 104 | self.assertAlmostEqual(all_res.max(), 1.0, delta=1e-6) 105 | 106 | 107 | if __name__ == '__main__': 108 | absltest.main() 109 | -------------------------------------------------------------------------------- /jax_verify/tests/sdp_verify/boundprop_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for crown_boundprop.py.""" 17 | 18 | import functools 19 | import pickle 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | import jax 23 | import jax.numpy as jnp 24 | import jax_verify 25 | from jax_verify.extensions.sdp_verify import boundprop_utils 26 | from jax_verify.extensions.sdp_verify import utils 27 | import numpy as np 28 | 29 | 30 | class BoundpropTest(parameterized.TestCase): 31 | 32 | def test_crown_boundprop(self): 33 | """Test CROWN bounds vs FGSM on Wong-Small MNIST CNN.""" 34 | crown_boundprop = functools.partial(boundprop_utils.boundprop, 35 | boundprop_type='crown_ibp') 36 | self._test_boundprop(crown_boundprop) 37 | 38 | def test_nonconvex_boundprop(self): 39 | """Test Nonconvex bounds vs FGSM on Wong-Small MNIST CNN.""" 40 | # Minimal test, since this already takes 70s. 41 | nonconvex_boundprop = functools.partial( 42 | boundprop_utils.boundprop, boundprop_type='nonconvex', 43 | nonconvex_boundprop_steps=2) 44 | self._test_boundprop(nonconvex_boundprop, num_idxs_to_test=1) 45 | 46 | def test_ibp_boundprop(self): 47 | def boundprop(params, x, epsilon, input_bounds): 48 | assert len(x.shape) == 4 and x.shape[0] == 1, f'shape check {x.shape}' 49 | init_bound = utils.init_bound(x[0], epsilon, input_bounds=input_bounds) 50 | return utils.boundprop(params, init_bound) 51 | self._test_boundprop(boundprop) 52 | 53 | def _test_boundprop(self, boundprop_method, num_idxs_to_test=5): 54 | """Test `boundprop_method` on Wong-Small MNIST CNN.""" 55 | with jax_verify.open_file('mnist/x_test_first100.npy', 'rb') as f: 56 | xs = np.load(f) 57 | model_name = 'models/mnist_wongsmall_eps_10_adv.pkl' 58 | with jax_verify.open_file(model_name, 'rb') as f: 59 | params = pickle.load(f) 60 | x = xs[0] 61 | eps = 0.1 62 | 63 | bounds = boundprop_method(params, np.expand_dims(x, axis=0), eps, 64 | input_bounds=(0., 1.)) 65 | crown_lbs = utils.flatten([b.lb_pre for b in bounds[1:]]) 66 | crown_ubs = utils.flatten([b.ub_pre for b in bounds[1:]]) 67 | 68 | max_idx = crown_lbs.shape[0] 69 | np.random.seed(0) 70 | test_idxs = np.random.randint(max_idx, size=num_idxs_to_test) 71 | 72 | @jax.jit 73 | def fwd(x): 74 | _, acts = utils.predict_cnn(params, jnp.expand_dims(x, 0), 75 | include_preactivations=True) 76 | return acts 77 | 78 | get_act = lambda x, idx: utils.flatten(fwd(x), backend=jnp)[idx] 79 | 80 | print('Number of activations:', crown_lbs.shape[0]) 81 | print('Bound shape', [b.lb.shape for b in bounds]) 82 | print('Activation shape', [a.shape for a in fwd(x)]) 83 | assert utils.flatten(fwd(x)).shape == crown_lbs.shape, ( 84 | f'bad shape {crown_lbs.shape}, {utils.flatten(fwd(x)).shape}') 85 | 86 | for idx in test_idxs: 87 | nom = get_act(x, idx) 88 | crown_lb = crown_lbs[idx] 89 | crown_ub = crown_ubs[idx] 90 | 91 | adv_loss = lambda x: get_act(x, idx) # pylint: disable=cell-var-from-loop 92 | x_lb = utils.pgd(adv_loss, x, eps, 5, 0.01) 93 | fgsm_lb = get_act(x_lb, idx) 94 | 95 | adv_loss = lambda x: -get_act(x, idx) # pylint: disable=cell-var-from-loop 96 | x_ub = utils.pgd(adv_loss, x, eps, 5, 0.01) 97 | fgsm_ub = get_act(x_ub, idx) 98 | 99 | print(f'Idx {idx}: Boundprop LB {crown_lb}, FGSM LB {fgsm_lb}, ' 100 | f'Nominal {nom}, FGSM UB {fgsm_ub}, Boundprop UB {crown_ub}') 101 | margin = 1e-5 102 | assert crown_lb <= fgsm_lb + margin, f'Bad lower bound. Idx {idx}.' 103 | assert crown_ub >= fgsm_ub - margin, f'Bad upper bound. Idx {idx}.' 104 | 105 | crown_lb_post, fgsm_lb_post = max(crown_lb, 0), max(fgsm_lb, 0) 106 | crown_ub_post, fgsm_ub_post = max(crown_ub, 0), max(fgsm_ub, 0) 107 | assert crown_lb_post <= fgsm_lb_post + margin, f'Idx {idx}.' 108 | assert crown_ub_post >= fgsm_ub_post - margin, f'Idx {idx}.' 109 | 110 | 111 | if __name__ == '__main__': 112 | absltest.main() 113 | -------------------------------------------------------------------------------- /jax_verify/tests/sdp_verify/cvxpy_verify_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for cvxpy_verify.py.""" 17 | 18 | import unittest 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from cvxpy.reductions.solvers.defines import INSTALLED_MI_SOLVERS as MIP_SOLVERS 22 | import jax.numpy as jnp 23 | from jax_verify.extensions.sdp_verify import cvxpy_verify 24 | from jax_verify.extensions.sdp_verify import utils 25 | from jax_verify.tests.sdp_verify import test_utils 26 | 27 | NO_MIP_SOLVERS_MESSAGE = 'No mixed-integer solver is installed.' 28 | 29 | 30 | class CvxpyTest(parameterized.TestCase): 31 | 32 | @unittest.skipUnless(MIP_SOLVERS, NO_MIP_SOLVERS_MESSAGE) 33 | def test_mip_status(self): 34 | """Test toy MIP is solved optimally by cvxpy.""" 35 | for seed in range(10): 36 | verif_instance = test_utils.make_toy_verif_instance(seed) 37 | val, info = cvxpy_verify.solve_mip_mlp_elided(verif_instance) 38 | status = info['problem'].status 39 | assert val is not None 40 | assert status in ('optimal', 'optimal_inaccurate'), f'Status is {status}.' 41 | 42 | def test_sdp_status(self): 43 | """Test toy SDP is solved optimally by cvxpy.""" 44 | for seed in range(10): 45 | verif_instance = test_utils.make_toy_verif_instance(seed) 46 | val, info = cvxpy_verify.solve_sdp_mlp_elided(verif_instance) 47 | status = info['problem'].status 48 | assert val is not None 49 | assert status in ('optimal', 'optimal_inaccurate'), f'Status is {status}.' 50 | 51 | 52 | def _fgsm_example_and_bound(params, target_label, label): 53 | model_fn = lambda x: utils.predict_mlp(params, x) 54 | x = 0.5 * jnp.ones(utils.nn_layer_sizes(params)[0]) 55 | epsilon = 0.5 56 | x_adv = utils.fgsm_single(model_fn, x, label, target_label, epsilon, 57 | num_steps=30, step_size=0.03) 58 | return x_adv, utils.adv_objective(model_fn, x_adv, label, target_label) 59 | 60 | MARGIN = 1e-6 61 | 62 | 63 | class CrossingBoundsTest(parameterized.TestCase): 64 | """Check IBP,SDP relaxations <= MIP <= FGSM upper bound.""" 65 | 66 | @unittest.skipUnless(MIP_SOLVERS, NO_MIP_SOLVERS_MESSAGE) 67 | def test_fgsm_vs_mip(self): 68 | num_repeats = 5 69 | target_label, label = 1, 2 70 | for seed in range(num_repeats): 71 | verif_instance = test_utils.make_toy_verif_instance( 72 | seed, target_label=target_label, label=label) 73 | mip_val, _ = cvxpy_verify.solve_mip_mlp_elided(verif_instance) 74 | _, fgsm_val = _fgsm_example_and_bound( 75 | verif_instance.params_full, target_label=target_label, label=label) 76 | assert mip_val > fgsm_val - MARGIN, ( 77 | 'MIP exact solution should be greater than FGSM lower bound.') 78 | 79 | @unittest.skipUnless(MIP_SOLVERS, NO_MIP_SOLVERS_MESSAGE) 80 | def test_sdp_vs_mip(self): 81 | num_repeats = 5 82 | loss_margin = 1e-3 # fixed via runs_per_test=300 with random seeds 83 | for seed in range(num_repeats): 84 | verif_instance = test_utils.make_toy_verif_instance(seed) 85 | mip_val, _ = cvxpy_verify.solve_mip_mlp_elided(verif_instance) 86 | sdp_val, _ = cvxpy_verify.solve_sdp_mlp_elided(verif_instance) 87 | assert sdp_val > mip_val - loss_margin, ( 88 | 'SDP relaxation should be greater than MIP exact solution. ' 89 | f'Vals are MIP: {mip_val} SDP: {sdp_val}') 90 | 91 | 92 | class MatchingBoundsTest(parameterized.TestCase): 93 | 94 | @unittest.skipUnless(MIP_SOLVERS, NO_MIP_SOLVERS_MESSAGE) 95 | def test_fgsm_vs_mip(self): 96 | """Check FGSM and MIP reach same solution/value most of the time.""" 97 | # Note this test only works with fixed seeds 98 | num_repeats = 5 99 | expected_successes = 4 100 | num_successes = 0 101 | loss_margin = 0.01 102 | target_label, label = 1, 2 103 | for seed in range(num_repeats): 104 | verif_instance = test_utils.make_toy_verif_instance( 105 | seed, target_label=target_label, label=label) 106 | mip_val, _ = cvxpy_verify.solve_mip_mlp_elided(verif_instance) 107 | _, fgsm_val = _fgsm_example_and_bound( 108 | verif_instance.params_full, target_label=target_label, label=label) 109 | if abs(mip_val - fgsm_val) < loss_margin: 110 | num_successes += 1 111 | assert num_successes >= expected_successes, f'Successes: {num_successes}' 112 | 113 | 114 | class SdpTest(parameterized.TestCase): 115 | 116 | def test_constraints_numpy(self): 117 | num_repeats = 5 118 | margin = 3e-4 119 | for seed in range(num_repeats): 120 | verif_instance = test_utils.make_toy_verif_instance( 121 | seed=seed, label=1, target_label=2) 122 | obj_value, info = cvxpy_verify.solve_sdp_mlp_elided(verif_instance) 123 | obj_np, violations = cvxpy_verify.check_sdp_bounds_numpy( 124 | info['P'].value, verif_instance) 125 | assert abs(obj_np - obj_value) < margin, 'objective does not match' 126 | for k, v in violations.items(): 127 | assert v < margin, f'violation of {k} by {v}' 128 | 129 | 130 | if __name__ == '__main__': 131 | absltest.main() 132 | -------------------------------------------------------------------------------- /jax_verify/tests/sdp_verify/problem_from_graph_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for problem_from_graph.py.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import jax 21 | from jax import numpy as jnp 22 | import jax_verify 23 | from jax_verify.extensions.sdp_verify import problem 24 | from jax_verify.extensions.sdp_verify import problem_from_graph 25 | from jax_verify.extensions.sdp_verify import sdp_verify 26 | from jax_verify.extensions.sdp_verify import utils 27 | from jax_verify.src import ibp 28 | from jax_verify.tests.sdp_verify import test_utils 29 | 30 | 31 | class SdpProblemTest(parameterized.TestCase): 32 | 33 | def assertArrayAlmostEqual(self, lhs, rhs): 34 | self.assertEqual(lhs is None, rhs is None) 35 | if lhs is not None: 36 | diff = jnp.abs(lhs - rhs).max() 37 | self.assertAlmostEqual(diff, 0., places=5) 38 | 39 | def test_sdp_problem_equivalent_to_sdp_verify(self): 40 | # Set up a verification problem for test purposes. 41 | verif_instance = test_utils.make_toy_verif_instance(label=2, target_label=1) 42 | 43 | # Set up a spec function that replicates the test problem. 44 | inputs = jnp.zeros((1, 5)) 45 | input_bounds = jax_verify.IntervalBound( 46 | jnp.zeros_like(inputs), jnp.ones_like(inputs)) 47 | boundprop_transform = ibp.bound_transform 48 | def spec_fn(x): 49 | x = utils.predict_mlp(verif_instance.params, x) 50 | x = jax.nn.relu(x) 51 | return jnp.sum( 52 | jnp.reshape(x, (-1,)) * verif_instance.obj) + verif_instance.const 53 | 54 | # Build an SDP verification instance using the code under test. 55 | sdp_relu_problem = problem_from_graph.SdpReluProblem( 56 | boundprop_transform, spec_fn, input_bounds) 57 | sdp_problem_vi = sdp_relu_problem.build_sdp_verification_instance() 58 | 59 | # Build an SDP verification instance using existing `sdp_verify` code. 60 | sdp_verify_vi = problem.make_sdp_verif_instance(verif_instance) 61 | 62 | self._assert_verif_instances_equal(sdp_problem_vi, sdp_verify_vi) 63 | 64 | def _assert_verif_instances_equal(self, sdp_problem_vi, sdp_verify_vi): 65 | # Assert that bounds are the same. 66 | self.assertEqual(len(sdp_problem_vi.bounds), len(sdp_verify_vi.bounds)) 67 | for sdp_problem_bound, sdp_verify_bound in zip( 68 | sdp_problem_vi.bounds, sdp_verify_vi.bounds): 69 | self.assertArrayAlmostEqual(sdp_problem_bound.lb, sdp_verify_bound.lb) 70 | self.assertArrayAlmostEqual(sdp_problem_bound.ub, sdp_verify_bound.ub) 71 | 72 | # Don't compare dual shapes/types in detail, because the different 73 | # implementations can and do represent them in different 74 | # (but equivalent) ways. 75 | # They should have the same length, though. 76 | self.assertEqual(len(sdp_problem_vi.dual_shapes), 77 | len(sdp_verify_vi.dual_shapes)) 78 | self.assertEqual(len(sdp_problem_vi.dual_types), 79 | len(sdp_verify_vi.dual_types)) 80 | 81 | # Evaluate each problem's dual objective on the same random dual variables. 82 | def random_dual_fun(verif_instance): 83 | key = jax.random.PRNGKey(103) 84 | random_like = lambda x: jax.random.uniform(key, x.shape, x.dtype) 85 | duals = sdp_verify.init_duals(verif_instance, None) 86 | duals = jax.tree_map(random_like, duals) 87 | return sdp_verify.dual_fun(verif_instance, duals) 88 | 89 | self.assertAlmostEqual( 90 | random_dual_fun(sdp_problem_vi), random_dual_fun(sdp_verify_vi), 91 | places=5) 92 | 93 | 94 | if __name__ == '__main__': 95 | absltest.main() 96 | -------------------------------------------------------------------------------- /jax_verify/tests/sdp_verify/test_utilfuns.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for cvxpy_verify.py.""" 17 | 18 | import functools 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import jax.numpy as jnp 22 | import jax.random as random 23 | from jax_verify.extensions.sdp_verify import utils 24 | from jax_verify.tests.sdp_verify import test_utils 25 | 26 | 27 | class ParamExtractionTest(parameterized.TestCase): 28 | """Test the functions extracting network parameters from functions.""" 29 | 30 | def check_fun_extract(self, fun_to_extract, example_inputs): 31 | extracted_params = utils.get_layer_params(fun_to_extract, example_inputs) 32 | 33 | eval_original = fun_to_extract(example_inputs) 34 | eval_extracted = utils.predict_cnn(extracted_params, example_inputs) 35 | 36 | self.assertAlmostEqual(jnp.abs(eval_original - eval_extracted).max(), 0.0, 37 | places=6) 38 | 39 | def test_cnn_extract(self): 40 | """Test that weights from a CNN can be extracted.""" 41 | key = random.PRNGKey(0) 42 | k1, k2 = random.split(key) 43 | 44 | input_sizes = (1, 2, 2, 1) 45 | layer_sizes = [input_sizes, { 46 | 'n_h': 2, 47 | 'n_w': 2, 48 | 'n_cout': 2, 49 | 'padding': 'VALID', 50 | 'stride': 1, 51 | 'n_cin': 1 52 | }, 3] 53 | cnn_params = test_utils.make_cnn_params(layer_sizes, k1) 54 | 55 | fun_to_extract = functools.partial(utils.predict_cnn, cnn_params) 56 | example_inputs = random.normal(k2, input_sizes) 57 | 58 | self.check_fun_extract(fun_to_extract, example_inputs) 59 | 60 | def test_cnn_withpreproc(self): 61 | """Test extraction of weights from a CNN with input preprocessing.""" 62 | key = random.PRNGKey(0) 63 | k1, k2, k3, k4 = random.split(key, num=4) 64 | 65 | input_sizes = (1, 2, 2, 3) 66 | layer_sizes = [input_sizes, { 67 | 'n_h': 2, 68 | 'n_w': 2, 69 | 'n_cout': 2, 70 | 'padding': 'VALID', 71 | 'stride': 1, 72 | 'n_cin': 3 73 | }, 3] 74 | cnn_params = test_utils.make_cnn_params(layer_sizes, k1) 75 | example_inputs = random.normal(k2, input_sizes) 76 | input_mean = random.normal(k3, (3,)) 77 | input_std = random.normal(k4, (3,)) 78 | 79 | def fun_to_extract(inputs): 80 | inp = (inputs - input_mean) / input_std 81 | return utils.predict_cnn(cnn_params, inp) 82 | 83 | self.check_fun_extract(fun_to_extract, example_inputs) 84 | 85 | def test_mlp_extract(self): 86 | """Test that weights from a MLP can be extracted.""" 87 | key = random.PRNGKey(0) 88 | k1, k2 = random.split(key) 89 | 90 | input_sizes = (5,) 91 | layer_sizes = (5, 8, 5) 92 | mlp_params = test_utils.make_mlp_params(layer_sizes, k1) 93 | 94 | fun_to_extract = functools.partial(utils.predict_mlp, mlp_params) 95 | example_inputs = random.normal(k2, input_sizes) 96 | self.check_fun_extract(fun_to_extract, example_inputs) 97 | 98 | def test_mlp_withpreproc(self): 99 | """Test extraction of weights from a MLP with input preprocessing.""" 100 | key = random.PRNGKey(0) 101 | k1, k2, k3, k4 = random.split(key, num=4) 102 | 103 | input_sizes = (5,) 104 | layer_sizes = (5, 8, 5) 105 | mlp_params = test_utils.make_mlp_params(layer_sizes, k1) 106 | example_inputs = random.normal(k2, input_sizes) 107 | input_mean = random.normal(k3, input_sizes) 108 | input_std = random.normal(k4, input_sizes) 109 | 110 | def fun_to_extract(inputs): 111 | inp = (inputs - input_mean) / input_std 112 | return utils.predict_mlp(mlp_params, inp) 113 | 114 | self.check_fun_extract(fun_to_extract, example_inputs) 115 | 116 | 117 | if __name__ == '__main__': 118 | absltest.main() 119 | -------------------------------------------------------------------------------- /jax_verify/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils functions for writing jax_verify tests.""" 17 | import functools 18 | from typing import Tuple 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | from jax_verify.extensions.sdp_verify import utils 23 | from jax_verify.src import opt_utils 24 | from jax_verify.tests.sdp_verify import test_utils as sdp_test_utils 25 | 26 | 27 | def sample_bounds(key: jnp.ndarray, 28 | shape: Tuple[int, ...], 29 | minval: float = -2., 30 | maxval: float = 2.) -> Tuple[jnp.ndarray, jnp.ndarray]: 31 | """Sample some bounds of the required shape. 32 | 33 | Args: 34 | key: Random number generator. 35 | shape: Shape of the bounds to generate. 36 | minval: Optional, smallest value that the bounds could take. 37 | maxval: Optional, largest value that the bounds could take. 38 | Returns: 39 | lb, ub: Lower and upper bound tensor 40 | """ 41 | key_0, key_1 = jax.random.split(key) 42 | bound_1 = jax.random.uniform(key_0, shape, minval=minval, maxval=maxval) 43 | bound_2 = jax.random.uniform(key_1, shape, minval=minval, maxval=maxval) 44 | lb = jnp.minimum(bound_1, bound_2) 45 | ub = jnp.maximum(bound_1, bound_2) 46 | return lb, ub 47 | 48 | 49 | def sample_bounded_points(key: jnp.ndarray, 50 | bounds: Tuple[jnp.ndarray, jnp.ndarray], 51 | nb_points: int, 52 | axis: int = 0) -> jnp.ndarray: 53 | """Sample uniformly some point respecting the bounds. 54 | 55 | Args: 56 | key: Random number generator 57 | bounds: Tuple containing [lower bound, upper bound] 58 | nb_points: How many points to sample. 59 | axis: Which dimension to add to correspond to the number of points. 60 | Returns: 61 | points: Points contained between the given bounds. 62 | """ 63 | lb, ub = bounds 64 | act_shape = lb.shape 65 | to_sample_shape = act_shape[:axis] + (nb_points,) + act_shape[axis:] 66 | unif_samples = jax.random.uniform(key, to_sample_shape) 67 | 68 | broad_lb = jnp.expand_dims(lb, axis) 69 | broad_ub = jnp.expand_dims(ub, axis) 70 | 71 | bound_range = broad_ub - broad_lb 72 | return broad_lb + unif_samples * bound_range 73 | 74 | 75 | def sample_bounded_simplex_points(key: jnp.ndarray, 76 | bounds: Tuple[jnp.ndarray, jnp.ndarray], 77 | simplex_sum: float, 78 | nb_points: int) -> jnp.ndarray: 79 | """Sample some points respecting the bounds as well as a simplex constraint. 80 | 81 | Args: 82 | key: Random number generator 83 | bounds: Tuple containing [lower bound, upper bound]. 84 | simplex_sum: Value that each datapoint should sum to. 85 | nb_points: How many points to sample. 86 | Returns: 87 | Points contained between the given bounds. 88 | 89 | """ 90 | lb, ub = bounds 91 | points = sample_bounded_points(key, bounds, nb_points) 92 | project_fun = functools.partial(opt_utils.project_onto_interval_simplex, 93 | lb, ub, simplex_sum) 94 | batch_project_fun = jax.vmap(project_fun) 95 | return batch_project_fun(points) 96 | 97 | 98 | def set_up_toy_problem(rng_key, batch_size, architecture): 99 | key_1, key_2 = jax.random.split(rng_key) 100 | params = sdp_test_utils.make_mlp_params(architecture, key_2) 101 | 102 | inputs = jax.random.uniform(key_1, (batch_size, architecture[0])) 103 | eps = 0.1 104 | lb = jnp.maximum(jnp.minimum(inputs - eps, 1.), 0.) 105 | ub = jnp.maximum(jnp.minimum(inputs + eps, 1.), 0.) 106 | fun = functools.partial(utils.predict_cnn, params) 107 | return fun, (lb, ub) 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | cvxpy 3 | dm-tree 4 | jax>=0.4.6 5 | jaxlib>=0.4.6 6 | numpy 7 | optax 8 | dm-haiku 9 | einshape @ git+git://github.com/deepmind/einshape.git 10 | ml_collections 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 DeepMind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Setup for pip package.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import unittest 23 | from setuptools import find_namespace_packages 24 | from setuptools import setup 25 | 26 | 27 | def _parse_requirements(requirements_txt_path): 28 | with open(requirements_txt_path) as fp: 29 | return fp.read().splitlines() 30 | 31 | 32 | def test_suite(): 33 | test_loader = unittest.TestLoader() 34 | all_tests = test_loader.discover('jax_verify/tests', 35 | pattern='*_test.py') 36 | return all_tests 37 | 38 | setup( 39 | name='jax_verify', 40 | version='1.0', 41 | description='A library for neural network verification.', 42 | url='https://github.com/deepmind/jax_verify', 43 | author='DeepMind', 44 | author_email='jax_verify@google.com', 45 | # Contained modules and scripts. 46 | packages=find_namespace_packages(exclude=['*_test.py']), 47 | install_requires=_parse_requirements('requirements.txt'), 48 | requires_python='>=3.6', 49 | platforms=['any'], 50 | license='Apache 2.0', 51 | test_suite='setup.test_suite', 52 | include_package_data=True, 53 | zip_safe=False, 54 | ) 55 | --------------------------------------------------------------------------------