├── tests ├── __init__.py ├── conftest.py ├── test_linops.py ├── test_problem_data.py ├── test_qcp_cpu_analytical.py ├── helpers.py ├── test_qcp_gpu_analytical.py └── test_cone_projectors.py ├── experiments ├── __init__.py ├── cpu_baseline.py ├── cvx_problem_generator.py ├── cpu_experiment.py ├── heterogeneous_experiment2.py ├── direct_solve_experiment.py ├── heterogeneous_experiment.py └── gpu_experiment.py ├── .python-version ├── diffqcp ├── cones │ ├── __init__.py │ ├── abstract_projector.py │ └── pow.py ├── __init__.py ├── _helpers.py ├── linops.py ├── qcp_derivs.py ├── problem_data.py └── qcp.py ├── .gitignore ├── .github └── workflows │ ├── test.yml │ └── python-publish.yml ├── pyproject.toml ├── README.md └── LICENSE.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /diffqcp/cones/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /diffqcp/__init__.py: -------------------------------------------------------------------------------- 1 | from .qcp import ( 2 | DeviceQCP as DeviceQCP, 3 | HostQCP as HostQCP 4 | ) 5 | 6 | from .problem_data import ( 7 | QCPStructureCPU as QCPStructureCPU, 8 | QCPStructureGPU as QCPStructureGPU, 9 | QCPStructureLayers as QCPStructureLayers 10 | ) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # mix of .gitignore from `equinox` and the original one (Numpy/SciPy and Torch branches) in `diffqcp` 2 | # and default GitHub ignore. 3 | __pycache__/ 4 | *.py[cod] 5 | *.egg-info/ 6 | *.egg 7 | dist/ 8 | build/ 9 | .env 10 | .venv 11 | *.lock 12 | 13 | # MyPy & pytest 14 | .mypy_cache/ 15 | .pytest_cache/ 16 | 17 | # IDE/editor files 18 | .vscode/ 19 | *.swp 20 | .DS_Store -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Shamelessly taken from https://github.com/patrick-kidger/lineax/blob/main/tests/conftest.py 2 | """ 3 | 4 | import equinox.internal as eqxi 5 | import jax 6 | import pytest 7 | 8 | 9 | jax.config.update("jax_enable_x64", True) 10 | jax.config.update("jax_numpy_dtype_promotion", "strict") 11 | jax.config.update("jax_numpy_rank_promotion", "raise") 12 | 13 | 14 | @pytest.fixture 15 | def getkey(): 16 | return eqxi.GetKey() -------------------------------------------------------------------------------- /diffqcp/cones/abstract_projector.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import equinox as eqx 4 | from lineax import AbstractLinearOperator 5 | from jaxtyping import Float, Array 6 | 7 | class AbstractConeProjector(eqx.Module): 8 | 9 | @abstractmethod 10 | def proj_dproj(self, x: Float[Array, " _n"]) -> tuple[Float[Array, " _n"], AbstractLinearOperator]: 11 | pass 12 | 13 | def __call__(self, x: Float[Array, " _n"]): 14 | return self.proj_dproj(x) -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Example 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | - 'release/**' 9 | tags: 10 | - '*' 11 | 12 | jobs: 13 | uv-example: 14 | name: python 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Install uv 21 | uses: astral-sh/setup-uv@v6 22 | 23 | - name: Set up Python 24 | run: uv python install 25 | 26 | - name: Run tests 27 | # For example, using `pytest` 28 | run: uv run pytest tests 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | 2 | 3 | [project] 4 | name = "diffqcp" 5 | version = "0.4.4" 6 | description = "Engine to compute Jacobian-vector and vector-Jacobian products for (convex) quadratic cone programs." 7 | readme = "README.md" 8 | requires-python = ">=3.11" 9 | dependencies = [ 10 | "equinox>=0.12.2", 11 | "jax>=0.7.2", 12 | "jaxtyping>=0.3.2", 13 | "lineax", 14 | "numpy>=2.3.1", 15 | "scipy>=1.15.3", 16 | ] 17 | 18 | [project.optional-dependencies] 19 | gpu = [ 20 | "cupy-cuda12x>=13.6.0", 21 | "jax[cuda12]>=0.7.2", 22 | "nvmath-python[cu12]>=0.7.0", 23 | ] 24 | 25 | [tool.setuptools.packages.find] 26 | where = ["."] 27 | include = ["diffqcp*"] 28 | 29 | [dependency-groups] 30 | dev = [ 31 | "cvxpy>=1.6.6", 32 | "diffcp>=1.1.4", 33 | "diffqcp[gpu]", 34 | "juliacall>=0.9.26", 35 | "matplotlib>=3.10.5", 36 | "patdb>=0.1.0", 37 | "pytest>=8.4.1", 38 | "ruff>=0.12.0", 39 | ] 40 | 41 | [tool.ruff.lint] 42 | ignore = ["F722"] 43 | 44 | [tool.uv.sources] 45 | lineax = { git = "https://github.com/patrick-kidger/lineax.git" } 46 | diffqcp = { workspace = true } 47 | -------------------------------------------------------------------------------- /diffqcp/_helpers.py: -------------------------------------------------------------------------------- 1 | """Helper/Utility functions used """ 2 | from typing import TYPE_CHECKING 3 | 4 | import numpy as np 5 | from jax.numpy import argsort, stack 6 | from jax.experimental.sparse import BCOO, BCSR 7 | import equinox as eqx 8 | from jaxtyping import Float, Integer, Array 9 | 10 | def _to_int_list(v: np.ndarray) -> list[int]: 11 | """ 12 | Utility function to ensure eqx.filter_{...} TODO(quill): finish 13 | 14 | Parameters 15 | ---------- 16 | v : np.ndarray 17 | Should only contain intgers 18 | """ 19 | return [int(val) for val in v] 20 | 21 | class _TransposeCSRInfo(eqx.Module): 22 | indices: Integer[Array, "..."] 23 | indptr: Integer[Array, "..."] 24 | sorting_perm: Integer[Array, "..."] 25 | 26 | 27 | def _coo_to_csr_transpose_map(mat: Float[BCOO, "_m _n"]) -> _TransposeCSRInfo: 28 | """ 29 | we need `sorting_perm`, otherwise could just .T the BCOO array. 30 | """ 31 | num_rows = mat.shape[0] 32 | rowsT, colsT = mat.indices[:, 1], mat.indices[:, 0] 33 | transposed_val_ordering_unsorted = rowsT * num_rows + colsT # = cols * num_rows + rows 34 | sorting_perm = argsort(transposed_val_ordering_unsorted) 35 | transposed_indices = stack([rowsT[sorting_perm], colsT[sorting_perm]], axis=1) 36 | mat_transposed = BCOO((mat.data[sorting_perm], transposed_indices), 37 | shape=(mat.shape[1], mat.shape[0])) 38 | mat_transposed_csr = BCSR.from_bcoo(mat_transposed) 39 | return _TransposeCSRInfo(indices=mat_transposed_csr.indices, 40 | indptr=mat_transposed_csr.indptr, 41 | sorting_perm=sorting_perm) 42 | 43 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package to PyPI when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | release-build: 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v4 24 | 25 | - uses: actions/setup-python@v5 26 | with: 27 | python-version: "3.x" 28 | 29 | - name: Build release distributions 30 | run: | 31 | # NOTE: put your own distribution build steps here. 32 | python -m pip install build 33 | python -m build 34 | 35 | - name: Upload distributions 36 | uses: actions/upload-artifact@v4 37 | with: 38 | name: release-dists 39 | path: dist/ 40 | 41 | pypi-publish: 42 | runs-on: ubuntu-latest 43 | needs: 44 | - release-build 45 | permissions: 46 | # IMPORTANT: this permission is mandatory for trusted publishing 47 | id-token: write 48 | 49 | # Dedicated environments with protections for publishing are strongly recommended. 50 | # For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules 51 | environment: 52 | name: pypi 53 | # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status: 54 | url: https://pypi.org/p/diffqcp 55 | # 56 | # ALTERNATIVE: if your GitHub Release name is the PyPI project version string 57 | # ALTERNATIVE: exactly, uncomment the following line instead: 58 | # url: https://pypi.org/project/YOURPROJECT/${{ github.event.release.name }} 59 | 60 | steps: 61 | - name: Retrieve release distributions 62 | uses: actions/download-artifact@v4 63 | with: 64 | name: release-dists 65 | path: dist/ 66 | 67 | - name: Publish release distributions to PyPI 68 | uses: pypa/gh-action-pypi-publish@release/v1 69 | with: 70 | packages-dir: dist/ 71 | -------------------------------------------------------------------------------- /tests/test_linops.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jax.random as jr 4 | import lineax as lx 5 | 6 | from diffqcp.linops import _BlockLinearOperator 7 | 8 | from .helpers import tree_allclose 9 | 10 | 11 | def test_block_operator(getkey): 12 | # test `mv` 13 | # test `.transpose.mv` 14 | # test `in_structure` and `out_structure 15 | # test under vmap` 16 | n = 10 17 | m = 5 18 | 19 | x = jr.normal(getkey(), n) 20 | A = jr.normal(getkey(), (m, n)) 21 | op1 = lx.DiagonalLinearOperator(x) 22 | op2 = lx.MatrixLinearOperator(A) 23 | _fn = lambda y: A.T @ y 24 | in_struc_fn = lambda: jnp.arange(m, dtype=x.dtype) 25 | op3 = lx.FunctionLinearOperator(_fn, input_structure=jax.eval_shape(in_struc_fn)) 26 | ops = [op1, op2, op3] 27 | block_op = _BlockLinearOperator(ops) 28 | 29 | in_dim = out_dim = 2 * n + m 30 | assert block_op.in_size() == in_dim 31 | assert block_op.out_size() == out_dim 32 | assert block_op.in_structure().shape == (in_dim,) 33 | assert block_op.out_structure().shape == (out_dim,) 34 | 35 | v = jr.normal(getkey(), in_dim) 36 | out1 = op1.mv(v[0:n]) 37 | out2 = op2.mv(v[n:2*n]) 38 | out3 = op3.mv(v[2*n:2*n+m]) 39 | out_correct = jnp.concatenate([out1, out2, out3]) 40 | assert tree_allclose(out_correct, block_op.mv(v)) 41 | 42 | # --- test vmap --- 43 | 44 | v = jr.normal(getkey(), (5, in_dim)) 45 | out1 = jax.vmap(op1.mv)(v[:, 0:n]) 46 | out2 = jax.vmap(op2.mv)(v[:, n:2*n]) 47 | out3 = jax.vmap(op3.mv)(v[:, 2*n:2*n+m]) 48 | out_correct = jnp.concatenate([out1, out2, out3], axis=1) 49 | assert tree_allclose(out_correct, jax.vmap(block_op.mv)(v)) 50 | 51 | # === test transpose === 52 | 53 | u = jr.normal(getkey(), out_dim) 54 | out1 = op1.transpose().mv(u[0:n]) 55 | out2 = op2.transpose().mv(u[n:n+m]) 56 | out3 = op3.transpose().mv(u[n+m:2*n+m]) 57 | out_correct = jnp.concatenate([out1, out2, out3]) 58 | assert tree_allclose(out_correct, block_op.transpose().mv(u)) 59 | 60 | # --- test vmap --- 61 | 62 | u = jr.normal(getkey(), (5, out_dim)) 63 | out1 = jax.vmap(op1.transpose().mv)(u[:, 0:n]) 64 | out2 = jax.vmap(op2.transpose().mv)(u[:, n:n+m]) 65 | out3 = jax.vmap(op3.transpose().mv)(u[:, n+m:2*n+m]) 66 | out_correct = jnp.concatenate([out1, out2, out3], axis=1) 67 | assert tree_allclose(out_correct, jax.vmap(block_op.transpose().mv)(u)) 68 | 69 | # TODO(quill): will need to create a wrapper function so I can batch ops on top of each other 70 | # NOTE(quill): a `vmap` test in `lineax` does this. 71 | 72 | # TODO(quill): add test to ensure BlockOperator is symmetric if its blocks are symmetric. 73 | # (skipping for now since this is irrelevant to `diffqcp`.) -------------------------------------------------------------------------------- /tests/test_problem_data.py: -------------------------------------------------------------------------------- 1 | """Mainly testing adjoint 2 | """ 3 | 4 | from diffqcp import QCPStructureCPU, QCPStructureGPU, QCPStructureLayers 5 | import numpy as np 6 | import jax.numpy as jnp 7 | from jax.experimental.sparse import BCOO, BCSR 8 | 9 | def _make_upper_tri_bcoo(n, rng): 10 | M = rng.standard_normal(n) 11 | M = np.triu(M) # keep upper triangular (including diag) 12 | Md = jnp.array(M) 13 | return BCOO.fromdense(Md), Md 14 | 15 | def _make_dense_bcsr(m, n, rng): 16 | M = rng.standard_normal((m, n)) 17 | Md = jnp.array(M) 18 | return BCSR.fromdense(Md), Md 19 | 20 | def test_qcpstructurecpu_obj_matrix_and_mv(): 21 | rng = np.random.default_rng(0) 22 | n = 8 23 | m = 5 24 | 25 | P_bcoo, P_upper = _make_upper_tri_bcoo(n, rng) 26 | A_bcoo = BCOO.fromdense(jnp.array(rng.standard_normal((m, n)))) 27 | 28 | s = QCPStructureCPU(P_bcoo, A_bcoo, {}) 29 | 30 | # form ObjMatrixCPU and check mv equals full symmetric multiplication 31 | obj = s.form_obj(P_bcoo) 32 | v = jnp.array(rng.standard_normal(n)) 33 | 34 | # full symmetric matrix = P_upper + P_upper.T - diag(P_upper) 35 | full_sym = P_upper + P_upper.T - jnp.diag(jnp.diag(P_upper)) 36 | res_expected = full_sym @ v 37 | res_actual = obj.mv(v) 38 | 39 | assert jnp.allclose(res_actual, res_expected, atol=1e-12, rtol=1e-12) 40 | 41 | # metadata checks 42 | nz_rows = np.array(s.P_nonzero_rows) 43 | nz_cols = np.array(s.P_nonzero_cols) 44 | # positions reported should match nonzero positions of the upper triangular matrix 45 | mask = (np.triu(P_upper) != 0) 46 | rows, cols = np.where(np.asarray(mask)) 47 | assert set(zip(rows.tolist(), cols.tolist())) == set(zip(nz_rows.tolist(), nz_cols.tolist())) 48 | 49 | def test_qcpstructuregpu_A_transpose_inner_product(): 50 | rng = np.random.default_rng(1) 51 | n = 10 52 | m = 7 53 | 54 | # P can be simple full matrix for obj init 55 | P_dense = jnp.array(rng.standard_normal((n, n))) 56 | P_bcsr = BCSR.fromdense(P_dense) 57 | 58 | A_bcsr, A_dense = _make_dense_bcsr(m, n, rng) 59 | 60 | s = QCPStructureGPU(P_bcsr, A_bcsr, {}) 61 | 62 | # Make random vectors x (n,) and y (m,) 63 | x = jnp.array(rng.standard_normal((n,))) 64 | y = jnp.array(rng.standard_normal((m,))) 65 | 66 | # compute 67 | Ax = A_bcsr @ x 68 | left = y @ Ax 69 | 70 | # form transpose via structure and compute 71 | A_T = s.form_A_transpose(A_bcsr) 72 | ATy = A_T @ y 73 | right = x @ ATy 74 | 75 | assert jnp.allclose(left, right, atol=1e-12, rtol=1e-12) 76 | 77 | # additionally, check that form_A_transpose produces a BCSR whose dense equals A.T 78 | assert jnp.allclose(jnp.array(A_T.todense()), jnp.array(A_dense.T), atol=1e-12) 79 | 80 | -------------------------------------------------------------------------------- /diffqcp/linops.py: -------------------------------------------------------------------------------- 1 | """General (i.e., not cone-specific) linear operators that are not implemented in `lineax`. 2 | 3 | Note that these operators were purposefully made "private" since they are solely implemented 4 | to support functionality required by `diffqcp`. They **should not** be accessed as if they 5 | were true atoms implemented in `lineax`. 6 | """ 7 | import numpy as np 8 | from jax import ShapeDtypeStruct 9 | import jax.numpy as jnp 10 | import lineax as lx 11 | import equinox as eqx 12 | 13 | from diffqcp._helpers import _to_int_list 14 | 15 | 16 | class _BlockLinearOperator(lx.AbstractLinearOperator): 17 | """Represents a block matrix (without explicitly forming zeros). 18 | 19 | TODO(quill): Support operating on PyTrees (clearly the way I handle `input_structure` 20 | and `output_structure` isn't compatible with PyTrees.) 21 | """ 22 | 23 | blocks: list[lx.AbstractLinearOperator] 24 | num_blocks: int 25 | # _in_sizes: list[int] 26 | # _out_sizes: list[int] 27 | # NOTE(quill): either use the non-static defined `split_indices` along with `eqx.filter_{...}`, 28 | # or use regular JAX function transforms with `split_indices` declared as static. 29 | # I'm personally a fan of the explicit declaration, but it seems that this is not the 30 | # suggested approach: https://github.com/patrick-kidger/equinox/issues/154. 31 | # (It is worth noting that `lineax` itself does use explicit static declarations, such as 32 | # in `PyTreeLinearOperator`.) 33 | # split_indices: list[int] 34 | split_indices: list[int] = eqx.field(static=True) 35 | # TODO(quill): make this a JAX array so goes onto device. 36 | 37 | def __init__( 38 | self, 39 | blocks: list[lx.AbstractLinearOperator] 40 | ): 41 | """ 42 | Parameters 43 | ---------- 44 | `blocks`: list[lx.AbstractLinearOperator] 45 | """ 46 | self.blocks = blocks 47 | self.num_blocks = len(blocks) 48 | 49 | in_sizes = [block.in_size() for block in self.blocks] 50 | # NOTE(quill): `int(idx)` is needed else `eqx.filter_{...}` doesn't filter out these indices 51 | # (Since I've declared `split_indices` as static this isn't necessary, but there's no true cost 52 | # to keeping.) 53 | self.split_indices = _to_int_list(np.cumsum(in_sizes[:-1])) 54 | 55 | def mv(self, x): 56 | chunks = jnp.split(x, self.split_indices, axis=-1) 57 | results = [op.mv(xi) for op, xi in zip(self.blocks, chunks)] 58 | return jnp.concatenate(results, axis=-1) 59 | 60 | def as_matrix(self): 61 | """uses output dtype 62 | 63 | not meant to be efficient. 64 | """ 65 | # dtype = self.blocks[0].out_structure().dtype 66 | # zeros_block = jnp.zeros((self._out_size, self._in_size), dtype=dtype) 67 | # n, m = 0, 0 68 | # for i in range(self.num_blocks): 69 | # ni, mi = self._in_sizes[i], self._out_sizes[i] 70 | # zeros_block.at[m:m+mi, n:n+ni].set(self.blocks[i].as_matrix()) 71 | # n += ni 72 | # m += mi 73 | raise NotImplementedError("`_BlockLinearOperator`'s `as_matrix` is not implemented.") 74 | 75 | def transpose(self): 76 | return _BlockLinearOperator([block.T for block in self.blocks]) 77 | 78 | def in_structure(self): 79 | if len(self.blocks[0].in_structure().shape) == 2: 80 | num_batches = self.blocks[0].in_structure().shape[0] 81 | idx = 1 82 | else: 83 | num_batches = 0 84 | idx = 0 85 | in_size = 0 86 | for block in self.blocks: 87 | in_size += block.in_structure().shape[idx] 88 | dtype = self.blocks[0].in_structure().dtype 89 | in_shape = (num_batches, in_size) if num_batches > 0 else (in_size,) 90 | return ShapeDtypeStruct(shape=in_shape, dtype=dtype) 91 | 92 | def out_structure(self): 93 | if len(self.blocks[0].out_structure().shape) == 2: 94 | num_batches = self.blocks[0].out_structure().shape[0] 95 | idx = 1 96 | else: 97 | num_batches = 0 98 | idx = 0 99 | out_size = 0 100 | for block in self.blocks: 101 | out_size += block.out_structure().shape[idx] 102 | dtype = self.blocks[0].out_structure().dtype 103 | in_shape = (num_batches, out_size) if num_batches > 0 else (out_size,) 104 | return ShapeDtypeStruct(shape=in_shape, dtype=dtype) 105 | 106 | @lx.is_symmetric.register(_BlockLinearOperator) 107 | def _(op): 108 | return all(lx.is_symmetric(block) for block in op.blocks) 109 | 110 | @lx.conj.register(_BlockLinearOperator) 111 | def _(op): 112 | return op -------------------------------------------------------------------------------- /tests/test_qcp_cpu_analytical.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import jax 5 | jax.config.update("jax_platform_name", "cpu") 6 | import jax.numpy as jnp 7 | import scipy.linalg as la 8 | import jax.random as jr 9 | import cvxpy as cvx 10 | import equinox as eqx 11 | 12 | from diffqcp import HostQCP, QCPStructureCPU 13 | from .helpers import (quad_data_and_soln_from_qcp_coo as quad_data_and_soln_from_qcp, 14 | scoo_to_bcoo, QCPProbData, get_zeros_like_coo) 15 | 16 | # TODO(quill): configure so don't run GPU tests when no GPU present 17 | # => does require utilizing BCOO vs. BCSR matrices, so probably 18 | # have to create different tests. 19 | 20 | def test_least_squares_cpu(getkey): 21 | """ 22 | The least squares (approximation) problem 23 | 24 | minimize ||Ax - b||^2, 25 | 26 | <=> 27 | 28 | minimize ||r||^2 29 | subject to r = Ax - b, 30 | 31 | where A is a (m x n)-matrix with rank A = n, has 32 | the analytical solution 33 | 34 | x^star = (A^T A)^-1 A^T b. 35 | 36 | Considering x^star as a function of b, we know 37 | 38 | Dx^star(b) = (A^T A)^-1 A^T. 39 | 40 | This test checks the accuracy of `diffqcp`'s derivative computations by 41 | comparing DS(Data)dData to Dx^star(b)db. 42 | 43 | **Notes:** 44 | - `dData == (0, 0, 0, db)`, and other canonicalization considerations must be made 45 | (hence the `data_and_soln_from_cvxpy_problem` function call and associated data declaration.) 46 | """ 47 | 48 | # TODO(quill): update the testing to follow best practices 49 | 50 | np.random.seed(0) 51 | 52 | for _ in range(10): 53 | np.random.seed(0) 54 | n = np.random.randint(low=10, high=15) 55 | m = n + np.random.randint(low=5, high=15) 56 | # n = np.random.randint(low=1_000, high=1_500) 57 | # m = n + np.random.randint(low=500, high=1_000) 58 | 59 | A = np.random.randn(m, n) 60 | b = np.random.randn(m) 61 | 62 | x = cvx.Variable(n) 63 | r = cvx.Variable(m) 64 | f0 = cvx.sum_squares(r) 65 | problem = cvx.Problem(cvx.Minimize(f0), [r == A@x - b]) 66 | 67 | data = QCPProbData(problem) 68 | 69 | P = scoo_to_bcoo(data.Pcoo) 70 | Pupper = scoo_to_bcoo(data.Pupper_coo) 71 | A_orig = A 72 | A = scoo_to_bcoo(data.Acoo) 73 | q = jnp.array(data.q) 74 | b_orig = b 75 | b = jnp.array(data.b) 76 | x = jnp.array(data.x) 77 | y = jnp.array(data.y) 78 | s = jnp.array(data.s) 79 | 80 | qcp_struc = QCPStructureCPU(Pupper, A, data.scs_cones) 81 | qcp = HostQCP(P, A, q, b, x, y, s, qcp_struc) 82 | 83 | print("N = ", qcp_struc.N) 84 | print("n = ", qcp_struc.n) 85 | print("m = ", qcp_struc.m) 86 | 87 | dP = get_zeros_like_coo(data.Pupper_coo) 88 | dP = scoo_to_bcoo(dP) 89 | dA = get_zeros_like_coo(data.Acoo) 90 | dA = scoo_to_bcoo(dA) 91 | assert b_orig.size == b.size 92 | np.testing.assert_allclose(-b_orig, b) # sanity check 93 | db = 1e-6 * jr.normal(getkey(), shape=jnp.size(b)) 94 | dq = jnp.zeros_like(q) 95 | 96 | Dx_b = jnp.array(la.solve(A_orig.T @ A_orig, A_orig.T)) 97 | 98 | # start = time.perf_counter() 99 | # dx, dy, ds = qcp.jvp(dP, dA, dq, -db) 100 | # tol = jnp.abs(dx) 101 | # end = time.perf_counter() 102 | # print(f"compile + solve time = {end - start}..") 103 | 104 | true_result = Dx_b @ db 105 | 106 | # patdb.debug() 107 | 108 | # assert jnp.allclose(true_result, dx[m:], atol=1e-8) 109 | 110 | # assert False # DEBUG 111 | 112 | def is_array_and_dtype(dtype): 113 | def _predicate(x): 114 | return isinstance(x, jax.Array) and jnp.issubdtype(x.dtype, dtype) 115 | return _predicate 116 | 117 | # Partition qcp into (traced, static) components 118 | qcp_traced, qcp_static = eqx.partition(qcp, is_array_and_dtype(jnp.floating)) 119 | 120 | # Partition inputs similarly 121 | jvp_inputs = (dP, dA, dq, -db) 122 | inputs_traced, inputs_static = eqx.partition(jvp_inputs, is_array_and_dtype(jnp.floating)) 123 | 124 | # Define a wrapper that takes only the traced inputs 125 | def jvp_wrapped(qcp_traced, inputs_traced): 126 | # Recombine with the static parts 127 | qcp_full = eqx.combine(qcp_traced, qcp_static) 128 | inputs_full = eqx.combine(inputs_traced, inputs_static) 129 | return qcp_full.jvp(*inputs_full) 130 | 131 | # Compile it 132 | jvp_compiled = eqx.filter_jit(jvp_wrapped) 133 | 134 | # print out static vs traced inputs 135 | 136 | # Call it 137 | start = time.perf_counter() 138 | dx, dy, ds = jvp_compiled(qcp_traced, inputs_traced) 139 | tol = np.asarray(dx) 140 | end = time.perf_counter() 141 | print(f"compile + solve time = {end - start}..") 142 | 143 | start = time.perf_counter() 144 | dx, dy, ds = jvp_compiled(qcp_traced, inputs_traced) 145 | tol = np.asarray(dx) 146 | end = time.perf_counter() 147 | print(f"solve only time = {end - start}..") 148 | 149 | # dx, dy, ds = jvp(dP, dA, dq, -db) 150 | 151 | true_result = Dx_b @ db 152 | 153 | print("true result shape: ", jnp.shape(true_result)) 154 | print("dx shape: ", jnp.shape(dx[m:])) 155 | 156 | assert jnp.allclose(true_result, dx[m:], atol=1e-8) -------------------------------------------------------------------------------- /experiments/cpu_baseline.py: -------------------------------------------------------------------------------- 1 | """Clarabel+diffcp learning loop for paper experiment.""" 2 | 3 | import time 4 | from dataclasses import dataclass, field 5 | import os 6 | 7 | import numpy as np 8 | from numpy import ndarray 9 | import scipy.linalg as la 10 | from scipy.sparse import (spmatrix, sparray, csr_matrix, 11 | csr_array, coo_matrix, coo_array, 12 | csc_matrix, csc_array) 13 | import cvxpy as cvx 14 | from diffcp import solve_and_derivative 15 | from jaxtyping import Float 16 | 17 | import experiments.cvx_problem_generator as prob_generator 18 | import patdb 19 | import matplotlib.pyplot as plt 20 | 21 | type SP = spmatrix | sparray 22 | type SCSR = csr_matrix | csr_array 23 | type SCSC = csc_matrix | csc_array 24 | type SCOO = coo_matrix | coo_array 25 | 26 | @dataclass 27 | class CPProbData: 28 | """(linear) Cone Program (CP) problem data.""" 29 | 30 | problem: cvx.Problem 31 | 32 | Acsc: SCSC = field(init=False) 33 | c: np.ndarray = field(init=False) 34 | b: np.ndarray = field(init=False) 35 | scs_cones: dict[int, int | list[int] | list[float]] = field(init=False) 36 | n: int = field(init=False) 37 | m: int = field(init=False) 38 | 39 | def __post_init__(self): 40 | 41 | probdata, _, _ = self.problem.get_problem_data(cvx.CLARABEL, ignore_dpp=True, solver_opts={'use_quad_obj': False}) 42 | self.A = probdata["A"].tocsc() 43 | self.c, self.b = probdata["c"], probdata["b"] 44 | self.scs_cones = cvx.reductions.solvers.conic_solvers.scs_conif.dims_to_solver_dict(probdata["dims"]) 45 | self.n = np.size(self.c) 46 | self.m = np.size(self.b) 47 | 48 | 49 | def f0( 50 | target_x: Float[ndarray, " n"], 51 | target_y: Float[ndarray, " m"], 52 | target_s: Float[ndarray, " m"], 53 | x: Float[ndarray, " n"], 54 | y: Float[ndarray, " m"], 55 | s: Float[ndarray, " m"] 56 | ) -> float: 57 | return (0.5 * la.norm(x - target_x)**2 + 0.5 * la.norm(y - target_y)**2 58 | + 0.5 * la.norm(s - target_s)**2) 59 | 60 | 61 | def grad_desc( 62 | prob_data: CPProbData, 63 | target_x, 64 | target_y, 65 | target_s, 66 | num_iter: int=500, 67 | step_size: float = 1e-5 68 | ): 69 | curr_iter = 0 70 | losses = [] 71 | 72 | while curr_iter < num_iter: 73 | 74 | xk, yk, sk, _, DT = solve_and_derivative(prob_data.A, 75 | prob_data.b, 76 | prob_data.c, 77 | prob_data.scs_cones, 78 | solve_method="CLARABEL") 79 | losses.append(f0(target_x, target_y, target_s, xk, yk, sk)) 80 | 81 | dA, db, dc = DT(xk - target_x, yk - target_y, sk - target_s) 82 | 83 | prob_data.A += -step_size * dA 84 | prob_data.c += -step_size * dc 85 | prob_data.b += -step_size * db 86 | 87 | curr_iter += 1 88 | 89 | return losses 90 | 91 | 92 | if __name__ == "__main__": 93 | 94 | np.random.seed(28) 95 | 96 | # SMALL 97 | m = 20 98 | n = 10 99 | # MEDIUM-ish 100 | # m = 200 101 | # n = 100 102 | # LARGE-ish 103 | # m = 2_000 104 | # n = 1_000 105 | start_time = time.perf_counter() 106 | # target_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 107 | target_problem = prob_generator.generate_group_lasso_logistic(m=m, n=m) 108 | prob_data = CPProbData(target_problem) 109 | end_time = time.perf_counter() 110 | print("Time to generate the target problem and" 111 | + f" canonicalize it: {end_time - start_time} seconds") 112 | 113 | start_time = time.perf_counter() 114 | target_x, target_y, target_s, _, DT = solve_and_derivative(prob_data.A, 115 | prob_data.b, 116 | prob_data.c, 117 | prob_data.scs_cones, 118 | solve_method="CLARABEL") 119 | end_time = time.perf_counter() 120 | print("Time to solve target problem + precompute some derivative info:" 121 | + f" {end_time - start_time} seconds.") 122 | 123 | fake_x = 1e-3 * np.arange(np.size(prob_data.c), dtype=prob_data.c.dtype) 124 | fake_y = 1e-3 * np.arange(np.size(prob_data.b), dtype=prob_data.b.dtype) 125 | fake_s = 1e-3 * np.arange(np.size(prob_data.b), dtype=prob_data.b.dtype) 126 | 127 | start_time = time.perf_counter() 128 | _, _, _ = DT(fake_x - target_x, 129 | fake_y - target_y, 130 | fake_s - target_s) 131 | end_time = time.perf_counter() 132 | print("Time to do the main diffcp computations:" 133 | + f" {end_time - start_time}") 134 | 135 | start_time = time.perf_counter() 136 | # initial_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 137 | initial_problem = prob_generator.generate_group_lasso_logistic(m=m, n=m) 138 | prob_data = CPProbData(initial_problem) 139 | end_time = time.perf_counter() 140 | print("Time to generate the initial (starting point) problem and" 141 | + f" canonicalize it: {end_time - start_time} seconds") 142 | print(f"Canonicalized n is: {prob_data.n}") 143 | print(f"Canonicalized m is: {prob_data.m}") 144 | 145 | # num_iter = 5 146 | # num_iter=100 147 | num_iter= 100 148 | print("starting loop:") 149 | start_time = time.perf_counter() 150 | losses = grad_desc(prob_data, target_x, target_y, target_s, num_iter) 151 | end_time = time.perf_counter() 152 | print("Learning loop time: ", end_time - start_time) 153 | print(f"Avg. iteration (solve + VJP) time: {(end_time - start_time) / num_iter}") 154 | print("starting loss: ", losses[0]) 155 | print("final loss: ", losses[-1]) 156 | 157 | plt.figure(figsize=(8, 6)) 158 | plt.plot(range(num_iter), losses, label="Objective Trajectory") 159 | plt.xlabel("num. iterations") 160 | plt.ylabel("Objective function") 161 | plt.legend() 162 | plt.title(label="diffcp") 163 | results_dir = os.path.join(os.path.dirname(__file__), "results") 164 | if n > 999: 165 | output_path = os.path.join(results_dir, "diffcp_logistic_lasso_large.svg") 166 | else: 167 | output_path = os.path.join(results_dir, "diffcp_logistic_lasso_small.svg") 168 | 169 | plt.savefig(output_path, format="svg") 170 | plt.close() 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import numpy as np 4 | from scipy import sparse 5 | from scipy.sparse import (spmatrix, sparray, csr_matrix, 6 | csr_array, coo_matrix, coo_array, 7 | csc_matrix, csc_array) 8 | import cvxpy as cvx 9 | import clarabel 10 | import jax 11 | import jax.numpy as jnp 12 | import equinox as eqx 13 | from jaxtyping import Float, Array 14 | from jax.experimental.sparse import BCSR, BCOO 15 | import patdb 16 | 17 | CPU = jax.devices("cpu")[0] 18 | 19 | type SP = spmatrix | sparray 20 | type SCSR = csr_matrix | csr_array 21 | type SCSC = csc_matrix | csc_array 22 | type SCOO = coo_matrix | coo_array 23 | 24 | def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8): 25 | return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol) 26 | 27 | 28 | def get_cpu_int(a: Float[Array, " 1"]): 29 | return int(jnp.squeeze(jax.device_put(a, device=CPU))) 30 | 31 | 32 | def scoo_to_bcoo(coo_mat: SCOO) -> BCOO: 33 | """just assume `coo_mat` is in correct form.""" 34 | row_indices = coo_mat.row 35 | col_indices = coo_mat.col 36 | indices = list(zip(row_indices, col_indices)) 37 | if len(indices) == 0: 38 | return BCOO.fromdense(jnp.zeros(coo_mat.shape)) 39 | else: 40 | return BCOO((coo_mat.data, indices), shape=coo_mat.shape) 41 | 42 | 43 | def scsr_to_bcsr(csr_mat: SCSR) -> BCSR: 44 | if len(csr_mat.data) > 0: 45 | return BCSR((csr_mat.data, csr_mat.indices, csr_mat.indptr), 46 | shape=csr_mat.shape) 47 | else: 48 | return BCSR.fromdense(jnp.zeros(csr_mat.shape)) 49 | 50 | 51 | def quad_data_and_soln_from_qcp(problem: cvx.Problem, return_csr: bool = True): 52 | """ 53 | note that we could grab the qcp problem data in a linear canonical form. 54 | """ 55 | clarabel_probdata, _, _ = problem.get_problem_data(cvx.CLARABEL) 56 | 57 | Pfull = clarabel_probdata['P'] 58 | P_upper = sparse.triu(Pfull).tocsc() 59 | A = clarabel_probdata['A'] 60 | q = clarabel_probdata['c'] 61 | b = clarabel_probdata['b'] 62 | 63 | clarabel_cones = cvx.reductions.solvers.conic_solvers.clarabel_conif.dims_to_solver_cones(clarabel_probdata["dims"]) 64 | scs_cone_dict = cvx.reductions.solvers.conic_solvers.scs_conif.dims_to_solver_dict(clarabel_probdata["dims"]) 65 | 66 | solver_settings = clarabel.DefaultSettings() 67 | solver_settings.verbose = False 68 | solver = clarabel.DefaultSolver(P_upper, q, A, b, clarabel_cones, solver_settings) 69 | soln = solver.solve() 70 | 71 | if return_csr: 72 | Pfull = Pfull.tocsr() 73 | P_upper = P_upper.tocsr() 74 | A = A.tocsr() 75 | else: 76 | Pfull = Pfull.tocoo() 77 | P_upper = P_upper.tocoo() 78 | A = A.tocoo() 79 | 80 | return Pfull, P_upper, A, q, b, np.array(soln.x), np.array(soln.z), np.array(soln.s), scs_cone_dict, clarabel_cones 81 | 82 | quad_data_and_soln_from_qcp_coo = lambda prob: quad_data_and_soln_from_qcp(prob, return_csr=False) 83 | quad_data_and_soln_from_qcp_csr = lambda prob: quad_data_and_soln_from_qcp(prob, return_csr=True) 84 | 85 | @dataclass 86 | class QCPProbData: 87 | 88 | problem: cvx.Problem 89 | 90 | Pcsc: SCSC = field(init=False) 91 | Pcsr: SCSR = field(init=False) 92 | Pcoo: SCOO = field(init=False) 93 | 94 | Pupper_csc: SCSC = field(init=False) 95 | Pupper_csr: SCSR = field(init=False) 96 | Pupper_coo: SCOO = field(init=False) 97 | 98 | Acsc: SCSC = field(init=False) 99 | Acsr: SCSR = field(init=False) 100 | Acoo: SCOO = field(init=False) 101 | 102 | q: np.ndarray = field(init=False) 103 | b: np.ndarray = field(init=False) 104 | 105 | n: np.ndarray = field(init=False) 106 | m: np.ndarray = field(init=False) 107 | 108 | x: np.ndarray = field(init=False) 109 | y: np.ndarray = field(init=False) 110 | s: np.ndarray = field(init=False) 111 | 112 | scs_cones: dict[int, int | list[int] | list[float]] = field(init=False) 113 | clarabel_cones: list = field(init=False) 114 | 115 | def __post_init__(self): 116 | """ 117 | 118 | **Note** 119 | - `get_problem_data` seems to be returning CSR matrices/arrays. 120 | - for `P`, it returns the whole array, not just the upper triangular part. 121 | To check this consider the following example: 122 | 123 | ```python 124 | import cvxpy as cvx 125 | import numpy as np 126 | 127 | # Define problem data 128 | Q = np.array([[4.0, 1.0, 0.5], 129 | [1.0, 3.0, 1.5], 130 | [0.5, 1.5, 2.0]]) # Non-diagonal and symmetric 131 | 132 | c = np.array([-1.0, 0.0, 1.0]) 133 | A = np.array([[1.0, 2.0, 3.0]]) 134 | b = np.array([1.0]) 135 | 136 | # Variable 137 | x = cvx.Variable(3) 138 | 139 | # Objective (standard QP form) 140 | objective = cvx.Minimize(0.5 * cvx.quad_form(x, Q) + c @ x) 141 | 142 | # Constraints 143 | constraints = [A @ x <= b] 144 | 145 | # Problem 146 | prob = cvx.Problem(objective, constraints) 147 | ``` 148 | """ 149 | clarabel_probdata, _, _ = self.problem.get_problem_data(cvx.CLARABEL, ignore_dpp=True, solver_opts={'use_quad_obj': True}) 150 | 151 | # Always get q and n first, since we need n for shape 152 | self.q = clarabel_probdata["c"] 153 | self.n = np.size(self.q) 154 | self.b = clarabel_probdata["b"] 155 | self.m = np.size(self.b) 156 | 157 | # Handle P (quadratic term) possibly missing 158 | if "P" in clarabel_probdata: 159 | self.Pcsr = clarabel_probdata["P"].tocsr() 160 | self.Pcsc = self.Pcsr.tocsc() 161 | self.Pcoo = self.Pcsr.tocoo() 162 | self.Pupper_csr = sparse.triu(self.Pcsr).tocsr() 163 | self.Pupper_csc = self.Pupper_csr.tocsc() 164 | self.Pupper_coo = self.Pupper_csr.tocoo() 165 | else: 166 | # Create zero matrices of shape (n, n) 167 | P = np.zeros((self.n, self.n)) # NOTE(quill): hack for now 168 | self.Pcsr = sparse.csr_matrix(P) 169 | self.Pcsc = self.Pcsr.tocsc() 170 | self.Pcoo = self.Pcsr.tocoo() 171 | self.Pupper_csr = self.Pcsr.copy() 172 | self.Pupper_csc = self.Pcsc.copy() 173 | self.Pupper_coo = self.Pcoo.copy() 174 | 175 | self.Acsr = clarabel_probdata["A"].tocsr() 176 | self.Acsc = self.Acsr.tocsc() 177 | self.Acoo = self.Acsr.tocoo() 178 | 179 | # NOTE(quill): that both reductions use the clarabel problem data 180 | # (So we are getting clarabel canonical form cones, but in scs dict on second line down.) 181 | self.clarabel_cones = cvx.reductions.solvers.conic_solvers.clarabel_conif.dims_to_solver_cones(clarabel_probdata["dims"]) 182 | self.scs_cones = cvx.reductions.solvers.conic_solvers.scs_conif.dims_to_solver_dict(clarabel_probdata["dims"]) 183 | 184 | solver_settings = clarabel.DefaultSettings() 185 | solver_settings.verbose = False 186 | solver = clarabel.DefaultSolver(self.Pupper_csc, 187 | self.q, 188 | self.Acsc, 189 | self.b, 190 | self.clarabel_cones, 191 | solver_settings) 192 | soln = solver.solve() 193 | self.x = np.array(soln.x) 194 | self.y = np.array(soln.z) 195 | self.s = np.array(soln.s) 196 | 197 | def get_zeros_like_coo(A: SCOO): 198 | return coo_array((np.zeros(A.size), A.nonzero()), shape=A.shape) 199 | 200 | def get_zeros_like_csr(A: SCSR): 201 | return csr_array((np.zeros(np.size(A.data)), A.indices, A.indptr), shape=A.shape, dtype=A.dtype) -------------------------------------------------------------------------------- /experiments/cvx_problem_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cvxpy as cvx 3 | import scipy.sparse as sparse 4 | import scipy.linalg as la 5 | 6 | def randn_symm(n, random_array): 7 | A = random_array(n, n) 8 | return (A + A.T) / 2 9 | 10 | 11 | def generate_sdp(n, p) -> cvx.Problem: 12 | """ 13 | Taken from https://www.cvxpy.org/examples/basic/sdp.html. 14 | """ 15 | C = randn_symm(n, np.random.randn) 16 | A = [] 17 | b = [] 18 | for _ in range(p): 19 | # A.append(np.random.randn(n, n)) 20 | Ai = randn_symm(n, np.random.randn) 21 | A.append(Ai) 22 | b.append(float(np.random.randn())) 23 | 24 | # Define and solve the CVXPY problem. 25 | # Create a symmetric matrix variable. 26 | X = cvx.Variable((n,n), symmetric=True) 27 | # The operator >> denotes matrix inequality. 28 | constraints = [X >> 0] 29 | constraints += [ 30 | cvx.trace(A[i] @ X) == b[i] for i in range(p) 31 | ] 32 | prob = cvx.Problem(cvx.Minimize(cvx.trace(C @ X)), 33 | constraints) 34 | return prob 35 | 36 | 37 | def generate_feasible_sdp(n, p, rank=3): 38 | # Step 1: Pick a random feasible X* 39 | Z = np.random.randn(n, n) 40 | X_star = Z @ Z.T # PSD and full rank 41 | 42 | # Step 2: Generate A_i and b_i consistent with X* 43 | A = [randn_symm(n, np.random.randn) for _ in range(p)] 44 | b = [float(np.trace(Ai @ X_star)) for Ai in A] 45 | 46 | # Step 3: Random objective matrix 47 | C = randn_symm(n, np.random.randn) 48 | 49 | # Step 4: Build the CVXPY problem 50 | X = cvx.Variable((n, n), symmetric=True) 51 | constraints = [X >> 0] 52 | constraints += [cvx.trace(A[i] @ X) == b[i] for i in range(p)] 53 | prob = cvx.Problem(cvx.Minimize(cvx.trace(C @ X)), constraints) 54 | return prob 55 | 56 | 57 | def generate_portfolio_problem(n) -> cvx.Problem: 58 | mu = cvx.Parameter(n) 59 | mu.value = np.random.randn(n) 60 | Sigma = np.random.randn(n, n) 61 | Sigma = Sigma.T @ Sigma 62 | Sigma_sqrt = cvx.Parameter((n, n)) 63 | w = cvx.Variable((n, 1)) 64 | gamma = 3.43046929e+01 # fix the risk-aversion parameter. 65 | ret = mu.T @ w 66 | # risk = cvx.quad_form(w, Sigma) 67 | risk = cvx.sum_squares(Sigma_sqrt @ w) 68 | Sigma_sqrt.value = la.sqrtm(Sigma) 69 | problem = cvx.Problem(cvx.Maximize(ret - gamma * risk), [cvx.sum(w) == 1, w >= 0]) 70 | 71 | return problem 72 | 73 | 74 | def generate_least_squares_eq(m, n) -> cvx.Problem: 75 | """Generate a conic problem with unique solution. 76 | Taken from diffcp. 77 | """ 78 | assert m >= n 79 | x = cvx.Variable(n) 80 | b = cvx.Parameter(m) 81 | b.value = np.random.randn(m) 82 | A = cvx.Parameter((m, n)) 83 | A.value = np.random.randn(m, n) 84 | assert np.linalg.matrix_rank(A.value) == n 85 | # objective = cvx.pnorm(A @ x - b, 2) 86 | objective = cvx.sum_squares(A@x - b) 87 | constraints = [x >= 0, cvx.sum(x) == 1.0] 88 | problem = cvx.Problem(cvx.Minimize(objective), constraints) 89 | assert problem.is_dpp() 90 | return problem 91 | 92 | 93 | def generate_LS_problem(m, n) -> cvx.Problem: 94 | A = np.random.randn(m, n) 95 | b = np.random.randn(m) 96 | 97 | x = cvx.Variable(n) 98 | r = cvx.Variable(m) 99 | f0 = cvx.sum_squares(r) 100 | problem = cvx.Problem(cvx.Minimize(f0), [r == A@x - b]) 101 | return problem 102 | 103 | 104 | def sigmoid(z): 105 | return 1/(1 + np.exp(-z)) 106 | 107 | def generate_group_lasso_logistic(n: int, m: int) -> cvx.Problem: 108 | X = np.random.randn(m, 10 * n) 109 | true_beta = np.zeros(10 * n) 110 | true_beta[:10 * n // 100] = 1.0 111 | y = np.round(sigmoid(X @ true_beta + np.random.randn(m)*0.5)) 112 | 113 | beta = cvx.Variable(10 * n) 114 | lambd = 0.1 115 | loss = -cvx.sum(cvx.multiply(y, X @ beta) - cvx.logistic(X @ beta)) 116 | reg = lambd * cvx.sum( cvx.norm( beta.reshape((-1, 10), 'C'), axis=1 ) ) 117 | 118 | prob = cvx.Problem(cvx.Minimize(loss + reg)) 119 | 120 | return prob 121 | 122 | def generate_group_lasso(n: int, m: int) -> cvx.Problem: 123 | X = cvx.Parameter((m, 10*n)) 124 | X.value = np.random.randn(m, 10 * n) 125 | true_beta = np.zeros(10 * n) 126 | true_beta[:10 * n // 100] = 1.0 127 | y = X @ true_beta + np.random.randn(m)*0.5 128 | 129 | beta = cvx.Variable(10 * n) 130 | lambd = cvx.Parameter(pos=True) 131 | lambd.value = 0.1 132 | loss = cvx.sum_squares(y - X @ beta) 133 | reg = lambd * cvx.sum( cvx.norm( beta.reshape((-1, 10), 'C'), axis=1 ) ) 134 | 135 | prob = cvx.Problem(cvx.Minimize(loss + reg)) 136 | 137 | assert prob.is_dpp() 138 | 139 | return prob 140 | 141 | def generate_robust_mvdr_beamformer(n: int) -> cvx.Problem: 142 | """`n` is the number of sensors.""" 143 | 144 | w = cvx.Variable((n, 1), complex=True) 145 | 146 | Sigma = np.random.randn(n, n) + 1j * np.random.randn(n, n) 147 | Sigma = Sigma @ Sigma.conj().T + 0.1 * np.eye(n) # Make Hermitian PSD 148 | Sigma_sqrt = cvx.Parameter((n, n), complex=True) 149 | Sigma_sqrt.value = la.sqrtm(Sigma) 150 | 151 | a_hat = 5 * np.random.randn(n) + 1j * np.random.randn(n) # Fake array manifold/response 152 | P = cvx.Parameter((n, n), complex=True) # uncertainty matrix 153 | P.value = np.random.randn(n, n) + 1j * np.random.randn(n, n) 154 | 155 | # f0 = cvx.real(cvx.sum_squares(Sigma_sqrt @ w)) 156 | f0 = cvx.sum_squares(Sigma_sqrt @ w) 157 | obj = cvx.Minimize(f0) 158 | 159 | gamma = 1.0 # desired signal constraint 160 | delta = 0.5 # uncertainty size 161 | 162 | constraints = [ 163 | cvx.real(a_hat.conj().T @ w) >= gamma, 164 | cvx.norm(P.conj().T @ w, 2) <= delta 165 | ] 166 | 167 | prob = cvx.Problem(obj, constraints) 168 | assert prob.is_dpp() 169 | return prob 170 | 171 | 172 | def generate_kalman_smoother( 173 | random_inputs, random_noise, T: int=5, n: int=100 174 | ) -> cvx.Problem: 175 | """ 176 | `n` is number of time steps. 177 | `T` is time horizon. 178 | 179 | Largely taken from: https://www.cvxpy.org/examples/applications/robust_kalman.html. 180 | """ 181 | _, delt = np.linspace(0,T,n,endpoint=True, retstep=True) 182 | gamma = .05 # damping, 0 is no damping 183 | 184 | A = np.zeros((4,4)) 185 | B = np.zeros((4,2)) 186 | C = np.zeros((2,4)) 187 | 188 | A[0,0] = 1 189 | A[1,1] = 1 190 | A[0,2] = (1-gamma*delt/2)*delt 191 | A[1,3] = (1-gamma*delt/2)*delt 192 | A[2,2] = 1 - gamma*delt 193 | A[3,3] = 1 - gamma*delt 194 | 195 | B[0,0] = delt**2/2 196 | B[1,1] = delt**2/2 197 | B[2,0] = delt 198 | B[3,1] = delt 199 | 200 | C[0,0] = 1 201 | C[1,1] = 1 202 | 203 | x = np.zeros((4,n+1)) 204 | x[:,0] = [0,0,0,0] 205 | y = np.zeros((2,n)) 206 | 207 | # generate random input and noise vectors 208 | w = random_inputs 209 | v = random_noise 210 | 211 | # simulate the system forward in time 212 | for t in range(n): 213 | y[:,t] = C @ x[:,t] + v[:,t] 214 | x[:,t+1] = A @ x[:,t] + B @ w[:,t] 215 | 216 | x = cvx.Variable(shape=(4, n+1)) 217 | w = cvx.Variable(shape=(2, n)) 218 | v = cvx.Variable(shape=(2, n)) 219 | 220 | tau = cvx.Parameter(pos=True) 221 | tau.value = np.random.uniform(0.1, 5) 222 | 223 | obj = cvx.sum_squares(w) + tau*cvx.sum_squares(v) 224 | obj = cvx.Minimize(obj) 225 | 226 | constr = [] 227 | for t in range(n): 228 | constr += [ x[:,t+1] == A@x[:,t] + B@w[:,t] , 229 | y[:,t] == C@x[:,t] + v[:,t] ] 230 | 231 | prob = cvx.Problem(obj, constr) 232 | assert prob.is_dpp() 233 | return prob 234 | 235 | 236 | def generate_pow_projection_problem( 237 | n: int 238 | ) -> cvx.Problem: 239 | """Project x onto the product of 3D power cones with given alphas using CVXPY.""" 240 | assert n % 3 == 0 241 | x = np.random.randn(n) 242 | num_cones = n // 3 243 | var = cvx.Variable(n) 244 | constraints = [] 245 | for i in range(num_cones): 246 | alpha = np.maximum(np.random.rand(), 0.01) 247 | constraints.append(cvx.PowCone3D(var[3*i], var[3*i+1], var[3*i+2], alpha)) 248 | objective = cvx.Minimize(cvx.sum_squares(var - x)) 249 | prob = cvx.Problem(objective, constraints) 250 | return prob -------------------------------------------------------------------------------- /diffqcp/qcp_derivs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING 3 | 4 | from jax import ShapeDtypeStruct 5 | import jax.numpy as jnp 6 | import equinox as eqx 7 | import lineax as lx 8 | from lineax import AbstractLinearOperator 9 | from jaxtyping import Array, Float, Integer 10 | from jax.experimental.sparse import BCOO, BCSR 11 | 12 | from diffqcp.problem_data import ObjMatrix 13 | 14 | # NOTE(quill): the last bit of that would fail since `dtau * self.q` would be 1D array * 2D array 15 | # So I guess the somewhat challenging aspect of this is the fact that the first two bits 16 | # in the expression are fine, so we don't actually want to vmap those... 17 | # NOTE(quill): UPDATE. This is NOT TRUE. If in the batched case then `dtau` is also a 2D array! 18 | 19 | class _DuQAdjoint(AbstractLinearOperator): 20 | P: ObjMatrix 21 | Px: Float[Array, " n"] 22 | xTPx: Float[Array, ""] 23 | A: Float[BCOO | BCSR, "m n"] 24 | AT: Float[BCOO | BCSR, "n m"] 25 | q: Float[Array, " n"] 26 | b: Float[Array, " m"] 27 | x: Float[Array, " n"] 28 | tau: Float[Array, ""] 29 | n: int = eqx.field(static=True) 30 | m: int = eqx.field(static=True) 31 | 32 | def mv(self, dv): 33 | dv1: Float[Array, " n"] = dv[:self.n] 34 | dv2: Float[Array, " m"] = dv[self.n:-1] 35 | dv3: Float[Array, ""] = dv[-1] 36 | out1 = self.P.mv(dv1) - self.AT @ dv2 + ( -(2/self.tau) * self.Px - self.q) * dv3 37 | out2 = self.A @ dv1 - dv3 * self.b 38 | out3 = self.q @ dv1 + self.b @ dv2 + (1/self.tau**2) * dv3 * self.xTPx 39 | return jnp.concatenate([out1, out2, jnp.array([out3])]) 40 | 41 | def as_matrix(self): 42 | raise NotImplementedError(f"{self.__class__.__name__}'s `as_matrix` method is" 43 | + " not yet implemented.") 44 | 45 | def transpose(self) -> _DuQ: 46 | return _DuQ(self.P, self.Px, self.xTPx, self.A, self.AT, self.q, 47 | self.b, self.x, self.tau, self.n, self.m) 48 | 49 | def in_structure(self): 50 | return ShapeDtypeStruct(shape=(self.n + self.m + 1,), 51 | dtype=self.A.dtype) 52 | 53 | def out_structure(self): 54 | return self.in_structure() 55 | 56 | 57 | class _DuQ(AbstractLinearOperator): 58 | """ 59 | NOTE(quill): we know at compile time if this is batched or not. 60 | """ 61 | P: ObjMatrix 62 | Px: Float[Array, " n"] 63 | xTPx: Float[Array, ""] 64 | A: Float[BCOO | BCSR, "m n"] 65 | AT: Float[BCOO | BCSR, "n m"] 66 | q: Float[Array, " n"] 67 | b: Float[Array, " m"] 68 | x: Float[Array, " n"] 69 | tau: Float[Array, ""] 70 | n: int = eqx.field(static=True) 71 | m: int = eqx.field(static=True) 72 | 73 | def mv(self, du: Float[Array, " n+m+1"]): 74 | dx, dy, dtau = du[:self.n], du[self.n:-1], du[-1] 75 | Pdx = self.P.mv(dx) 76 | out1 = Pdx + self.AT @ dy + dtau * self.q 77 | out2 = self.A @ (-dx) + dtau * self.b 78 | out3 = ((-2/self.tau) * self.x @ Pdx - self.q @ dx - self.b @ dy 79 | + (1/self.tau**2) * dtau * self.xTPx) 80 | return jnp.concatenate([out1, out2, jnp.array([out3])]) 81 | 82 | def as_matrix(self): 83 | raise NotImplementedError(f"{self.__class__.__name__}'s `as_matrix` method is" 84 | + " not yet implemented.") 85 | 86 | def transpose(self) -> _DuQAdjoint: 87 | return _DuQAdjoint(self.P, self.Px, self.xTPx, self.A, self.AT, self.q, 88 | self.b, self.x, self.tau, self.n, self.m) 89 | 90 | def in_structure(self): 91 | return ShapeDtypeStruct(shape=(self.n + self.m + 1,), 92 | dtype=self.A.dtype) 93 | 94 | def out_structure(self): 95 | return self.in_structure() 96 | 97 | @lx.is_symmetric.register(_DuQAdjoint) 98 | def _(op): 99 | return False 100 | 101 | @lx.conj.register(_DuQAdjoint) 102 | def _(op): 103 | return op 104 | 105 | @lx.is_symmetric.register(_DuQ) 106 | def _(op): 107 | return False 108 | 109 | @lx.conj.register(_DuQ) 110 | def _(op): 111 | return op 112 | 113 | def _d_data_Q( 114 | x: Float[Array, " n"], 115 | y: Float[Array, " m"], 116 | tau: Float[Array, ""], 117 | dP: ObjMatrix, 118 | dA: Float[BCOO | BCSR, "m n"], 119 | dAT: Float[BCOO, BCSR, "n m"], 120 | dq: Float[Array, " n"], 121 | db: Float[Array, " m"] 122 | ) -> Float[Array, " n+m+1"]: 123 | """The Jacobian-vector product D_dataQ(u, data)[data]. 124 | 125 | More specifically, returns D_data Q(u, data)[d_data], where 126 | d_data = (dP, dA, dq, db), Q is the nonlinear homogeneous embedding 127 | and D_data is the derivative operator w.r.t. data = (P, A, q, b). 128 | 129 | u, dP, dA, dq, and db are the exact objects defined in the diffqcp paper. 130 | Specifically, note that dP should be the true perturbation to the matrix P, 131 | **not just the upper triangular part.** 132 | """ 133 | 134 | dPx = dP.mv(x) 135 | out1 = dPx + dAT @ y + tau * dq 136 | out2 = dA @ -x + tau * db 137 | out3 = -(1 / tau) * (x @ dPx) - dq @ x - db @ y 138 | 139 | return jnp.concatenate([out1, out2, jnp.array([out3])]) 140 | 141 | # NOTE(quill): what's going to happen when these get `jit`ted? 142 | 143 | 144 | def _adjoint_values( 145 | x: Float[Array, " n"], 146 | y: Float[Array, " m"], 147 | tau: Float[Array, ""], 148 | w1: Float[Array, " n"], 149 | w2: Float[Array, " m"], 150 | w3: Float[Array, ""], 151 | P_rows: Integer[Array, "..."], 152 | P_cols: Integer[Array, "..."], 153 | A_rows: Integer[Array, "..."], 154 | A_cols: Integer[Array, "..."], 155 | ) -> tuple[Float[Array, "..."], Float[Array, "..."], Float[Array, " n"], Float[Array, " m"]]: 156 | dP_values = (0.5 * ( w1[P_rows] * x[P_cols] + x[P_rows] * w1[P_cols] ) 157 | - (w3 / tau) * x[P_rows] * x[P_cols]) 158 | dA_values = y[A_rows] * w1[A_cols] - w2[A_rows] * x[A_cols] 159 | dq = tau * w1 - w3 * x 160 | db = tau * w2 - w3 * y 161 | 162 | return (dP_values, dA_values, dq, db) 163 | 164 | 165 | def _d_data_Q_adjoint_cpu( 166 | x: Float[Array, " n"], 167 | y: Float[Array, " m"], 168 | tau: Float[Array, ""], 169 | w1: Float[Array, " n"], 170 | w2: Float[Array, " m"], 171 | w3: Float[Array, ""], 172 | P_rows: Integer[Array, "..."], 173 | P_cols: Integer[Array, "..."], 174 | A_rows: Integer[Array, "..."], 175 | A_cols: Integer[Array, "..."], 176 | n: int, 177 | m: int 178 | ) -> tuple[ 179 | Float[BCOO, "n n"], Float[BCOO, "m n"], Float[Array, " n"], Float[Array, " m"] 180 | ]: 181 | """The vector-Jacobian product D_data(u, data)^T[w]. 182 | """ 183 | dP_values, dA_values, dq, db = _adjoint_values(x, y, tau, w1, w2, w3, 184 | P_rows, P_cols, A_rows, A_cols) 185 | 186 | P_indices = jnp.stack([P_rows, P_cols], axis=1) 187 | dP = BCOO((dP_values, P_indices), shape=(n, n)) 188 | A_indices = jnp.stack([A_rows, A_cols], axis=1) 189 | dA = BCOO((dA_values, A_indices), shape=(m, n)) 190 | 191 | return (dP, dA, dq, db) 192 | 193 | 194 | def _d_data_Q_adjoint_gpu( 195 | x: Float[Array, " n"], 196 | y: Float[Array, " m"], 197 | tau: Float[Array, ""], 198 | w1: Float[Array, " n"], 199 | w2: Float[Array, " m"], 200 | w3: Float[Array, ""], 201 | P_rows: Integer[Array, "..."], 202 | P_cols: Integer[Array, "..."], 203 | P_csr_indices: Integer[Array, "..."], 204 | P_csr_indtpr: Integer[Array, "..."], 205 | A_rows: Integer[Array, "..."], 206 | A_cols: Integer[Array, "..."], 207 | A_csr_indices: Integer[Array, "..."], 208 | A_csr_indtpr: Integer[Array, "..."], 209 | n: int, 210 | m: int 211 | ) -> tuple[ 212 | Float[BCSR, "n n"], Float[BCSR, "m n"], Float[Array, " n"], Float[Array, " m"] 213 | ]: 214 | """The vector-Jacobian product D_data(u, data)^T[w].""" 215 | dP_values, dA_values, dq, db = _adjoint_values(x, y, tau, w1, w2, w3, 216 | P_rows, P_cols, A_rows, A_cols) 217 | 218 | dP = BCSR((dP_values, P_csr_indices, P_csr_indtpr), shape=(n, n)) 219 | dA = BCSR((dA_values, A_csr_indices, A_csr_indtpr), shape=(m, n)) 220 | 221 | return (dP, dA, dq, db) 222 | 223 | -------------------------------------------------------------------------------- /tests/test_qcp_gpu_analytical.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import jax 5 | import jax.numpy as jnp 6 | import scipy.linalg as la 7 | import jax.random as jr 8 | import cvxpy as cvx 9 | import equinox as eqx 10 | 11 | try: 12 | from nvmath.sparse.advanced import DirectSolver 13 | except ImportError: 14 | DirectSolver = None 15 | 16 | from diffqcp import DeviceQCP, QCPStructureGPU 17 | from .helpers import (quad_data_and_soln_from_qcp_coo as quad_data_and_soln_from_qcp, 18 | scsr_to_bcsr, QCPProbData, get_zeros_like_csr) 19 | 20 | def test_least_squares(getkey): 21 | """ 22 | The least squares (approximation) problem 23 | 24 | minimize ||Ax - b||^2, 25 | 26 | <=> 27 | 28 | minimize ||r||^2 29 | subject to r = Ax - b, 30 | 31 | where A is a (m x n)-matrix with rank A = n, has 32 | the analytical solution 33 | 34 | x^star = (A^T A)^-1 A^T b. 35 | 36 | Considering x^star as a function of b, we know 37 | 38 | Dx^star(b) = (A^T A)^-1 A^T. 39 | 40 | This test checks the accuracy of `diffqcp`'s derivative computations by 41 | comparing DS(Data)dData to Dx^star(b)db. 42 | 43 | **Notes:** 44 | - `dData == (0, 0, 0, db)`, and other canonicalization considerations must be made 45 | (hence the `data_and_soln_from_cvxpy_problem` function call and associated data declaration.) 46 | """ 47 | 48 | # TODO(quill): update the testing to follow best practices 49 | 50 | np.random.seed(0) 51 | 52 | for i in range(10): 53 | print(f"iteration {i}") 54 | np.random.seed(0) 55 | n = np.random.randint(low=10, high=15) 56 | m = n + np.random.randint(low=5, high=15) 57 | # n = np.random.randint(low=1_000, high=1_500) 58 | # m = n + np.random.randint(low=500, high=1_000) 59 | 60 | A = np.random.randn(m, n) 61 | b = np.random.randn(m) 62 | 63 | x = cvx.Variable(n) 64 | r = cvx.Variable(m) 65 | f0 = cvx.sum_squares(r) 66 | problem = cvx.Problem(cvx.Minimize(f0), [r == A@x - b]) 67 | 68 | data = QCPProbData(problem) 69 | 70 | P = scsr_to_bcsr(data.Pcsr) 71 | A_orig = A 72 | A = scsr_to_bcsr(data.Acsr) 73 | q = jnp.array(data.q) 74 | b_orig = b 75 | b = jnp.array(data.b) 76 | x = jnp.array(data.x) 77 | y = jnp.array(data.y) 78 | s = jnp.array(data.s) 79 | 80 | qcp_struc = QCPStructureGPU(P, A, data.scs_cones) 81 | qcp = DeviceQCP(P, A, q, b, x, y, s, qcp_struc) 82 | 83 | dP = get_zeros_like_csr(data.Pcsr) 84 | dP = scsr_to_bcsr(dP) 85 | dA = get_zeros_like_csr(data.Acsr) 86 | dA = scsr_to_bcsr(dA) 87 | assert b_orig.size == b.size 88 | np.testing.assert_allclose(-b_orig, b) # sanity check 89 | db = jr.normal(getkey(), shape=jnp.size(b)) 90 | dq = jnp.zeros_like(q) 91 | 92 | Dx_b = jnp.array(la.solve(A_orig.T @ A_orig, A_orig.T)) 93 | 94 | true_result = Dx_b @ db 95 | 96 | # patdb.debug() 97 | 98 | # assert jnp.allclose(true_result, dx[m:], atol=1e-8) 99 | 100 | # assert False # DEBUG 101 | 102 | def is_array_and_dtype(dtype): 103 | def _predicate(x): 104 | return isinstance(x, jax.Array) and jnp.issubdtype(x.dtype, dtype) 105 | return _predicate 106 | 107 | # Partition qcp into (traced, static) components 108 | qcp_traced, qcp_static = eqx.partition(qcp, is_array_and_dtype(jnp.floating)) 109 | 110 | # Partition inputs similarly 111 | jvp_inputs = (dP, dA, dq, -db, "jax-lsmr") 112 | inputs_traced, inputs_static = eqx.partition(jvp_inputs, is_array_and_dtype(jnp.floating)) 113 | 114 | # Define a wrapper that takes only the traced inputs 115 | def jvp_wrapped(qcp_traced, inputs_traced): 116 | # Recombine with the static parts 117 | qcp_full = eqx.combine(qcp_traced, qcp_static) 118 | inputs_full = eqx.combine(inputs_traced, inputs_static) 119 | return qcp_full.jvp(*inputs_full) 120 | 121 | # Compile it 122 | jvp_compiled = eqx.filter_jit(jvp_wrapped) 123 | 124 | # print out static vs traced inputs 125 | 126 | # Call it 127 | start = time.perf_counter() 128 | dx, dy, ds = jvp_compiled(qcp_traced, inputs_traced) 129 | dx.block_until_ready() 130 | end = time.perf_counter() 131 | print(f"compile + solve time = {end - start}..") 132 | 133 | start = time.perf_counter() 134 | dx, dy, ds = jvp_compiled(qcp_traced, inputs_traced) 135 | # tol = jnp.abs(dx) 136 | dx.block_until_ready() 137 | end = time.perf_counter() 138 | print(f"solve only time = {end - start}..") 139 | 140 | true_result = Dx_b @ db 141 | 142 | print("true result shape: ", jnp.shape(true_result)) 143 | print("dx shape: ", jnp.shape(dx[m:])) 144 | 145 | print("SMALL TRUTH: ", Dx_b @ (1e-6 * db)) 146 | print("REAL TRUTH: ", true_result) 147 | print("COMPUTED: ", dx[m:]) 148 | 149 | assert jnp.allclose(dx[m:], true_result, atol=1e-6) 150 | 151 | def test_least_squares_direct_solve(getkey): 152 | """ 153 | The least squares (approximation) problem 154 | 155 | minimize ||Ax - b||^2, 156 | 157 | <=> 158 | 159 | minimize ||r||^2 160 | subject to r = Ax - b, 161 | 162 | where A is a (m x n)-matrix with rank A = n, has 163 | the analytical solution 164 | 165 | x^star = (A^T A)^-1 A^T b. 166 | 167 | Considering x^star as a function of b, we know 168 | 169 | Dx^star(b) = (A^T A)^-1 A^T. 170 | 171 | This test checks the accuracy of `diffqcp`'s derivative computations by 172 | comparing DS(Data)dData to Dx^star(b)db. 173 | 174 | **Notes:** 175 | - `dData == (0, 0, 0, db)`, and other canonicalization considerations must be made 176 | (hence the `data_and_soln_from_cvxpy_problem` function call and associated data declaration.) 177 | """ 178 | 179 | # NOTE(quill): this is a bit sloppy; asserting first device is a 180 | # gpu device. 181 | jax_gpu_enabled = jax.devices()[0].platform == "gpu" 182 | if DirectSolver is not None and jax_gpu_enabled: 183 | solvers = ["jax-lu", "nvmath-direct"] 184 | else: 185 | solvers = ["jax-lu"] 186 | 187 | for solve_method in solvers: 188 | np.random.seed(0) 189 | for i in range(10): 190 | print(f"== iteration {i} ===") 191 | print("!!! JAX devices: ", jax.devices()) 192 | np.random.seed(0) 193 | n = np.random.randint(low=10, high=15) 194 | m = n + np.random.randint(low=5, high=15) 195 | # n = np.random.randint(low=1_000, high=1_500) 196 | # m = n + np.random.randint(low=500, high=1_000) 197 | 198 | A = np.random.randn(m, n) 199 | b = np.random.randn(m) 200 | 201 | x = cvx.Variable(n) 202 | r = cvx.Variable(m) 203 | f0 = cvx.sum_squares(r) 204 | problem = cvx.Problem(cvx.Minimize(f0), [r == A@x - b]) 205 | 206 | data = QCPProbData(problem) 207 | 208 | P = scsr_to_bcsr(data.Pcsr) 209 | A_orig = A 210 | A = scsr_to_bcsr(data.Acsr) 211 | q = jnp.array(data.q) 212 | b_orig = b 213 | b = jnp.array(data.b) 214 | x = jnp.array(data.x) 215 | y = jnp.array(data.y) 216 | s = jnp.array(data.s) 217 | 218 | qcp_struc = QCPStructureGPU(P, A, data.scs_cones) 219 | qcp = DeviceQCP(P, A, q, b, x, y, s, qcp_struc) 220 | 221 | print("N = ", qcp_struc.N) 222 | print("n = ", qcp_struc.n) 223 | print("m = ", qcp_struc.m) 224 | 225 | dP = get_zeros_like_csr(data.Pcsr) 226 | dP = scsr_to_bcsr(dP) 227 | dA = get_zeros_like_csr(data.Acsr) 228 | dA = scsr_to_bcsr(dA) 229 | assert b_orig.size == b.size 230 | np.testing.assert_allclose(-b_orig, b) # sanity check 231 | db = jr.normal(getkey(), shape=jnp.size(b)) 232 | dq = jnp.zeros_like(q) 233 | 234 | Dx_b = jnp.array(la.solve(A_orig.T @ A_orig, A_orig.T)) 235 | 236 | true_result = Dx_b @ db 237 | 238 | dx, _, _ = qcp.jvp(dP, dA, dq, -db, solve_method=solve_method) 239 | 240 | print("true result shape: ", jnp.shape(true_result)) 241 | print("dx shape: ", jnp.shape(dx[m:])) 242 | 243 | assert jnp.allclose(dx[m:], true_result, atol=1e-8) -------------------------------------------------------------------------------- /experiments/cpu_experiment.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from dataclasses import dataclass 4 | import numpy as np 5 | import jax 6 | jax.config.update("jax_enable_x64", True) 7 | jax.config.update("jax_platform_name", "cpu") 8 | import jax.numpy as jnp 9 | import jax.numpy.linalg as la 10 | from jaxtyping import Float, Array 11 | import equinox as eqx 12 | from jax.experimental.sparse import BCOO 13 | import clarabel 14 | from scipy.sparse import (spmatrix, sparray, 15 | csc_matrix, csc_array) 16 | import patdb 17 | import matplotlib.pyplot as plt 18 | 19 | from diffqcp import HostQCP, QCPStructureCPU 20 | import experiments.cvx_problem_generator as prob_generator 21 | from tests.helpers import QCPProbData, scoo_to_bcoo 22 | 23 | type SP = spmatrix | sparray 24 | type SCSC = csc_matrix | csc_array 25 | 26 | @dataclass 27 | class SolverData: 28 | 29 | Pupper_csc: SCSC 30 | A: SCSC 31 | q: np.ndarray 32 | b: np.ndarray 33 | 34 | def compute_loss(target_x, target_y, target_s, x, y, s): 35 | return (0.5 * la.norm(x - target_x)**2 + 0.5 * la.norm(y - target_y)**2 36 | + 0.5 * la.norm(s - target_s)**2) 37 | 38 | @eqx.filter_jit 39 | @eqx.debug.assert_max_traces(max_traces=1) 40 | def make_step( 41 | qcp: HostQCP, 42 | target_x: Float[Array, " n"], 43 | target_y: Float[Array, " m"], 44 | target_s: Float[Array, " m"], 45 | Pdata: Float[Array, "..."], 46 | Adata: Float[Array, "..."], 47 | q: Float[Array, " n"], 48 | b: Float[Array, " m"], 49 | step_size: float 50 | ) -> tuple[Float[Array, ""], Float[Array, "..."], Float[Array, "..."], 51 | Float[Array, " n"], Float[Array, " m"]]: 52 | loss = compute_loss(target_x, target_y, target_s, qcp.x, qcp.y, qcp.s) 53 | dP, dA, dq, db = qcp.vjp(qcp.x - target_x, 54 | qcp.y - target_y, 55 | qcp.s - target_s) 56 | new_Pdata = Pdata - step_size * dP.data 57 | new_Adata = Adata - step_size * dA.data 58 | new_q = q - step_size * dq 59 | new_b = b - step_size * db 60 | return (loss, new_Pdata, new_Adata, new_q, new_b) 61 | 62 | def grad_desc( 63 | Pk: Float[BCOO, "n n"], 64 | Ak: Float[BCOO, "m n"], 65 | qk: Float[Array, " n"], 66 | bk: Float[Array, " m"], 67 | target_x: Float[Array, " n"], 68 | target_y: Float[Array, " m"], 69 | target_s: Float[Array, " m"], 70 | qcp_problem_structure: QCPStructureCPU, 71 | data: QCPProbData, 72 | Pcoo_csc_perm: Float[np.ndarray, "..."], 73 | Acoo_csc_perm: Float[np.ndarray, "..."], 74 | clarabel_solver, 75 | num_iter: int = 100, 76 | step_size = 1e-5 77 | ): 78 | curr_iter = 0 79 | losses = [] 80 | 81 | while curr_iter < num_iter: 82 | 83 | solution = clarabel_solver.solve() 84 | 85 | xk = jnp.array(solution.x) 86 | yk = jnp.array(solution.z) 87 | sk = jnp.array(solution.s) 88 | 89 | qcp = HostQCP(Pk, Ak, qk, bk, xk, yk, sk, qcp_problem_structure) 90 | 91 | loss, *new_data = make_step(qcp, target_x, target_y, target_s, 92 | Pk.data, Ak.data, qk, bk, step_size) 93 | losses.append(loss) 94 | 95 | Pk_data, Ak_data, qk, bk = new_data 96 | Pk.data, Ak.data = Pk_data, Ak_data 97 | data.Pupper_csc.data = np.asarray(Pk.data, copy=True)[Pcoo_csc_perm] 98 | data.Acsc.data = np.asarray(Ak.data, copy=True)[Acoo_csc_perm] 99 | data.q = np.asarray(qk, copy=True) 100 | data.b = np.asarray(bk, copy=True) 101 | 102 | solver.update(P=data.Pupper_csc, q=data.q, A=data.Acsc, b=data.b) 103 | 104 | curr_iter += 1 105 | 106 | return losses 107 | 108 | if __name__ == "__main__": 109 | np.random.seed(28) 110 | 111 | # SMALL 112 | m = 20 113 | n = 10 114 | # MEDIUM-ish 115 | # m = 200 116 | # n = 100 117 | # LARGE-ish 118 | # m = 2_000 119 | # n = 1_000 120 | # target_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 121 | target_problem = prob_generator.generate_pow_projection_problem(n=33) 122 | prob_data_cpu = QCPProbData(target_problem) 123 | 124 | Pupper_coo_to_csc_order = np.lexsort((prob_data_cpu.Pupper_coo.row, 125 | prob_data_cpu.Pupper_coo.col)) 126 | A_coo_to_csc_order = np.lexsort((prob_data_cpu.Acoo.row, 127 | prob_data_cpu.Acoo.col)) 128 | 129 | cones = prob_data_cpu.clarabel_cones 130 | settings = clarabel.DefaultSettings() 131 | settings.verbose = False 132 | settings.presolve_enable = False 133 | 134 | solver = clarabel.DefaultSolver(prob_data_cpu.Pupper_csc, 135 | prob_data_cpu.q, 136 | prob_data_cpu.Acsc, 137 | prob_data_cpu.b, 138 | cones, 139 | settings) 140 | 141 | start_solve = time.perf_counter() 142 | solution = solver.solve() 143 | end_solve = time.perf_counter() 144 | print(f"Clarabel solve took: {end_solve - start_solve} seconds") 145 | 146 | target_x = jnp.array(solution.x) 147 | target_y = jnp.array(solution.z) 148 | target_s = jnp.array(solution.s) 149 | 150 | P = scoo_to_bcoo(prob_data_cpu.Pupper_coo) 151 | A = scoo_to_bcoo(prob_data_cpu.Acoo) 152 | q = prob_data_cpu.q 153 | b = prob_data_cpu.b 154 | scs_cones = prob_data_cpu.scs_cones 155 | problem_structure = QCPStructureCPU(P, A, scs_cones) 156 | qcp_initial = HostQCP(P, A, q, b, 157 | target_x, target_y, target_s, 158 | problem_structure) 159 | fake_target_x = 1e-3 * jnp.arange(jnp.size(q), dtype=q.dtype) 160 | fake_target_y = 1e-3 * jnp.arange(jnp.size(b), dtype=b.dtype) 161 | fake_target_s = 1e-3 * jnp.arange(jnp.size(b), dtype=b.dtype) 162 | 163 | start_time = time.perf_counter() 164 | result = make_step(qcp_initial, fake_target_x, fake_target_y, 165 | fake_target_s, P.data, A.data, q, b, step_size=1e-5) 166 | result[0].block_until_ready() 167 | end_time = time.perf_counter() 168 | # NOTE(quill): well, technically VJP + loss + step computations 169 | print("diffqcp VJP compile + compute took: ", end_time - start_time) 170 | # patdb.debug() 171 | # --- test compiled solve --- 172 | 173 | start_time = time.perf_counter() 174 | result = make_step(qcp_initial, fake_target_x, fake_target_y, 175 | fake_target_s, P.data, A.data, q, b, step_size=1e-5) 176 | result[0].block_until_ready() 177 | end_time = time.perf_counter() 178 | print("Compiled diffqcp VJP compute took: ", end_time - start_time) 179 | 180 | # --- --- 181 | 182 | # initial_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 183 | initial_problem = prob_generator.generate_pow_projection_problem(n=33) 184 | prob_data_cpu = QCPProbData(initial_problem) 185 | 186 | cones = prob_data_cpu.clarabel_cones 187 | settings = clarabel.DefaultSettings() 188 | settings.verbose = False 189 | settings.presolve_enable = False 190 | 191 | solver = clarabel.DefaultSolver(prob_data_cpu.Pupper_csc, 192 | prob_data_cpu.q, 193 | prob_data_cpu.Acsc, 194 | prob_data_cpu.b, 195 | cones, 196 | settings) 197 | 198 | num_iter = 1000 199 | 200 | start_time = time.perf_counter() 201 | losses = grad_desc(Pk=scoo_to_bcoo(prob_data_cpu.Pupper_coo), 202 | Ak=scoo_to_bcoo(prob_data_cpu.Acoo), 203 | qk = jnp.array(prob_data_cpu.q), 204 | bk = jnp.array(prob_data_cpu.b), 205 | target_x=target_x, 206 | target_y=target_y, 207 | target_s=target_s, 208 | qcp_problem_structure=problem_structure, 209 | data=prob_data_cpu, 210 | Pcoo_csc_perm=Pupper_coo_to_csc_order, 211 | Acoo_csc_perm=A_coo_to_csc_order, 212 | clarabel_solver=solver, num_iter=num_iter) 213 | losses[0].block_until_ready() 214 | end_time = time.perf_counter() 215 | print(f"The learning loop time was {end_time - start_time} seconds") 216 | print(f"Avg. iteration (solve + VJP) time: {(end_time - start_time) / num_iter}") 217 | losses = jnp.stack(losses) 218 | losses = np.asarray(losses) 219 | 220 | plt.figure(figsize=(8, 6)) 221 | plt.plot(range(num_iter), losses, label="Objective Trajectory") 222 | plt.xlabel("num. iterations") 223 | plt.ylabel("Objective function") 224 | plt.legend() 225 | plt.title(label="diffqcp") 226 | results_dir = os.path.join(os.path.dirname(__file__), "results") 227 | if prob_data_cpu.n > 99: 228 | # output_path = os.path.join(results_dir, "diffqcp_cpu_probability_large.svg") 229 | output_path = os.path.join(results_dir, "diffqcp_cpu_pow_large.svg") 230 | else: 231 | # output_path = os.path.join(results_dir, "diffqcp_cpu_probability_small.svg") 232 | output_path = os.path.join(results_dir, "diffqcp_cpu_pow_small.svg") 233 | plt.savefig(output_path, format="svg") 234 | plt.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

diffqcp: Differentiating through conic quadratic programs

2 | 3 | `diffqcp` is a [JAX](https://docs.jax.dev/en/latest/) library to form the derivative of the solution map to a conic quadratic program (CQP) with respect to the CQP problem data as an abstract linear operator and to compute Jacobian-vector products (JVPs) and vector-Jacobian products (VJPs) with this operator. 4 | The implementation is based on the derivations in our paper (see below) and computes 5 | these products implicitly via projections onto cones and sparse linear system solves. 6 | Our approach therefore differs from libraries that compute JVPs and VJPs by unrolling algorithm iterates. 7 | We directly exploit the underlying structure of CQPs. 8 | 9 | **Features include**: 10 | - Hardware acclerated: JVPs and VJPs can be computed on CPUs, GPUs, and (theoretically) TPUs. 11 | - Support for many canonical classes of convex optimization problems including 12 | - linear programs (LPs), 13 | - quadratic programs (QPs), 14 | - second-order cone programs (SOCPs), 15 | - and semidefinite programs (SDPs). 16 | - Support for convex optimization problems constrained to the product of exponential 17 | and power cones (as well as their duals). 18 | 19 | # Conic quadratic programs 20 | 21 | A conic quadratic program is given by the primal and dual problems 22 | 23 | ```math 24 | \begin{equation*} 25 | \begin{array}{lll} 26 | \text{(P)} \quad &\text{minimize} \; & (1/2)x^T P x + q^T x \\ 27 | &\text{subject to} & Ax + s = b \\ 28 | & & s \in \mathcal{K}, 29 | \end{array} 30 | \qquad 31 | \begin{array}{lll} 32 | \text{(D)} \quad &\text{maximize} \; & -(1/2)x^T P x -b^T y \\ 33 | &\text{subject to} & Px + A^T y = -q \\ 34 | & & y \in \mathcal{K}^*, 35 | \end{array} 36 | \end{equation*} 37 | ``` 38 | where $`x \in \mathbf{R}^n`$ is the *primal* variable, $`y \in \mathbf{R}^m`$ is the *dual* variable, and $`s \in \mathbf{R}^m`$ is the primal *slack* variable. The problem data are $`P\in \mathbf{S}_+^{n}`$, $`A \in \mathbf{R}^{m \times n}`$, $`q \in \mathbf{R}^n`$, and $`b \in \mathbf{R}^m`$. We assume that $`\mathcal K \subseteq \mathbf{R}^m`$ is a nonempty, closed, convex cone with dual cone $`\mathcal{K}^*`$. 39 | 40 | `diffqcp` currently supports CQPs whose cone is the Cartesian product of the zero cone, the positive orthant, second-order cones, positive semidefinite cones, 41 | exponential cones, dual exponential cones, power cones, and dual power cones. 42 | For more information about these cones, see the appendix of our paper. 43 | 44 | # Usage 45 | 46 | `diffqcp` is meant to be used as a CVXPYlayers backend --- it is not designed to be a stand-alone 47 | library. 48 | Nonetheless, here is how it use it. 49 | (Note that while we'll specify different CPU and a GPU configurations, 50 | all modules are CPU and GPU compatible--we just recommend the following 51 | as JAX's `BCSR` arrays do have CUDA backends for their `mv` operations while the `BCOO` arrays do not.) 52 | 53 | For both of the following problems, we'll use the following objects: 54 | 55 | ```python 56 | import cvxpy as cvx 57 | 58 | problem = cvx.Problem(...) 59 | prob_data, _, _ = problem.get_problem_data(cvx.CLARABEL, solver_opts={'use_quad_obj': True}) 60 | scs_cones = cvx.reductions.solvers.conic_solvers.scs_conif.dims_to_solver_dict(prob_data["dims"]) 61 | 62 | x, y, s = ... # canonicalized solutions to `problem` 63 | ``` 64 | 65 | ## Optimal CPU approach 66 | 67 | If computing JVPs and VJPs on a CPU, we recommend using the `equinox.Module`s `HostQCP` and `QCPStructureCPU` as demonstrated in the following pseudo-example. 68 | 69 | ```python 70 | from diffqcp import HostQCP, QCPStructureCPU 71 | from jax.experimental.sparse import BCOO 72 | from jaxtyping import Array 73 | 74 | P: BCOO = ... # Only the upper triangular part of the CQP matrix P 75 | A: BCOO = ... 76 | q: Array = ... 77 | b: Array = ... 78 | 79 | problem_structure = QCPStructureCPU(P, A, scs_cones) 80 | qcp = HostQCP(P, A, q, b, x, y, s, problem_structure) 81 | 82 | # Compute JVPs 83 | 84 | dP: BCOO ... # Same sparsity pattern as `P` 85 | dA: BCOO = ... # Same sparsity pattern as `A` 86 | db: Array = ... 87 | dq: Array = ... 88 | 89 | dx, dy, ds = qcp.jvp(dP, dA, dq, db) 90 | 91 | # Compute VJPs 92 | # `dP`, `dA` will be BCOO arrays, `dq`, `db` just Arrays 93 | dP, dA, dq, db = qcp.vjp(f1(x), f2(y), f3(s)) 94 | ``` 95 | 96 | ## Optimal GPU approach 97 | 98 | If computing JVPs and VJPs on a GPU, we recommend using the `equinox.Module`s `QCPStructureGPU` and `DeviceQCP`. 99 | 100 | ```python 101 | from diffqcp import DeviceQCP, QCPStructureGPU 102 | from jax.experimental.sparse import BCSR 103 | from jaxtyping import Array 104 | 105 | P: BCSR = ... # The entirety of the CQP matrix P 106 | A: BCSR = ... 107 | q: Array = ... 108 | b: Array = ... 109 | 110 | problem_structure = QCPStructureGPU(P, A, scs_cones) 111 | qcp = DeviceQCP(P, A, q, b, x, y, s, problem_structure) 112 | 113 | # Compute JVPs 114 | 115 | dP: BCSR ... # Same sparsity pattern as `P` 116 | dA: BCSR = ... # Same sparsity pattern as `A` 117 | db: Array = ... 118 | dq: Array = ... 119 | 120 | dx, dy, ds = qcp.jvp(dP, dA, dq, db) 121 | 122 | # Compute VJPs 123 | # `dP`, `dA` will be BCSR arrays, `dq`, `db` just Arrays 124 | dP, dA, dq, db = qcp.vjp(f1(x), f2(y), f3(s)) 125 | ``` 126 | 127 | ## Selecting solvers 128 | 129 | As detailed in our paper, the JVPs and VJPs are computed via a linear system solve. 130 | For this solve, `diffqcp` provides three options: 131 | - LSMR via `lineax`, an indirect method that does not materialize the coefficient matrix. 132 | - LU via `lineax`, a direct method that materializes the dense coefficient matrix. 133 | - A direct method via `nvmath-python` / `cuDSS` that materializes the dense coefficient matrix. 134 | 135 | The default solve method is `lineax`'s LSMR, aleit it is not packaged in a released 136 | `lineax` version, so `lineax `must be installed from source (*e.g.*, 137 | `uv add "lineax @ git+https://github.com/patrick-kidger/lineax.git"`). 138 | To switch between the solvers, provide `jax-lsmr`, `jax-lu`, or `nvmath-direct` (as strings) to the optional 139 | `solve_method` parameter of an `AbstractQCP`'s `jvp` and `vjp` methods. 140 | 141 | **Future direction:** 142 | 1. We're currently debugging why the direct solve methods yield exploding gradients. 143 | 2. We're currently working on materializing the coefficient matrix as a sparse array, not a dense matrix. The `lineax` LU method would still require forming the dense matrix, 144 | but the cuDSS backed-solve already accepts sparse arrays in CSR layout. 145 | 146 | # Installation 147 | 148 | | Platform | Instructions | 149 | |-----------------|-----------------------------------------| 150 | | CPU | `pip install diffqcp` | 151 | | NVIDIA GPU | `pip install "diffqcp[gpu]"` | 152 | 153 | Note that `diffqcp[gpu]` is currently packaged with version 12 of CUDA. Moreover, 154 | if your system supports version 13 of CUDA, install the CPU version of `diffqcp` 155 | and then `pip install -U jax[cuda13]`. Optionally, if you want access to the cuDSS 156 | solvers, also `pip install "cupy-cuda13x` and `nvmath-python[cu12]`. (Although 157 | note that we're unsure how `nvmath-python[cu12]` will interact with the version 158 | 13s of the other packages.) 159 | 160 | # Citation 161 | 162 | 163 | [arXiv:2508.17522 [math.OC]](https://arxiv.org/abs/2508.17522) 164 | ``` 165 | @misc{healey2025differentiatingquadraticconeprogram, 166 | title={Differentiating Through a Quadratic Cone Program}, 167 | author={Quill Healey and Parth Nobel and Stephen Boyd}, 168 | year={2025}, 169 | eprint={2508.17522}, 170 | archivePrefix={arXiv}, 171 | primaryClass={math.OC}, 172 | url={https://arxiv.org/abs/2508.17522}, 173 | } 174 | ``` 175 | 176 | # Next steps 177 | 178 | `diffqcp` is still in development! WIP features and improvements include: 179 | - Batched problem computations. 180 | - Not forming dense $F$ when using direct solver methods. 181 | - Re-incorporate the LSMR solver when `lineax` has a new release. 182 | - Consider JAX's [`spsolve`](https://docs.jax.dev/en/latest/_autosummary/jax.experimental.sparse.linalg.spsolve.html#jax.experimental.sparse.linalg.spsolve). 183 | - Provide options to linear system solvers. 184 | - Better performance benchmarking / regression testing. 185 | - Migration of tests from our [torch branch](https://github.com/cvxgrp/diffqcp/tree/torch-implementation). 186 | 187 | ## See also 188 | 189 | **Core dependencies** (`diffqcp` makes essential use of the following libraries) 190 | - [Equinox](https://github.com/patrick-kidger/equinox): Neural networks and everything not already in core JAX (via callable `PyTree`s). 191 | - [Lineax](https://github.com/patrick-kidger/lineax): Linear solvers. 192 | 193 | **Related** 194 | - [CVXPYlayers](https://github.com/cvxpy/cvxpylayers): Construct differentiable convex optimization layers using [CVXPY](https://github.com/cvxpy/cvxpy/). (`diffqcp` is a backend for CVXPYlayers.) 195 | - [CuClarabel](https://github.com/oxfordcontrol/Clarabel.jl/tree/CuClarabel): The GPU implemenation of the second-order CQP solver, Clarabel. 196 | - [SCS](https://github.com/cvxgrp/scs): A first-order CQP solver that has an optional GPU-accelerated backend. 197 | - [diffcp](https://github.com/cvxgrp/diffcp): A (Python with C-bindings) library for differentiating through (linear) cone programs. 198 | -------------------------------------------------------------------------------- /experiments/heterogeneous_experiment2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Experiment solving problems on the CPU and computing VJPs on the GPU. 3 | """ 4 | import time 5 | import os 6 | from dataclasses import dataclass 7 | import numpy as np 8 | import jax 9 | jax.config.update("jax_enable_x64", True) 10 | import jax.numpy as jnp 11 | import jax.numpy.linalg as la 12 | from jaxtyping import Float, Array 13 | import equinox as eqx 14 | from jax.experimental.sparse import BCSR 15 | import clarabel 16 | from scipy.sparse import (spmatrix, sparray, 17 | csc_matrix, csc_array, triu) 18 | import patdb 19 | import matplotlib.pyplot as plt 20 | 21 | from diffqcp import DeviceQCP, QCPStructureGPU 22 | import experiments.cvx_problem_generator as prob_generator 23 | from tests.helpers import QCPProbData, scsr_to_bcsr 24 | 25 | type SP = spmatrix | sparray 26 | type SCSC = csc_matrix | csc_array 27 | 28 | 29 | def compute_loss(target_x, target_y, target_s, x, y, s): 30 | return (0.5 * la.norm(x - target_x)**2 + 0.5 * la.norm(y - target_y)**2 31 | + 0.5 * la.norm(s - target_s)**2) 32 | 33 | 34 | @eqx.filter_jit 35 | @eqx.debug.assert_max_traces(max_traces=1) 36 | def make_step( 37 | qcp: DeviceQCP, 38 | target_x: Float[Array, " n"], 39 | target_y: Float[Array, " m"], 40 | target_s: Float[Array, " m"], 41 | Pdata: Float[Array, "..."], 42 | Adata: Float[Array, "..."], 43 | q: Float[Array, " n"], 44 | b: Float[Array, " m"], 45 | step_size: float 46 | ) -> tuple[Float[Array, ""], Float[Array, "..."], Float[Array, "..."], 47 | Float[Array, " n"], Float[Array, " m"]]: 48 | loss = compute_loss(target_x, target_y, target_s, qcp.x, qcp.y, qcp.s) 49 | dP, dA, dq, db = qcp.vjp(qcp.x - target_x, 50 | qcp.y - target_y, 51 | qcp.s - target_s, 52 | solve_method="jax-lsmr") 53 | new_Pdata = Pdata - step_size * dP.data 54 | new_Adata = Adata - step_size * dA.data 55 | new_q = q - step_size * dq 56 | new_b = b - step_size * db 57 | return (loss, new_Pdata, new_Adata, new_q, new_b) 58 | 59 | 60 | def grad_desc( 61 | Pk: Float[BCSR, "n n"], 62 | Ak: Float[BCSR, "m n"], 63 | qk: Float[Array, " n"], 64 | bk: Float[Array, " m"], 65 | target_x: Float[Array, " n"], 66 | target_y: Float[Array, " m"], 67 | target_s: Float[Array, " m"], 68 | qcp_problem_structure: QCPStructureGPU, 69 | data: QCPProbData, 70 | Pcoo_csc_perm: Float[np.ndarray, "..."], 71 | Acoo_csc_perm: Float[np.ndarray, "..."], 72 | clarabel_solver, 73 | num_iter: int = 100, 74 | step_size = 1e-5 75 | ): 76 | curr_iter = 0 77 | losses = [] 78 | 79 | while curr_iter < num_iter: 80 | 81 | solution = clarabel_solver.solve() 82 | 83 | xk = jnp.array(solution.x) 84 | yk = jnp.array(solution.z) 85 | sk = jnp.array(solution.s) 86 | 87 | qcp = DeviceQCP(Pk, Ak, qk, bk, xk, yk, sk, qcp_problem_structure) 88 | 89 | loss, *new_data = make_step(qcp, target_x, target_y, target_s, 90 | Pk.data, Ak.data, qk, bk, step_size) 91 | losses.append(loss) 92 | 93 | Pk_data, Ak_data, qk, bk = new_data 94 | Pk.data, Ak.data = Pk_data, Ak_data 95 | # need to grap uppper part of P only 96 | data.Pcsc.data = np.asarray(Pk.data)[Pcoo_csc_perm] 97 | data.Pupper_csc = triu(data.Pcsr, format="csc") 98 | data.Acsc.data = np.asarray(Ak.data)[Acoo_csc_perm] 99 | data.q = np.asarray(qk) 100 | data.b = np.asarray(bk) 101 | 102 | solver.update(P=data.Pupper_csc, q=data.q, A=data.Acsc, b=data.b) 103 | 104 | curr_iter += 1 105 | 106 | return losses 107 | 108 | if __name__ == "__main__": 109 | np.random.seed(28) 110 | 111 | # SMALL 112 | m = 20 113 | n = 10 114 | # MEDIUM-ish 115 | # m = 200 116 | # n = 100 117 | # LARGE-ish 118 | # m = 2_000 119 | # n = 1_000 120 | # target_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 121 | # target_problem = prob_generator.generate_LS_problem(m=m, n=n) 122 | target_problem = prob_generator.generate_group_lasso_logistic(m=m, n=m) 123 | prob_data_cpu = QCPProbData(target_problem) 124 | 125 | # ensure validity of the following ordering permutations. 126 | np.testing.assert_allclose(prob_data_cpu.Pcoo.data, 127 | prob_data_cpu.Pcsr.data) 128 | 129 | np.testing.assert_allclose(prob_data_cpu.Acoo.data, 130 | prob_data_cpu.Acsr.data) 131 | 132 | P_coo_to_csc_order = np.lexsort((prob_data_cpu.Pcoo.row, 133 | prob_data_cpu.Pupper_coo.col)) 134 | A_coo_to_csc_order = np.lexsort((prob_data_cpu.Acoo.row, 135 | prob_data_cpu.Acoo.col)) 136 | 137 | cones = prob_data_cpu.clarabel_cones 138 | settings = clarabel.DefaultSettings() 139 | settings.verbose = False 140 | 141 | solver = clarabel.DefaultSolver(prob_data_cpu.Pupper_csc, 142 | prob_data_cpu.q, 143 | prob_data_cpu.Acsc, 144 | prob_data_cpu.b, 145 | cones, 146 | settings) 147 | 148 | start_solve = time.perf_counter() 149 | solution = solver.solve() 150 | end_solve = time.perf_counter() 151 | print(f"Clarabel solve took: {end_solve - start_solve} seconds") 152 | 153 | target_x = jnp.array(solution.x) 154 | target_y = jnp.array(solution.z) 155 | target_s = jnp.array(solution.s) 156 | 157 | P = scsr_to_bcsr(prob_data_cpu.Pcsr) 158 | A = scsr_to_bcsr(prob_data_cpu.Acsr) 159 | q = prob_data_cpu.q 160 | b = prob_data_cpu.b 161 | scs_cones = prob_data_cpu.scs_cones 162 | problem_structure = QCPStructureGPU(P, A, scs_cones) 163 | qcp_initial = DeviceQCP(P, A, q, b, 164 | target_x, target_y, target_s, 165 | problem_structure) 166 | fake_target_x = 1e-3 * jnp.arange(jnp.size(q), dtype=q.dtype) 167 | fake_target_y = 1e-3 * jnp.arange(jnp.size(b), dtype=b.dtype) 168 | fake_target_s = 1e-3 * jnp.arange(jnp.size(b), dtype=b.dtype) 169 | 170 | start_time = time.perf_counter() 171 | result = make_step(qcp_initial, fake_target_x, fake_target_y, 172 | fake_target_s, P.data, A.data, q, b, step_size=1e-5) 173 | result[0].block_until_ready() 174 | end_time = time.perf_counter() 175 | # NOTE(quill): well, technically VJP + loss + step computations 176 | print("diffqcp VJP compile + compute took: ", end_time - start_time) 177 | 178 | # --- test compiled solve --- 179 | 180 | start_time = time.perf_counter() 181 | # with jax.profiler.trace("/home/quill/diffqcp/tmp/indirect-trace", create_perfetto_link=True): 182 | result = make_step(qcp_initial, fake_target_x, fake_target_y, 183 | fake_target_s, P.data, A.data, q, b, step_size=1e-5) 184 | result[0].block_until_ready() 185 | end_time = time.perf_counter() 186 | print("Compiled diffqcp VJP compute took: ", end_time - start_time) 187 | 188 | # --- --- 189 | 190 | # initial_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 191 | # initial_problem = prob_generator.generate_LS_problem(m=m, n=n) 192 | initial_problem = prob_generator.generate_group_lasso_logistic(m=m, n=m) 193 | prob_data_cpu = QCPProbData(initial_problem) 194 | 195 | cones = prob_data_cpu.clarabel_cones 196 | settings = clarabel.DefaultSettings() 197 | settings.verbose = False 198 | settings.presolve_enable = False 199 | 200 | solver = clarabel.DefaultSolver(prob_data_cpu.Pupper_csc, 201 | prob_data_cpu.q, 202 | prob_data_cpu.Acsc, 203 | prob_data_cpu.b, 204 | cones, 205 | settings) 206 | 207 | num_iter = 100 208 | 209 | start_time = time.perf_counter() 210 | losses = grad_desc(Pk=scsr_to_bcsr(prob_data_cpu.Pcsr), 211 | Ak=scsr_to_bcsr(prob_data_cpu.Acsr), 212 | qk = jnp.array(prob_data_cpu.q), 213 | bk = jnp.array(prob_data_cpu.b), 214 | target_x=target_x, 215 | target_y=target_y, 216 | target_s=target_s, 217 | qcp_problem_structure=problem_structure, 218 | data=prob_data_cpu, 219 | Pcoo_csc_perm=P_coo_to_csc_order, 220 | Acoo_csc_perm=A_coo_to_csc_order, 221 | clarabel_solver=solver, num_iter=num_iter) 222 | losses[0].block_until_ready() 223 | end_time = time.perf_counter() 224 | print(f"The learning loop time was {end_time - start_time} seconds") 225 | print(f"Avg. iteration (solve + VJP) time: {(end_time - start_time) / num_iter}") 226 | losses = jnp.stack(losses) 227 | losses = np.asarray(losses) 228 | 229 | print("starting loss: ", losses[0]) 230 | print("final loss: ", losses[-1]) 231 | 232 | plt.figure(figsize=(8, 6)) 233 | plt.plot(range(num_iter), losses, label="Objective Trajectory") 234 | plt.xlabel("num. iterations") 235 | plt.ylabel("Objective function") 236 | plt.legend() 237 | plt.title(label="diffqcp") 238 | results_dir = os.path.join(os.path.dirname(__file__), "results") 239 | if prob_data_cpu.n > 99: 240 | output_path = os.path.join(results_dir, "diffqcp_logistic_lasso_large.svg") 241 | else: 242 | output_path = os.path.join(results_dir, "diffqcp_logistic_lasso_small.svg") 243 | plt.savefig(output_path, format="svg") 244 | plt.close() -------------------------------------------------------------------------------- /experiments/direct_solve_experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Experiment solving problems on the CPU and computing VJPs on the GPU. 3 | """ 4 | import time 5 | import os 6 | import numpy as np 7 | import jax 8 | jax.config.update("jax_enable_x64", True) 9 | import jax.numpy as jnp 10 | import jax.numpy.linalg as la 11 | from jaxtyping import Float, Array 12 | import equinox as eqx 13 | from jax.experimental.sparse import BCSR 14 | import clarabel 15 | from scipy.sparse import (spmatrix, sparray, 16 | csc_matrix, csc_array, triu) 17 | import patdb 18 | import matplotlib.pyplot as plt 19 | 20 | from diffqcp import DeviceQCP, QCPStructureGPU 21 | import experiments.cvx_problem_generator as prob_generator 22 | from tests.helpers import QCPProbData, scsr_to_bcsr 23 | 24 | type SP = spmatrix | sparray 25 | type SCSC = csc_matrix | csc_array 26 | 27 | 28 | def compute_loss(target_x, target_y, target_s, x, y, s): 29 | return (0.5 * la.norm(x - target_x)**2 + 0.5 * la.norm(y - target_y)**2 30 | + 0.5 * la.norm(s - target_s)**2) 31 | 32 | 33 | def _update_data( 34 | dP, dA, dq, db, Pdata, Adata, q, b, step_size 35 | ): 36 | new_Pdata = Pdata - step_size * dP.data 37 | new_Adata = Adata - step_size * dA.data 38 | new_q = q - step_size * dq 39 | new_b = b - step_size * db 40 | return new_Pdata, new_Adata, new_q, new_b 41 | 42 | 43 | @eqx.filter_jit 44 | def make_step( 45 | qcp: DeviceQCP, 46 | target_x: Float[Array, " n"], 47 | target_y: Float[Array, " m"], 48 | target_s: Float[Array, " m"], 49 | Pdata: Float[Array, "..."], 50 | Adata: Float[Array, "..."], 51 | q: Float[Array, " n"], 52 | b: Float[Array, " m"], 53 | step_size: float 54 | ) -> tuple[Float[Array, ""], Float[Array, "..."], Float[Array, "..."], 55 | Float[Array, " n"], Float[Array, " m"]]: 56 | loss = eqx.filter_jit(compute_loss)(target_x, target_y, target_s, qcp.x, qcp.y, qcp.s) 57 | dP, dA, dq, db = qcp.vjp(qcp.x - target_x, 58 | qcp.y - target_y, 59 | qcp.s - target_s, 60 | solve_method="jax-lsmr") 61 | updated_data = eqx.filter_jit(_update_data)(dP, dA, dq, db, Pdata, Adata, q, b, step_size) 62 | return loss, *updated_data 63 | 64 | 65 | def grad_desc( 66 | Pk: Float[BCSR, "n n"], 67 | Ak: Float[BCSR, "m n"], 68 | qk: Float[Array, " n"], 69 | bk: Float[Array, " m"], 70 | target_x: Float[Array, " n"], 71 | target_y: Float[Array, " m"], 72 | target_s: Float[Array, " m"], 73 | qcp_problem_structure: QCPStructureGPU, 74 | data: QCPProbData, 75 | Pcoo_csc_perm: Float[np.ndarray, "..."], 76 | Acoo_csc_perm: Float[np.ndarray, "..."], 77 | clarabel_solver, 78 | num_iter: int = 100, 79 | step_size = 1e-5 80 | ): 81 | curr_iter = 0 82 | losses = [] 83 | 84 | while curr_iter < num_iter: 85 | 86 | solution = clarabel_solver.solve() 87 | 88 | xk = jnp.array(solution.x) 89 | yk = jnp.array(solution.z) 90 | sk = jnp.array(solution.s) 91 | 92 | qcp = DeviceQCP(Pk, Ak, qk, bk, xk, yk, sk, qcp_problem_structure) 93 | 94 | loss, *new_data = make_step(qcp, target_x, target_y, target_s, 95 | Pk.data, Ak.data, qk, bk, step_size) 96 | losses.append(loss) 97 | 98 | Pk_data, Ak_data, qk, bk = new_data 99 | Pk.data, Ak.data = Pk_data, Ak_data 100 | # need to grap uppper part of P only 101 | data.Pcsc.data = np.asarray(Pk.data)[Pcoo_csc_perm] 102 | data.Pupper_csc = triu(data.Pcsr, format="csc") 103 | data.Acsc.data = np.asarray(Ak.data)[Acoo_csc_perm] 104 | data.q = np.asarray(qk) 105 | data.b = np.asarray(bk) 106 | 107 | solver.update(P=data.Pupper_csc, q=data.q, A=data.Acsc, b=data.b) 108 | 109 | curr_iter += 1 110 | 111 | return losses 112 | 113 | if __name__ == "__main__": 114 | np.random.seed(28) 115 | 116 | # SMALL 117 | # m = 20 118 | # n = 10 119 | # MEDIUM-ish 120 | m = 200 121 | n = 100 122 | # LARGE-ish 123 | # m = 2_000 124 | # n = 1_000 125 | target_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 126 | # target_problem = prob_generator.generate_LS_problem(m=m, n=n) 127 | prob_data_cpu = QCPProbData(target_problem) 128 | 129 | # ensure validity of the following ordering permutations. 130 | np.testing.assert_allclose(prob_data_cpu.Pcoo.data, 131 | prob_data_cpu.Pcsr.data) 132 | 133 | np.testing.assert_allclose(prob_data_cpu.Acoo.data, 134 | prob_data_cpu.Acsr.data) 135 | 136 | P_coo_to_csc_order = np.lexsort((prob_data_cpu.Pcoo.row, 137 | prob_data_cpu.Pupper_coo.col)) 138 | A_coo_to_csc_order = np.lexsort((prob_data_cpu.Acoo.row, 139 | prob_data_cpu.Acoo.col)) 140 | 141 | cones = prob_data_cpu.clarabel_cones 142 | settings = clarabel.DefaultSettings() 143 | settings.verbose = False 144 | # settings.presolve_enable = False 145 | 146 | solver = clarabel.DefaultSolver(prob_data_cpu.Pupper_csc, 147 | prob_data_cpu.q, 148 | prob_data_cpu.Acsc, 149 | prob_data_cpu.b, 150 | cones, 151 | settings) 152 | 153 | start_solve = time.perf_counter() 154 | solution = solver.solve() 155 | end_solve = time.perf_counter() 156 | print(f"Clarabel solve took: {end_solve - start_solve} seconds") 157 | 158 | target_x = jnp.array(solution.x) 159 | target_y = jnp.array(solution.z) 160 | target_s = jnp.array(solution.s) 161 | 162 | P = scsr_to_bcsr(prob_data_cpu.Pcsr) 163 | A = scsr_to_bcsr(prob_data_cpu.Acsr) 164 | q = prob_data_cpu.q 165 | b = prob_data_cpu.b 166 | scs_cones = prob_data_cpu.scs_cones 167 | problem_structure = QCPStructureGPU(P, A, scs_cones) 168 | qcp_initial = DeviceQCP(P, A, q, b, 169 | target_x, target_y, target_s, 170 | problem_structure) 171 | fake_target_x = 1e-3 * jnp.arange(jnp.size(q), dtype=q.dtype) 172 | fake_target_y = 1e-3 * jnp.arange(jnp.size(b), dtype=b.dtype) 173 | fake_target_s = 1e-3 * jnp.arange(jnp.size(b), dtype=b.dtype) 174 | 175 | start_time = time.perf_counter() 176 | result = make_step(qcp_initial, fake_target_x, fake_target_y, 177 | fake_target_s, P.data, A.data, q, b, step_size=1e-5) 178 | result[0].block_until_ready() 179 | end_time = time.perf_counter() 180 | # NOTE(quill): well, technically VJP + loss + step computations 181 | print("diffqcp VJP compile + compute took: ", end_time - start_time) 182 | 183 | # --- test compiled solve --- 184 | 185 | start_time = time.perf_counter() 186 | # with jax.profiler.trace("/home/quill/diffqcp/tmp/jax-trace", create_perfetto_link=True): 187 | result = make_step(qcp_initial, fake_target_x, fake_target_y, 188 | fake_target_s, P.data, A.data, q, b, step_size=1e-5) 189 | result[0].block_until_ready() 190 | end_time = time.perf_counter() 191 | print("Compiled diffqcp VJP compute took: ", end_time - start_time) 192 | 193 | # --- --- 194 | 195 | initial_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 196 | # initial_problem = prob_generator.generate_LS_problem(m=m, n=n) 197 | prob_data_cpu = QCPProbData(initial_problem) 198 | 199 | cones = prob_data_cpu.clarabel_cones 200 | settings = clarabel.DefaultSettings() 201 | settings.verbose = False 202 | settings.presolve_enable = False 203 | 204 | solver = clarabel.DefaultSolver(prob_data_cpu.Pupper_csc, 205 | prob_data_cpu.q, 206 | prob_data_cpu.Acsc, 207 | prob_data_cpu.b, 208 | cones, 209 | settings) 210 | 211 | num_iter = 100 212 | 213 | start_time = time.perf_counter() 214 | losses = grad_desc(Pk=scsr_to_bcsr(prob_data_cpu.Pcsr), 215 | Ak=scsr_to_bcsr(prob_data_cpu.Acsr), 216 | qk = jnp.array(prob_data_cpu.q), 217 | bk = jnp.array(prob_data_cpu.b), 218 | target_x=target_x, 219 | target_y=target_y, 220 | target_s=target_s, 221 | qcp_problem_structure=problem_structure, 222 | data=prob_data_cpu, 223 | Pcoo_csc_perm=P_coo_to_csc_order, 224 | Acoo_csc_perm=A_coo_to_csc_order, 225 | clarabel_solver=solver, num_iter=num_iter) 226 | losses[0].block_until_ready() 227 | end_time = time.perf_counter() 228 | print(f"The learning loop time was {end_time - start_time} seconds") 229 | print(f"Avg. iteration (solve + VJP) time: {(end_time - start_time) / num_iter}") 230 | losses = jnp.stack(losses) 231 | losses = np.asarray(losses) 232 | 233 | print("starting loss: ", losses[0]) 234 | print("final loss: ", losses[-1]) 235 | 236 | plt.figure(figsize=(8, 6)) 237 | plt.plot(range(num_iter), losses, label="Objective Trajectory") 238 | plt.xlabel("num. iterations") 239 | plt.ylabel("Objective function") 240 | plt.legend() 241 | plt.title(label="diffqcp") 242 | results_dir = os.path.join(os.path.dirname(__file__), "results") 243 | if prob_data_cpu.n > 99: 244 | output_path = os.path.join(results_dir, "lsmr_dense_ls_100_iterates_direct_11_10.svg") 245 | else: 246 | output_path = os.path.join(results_dir, "dsolve_probability_small.svg") 247 | plt.savefig(output_path, format="svg") 248 | plt.close() -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2017 Steven Diamond 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /diffqcp/problem_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = [ 4 | "QCPStructureCPU", 5 | "QCPStructureGPU", 6 | "QCPStructureLayers", 7 | "ObjMatrixGPU", 8 | "ObjMatrixCPU" 9 | ] 10 | 11 | from abc import abstractmethod 12 | from typing import TYPE_CHECKING 13 | 14 | from jax import ShapeDtypeStruct 15 | import jax.numpy as jnp 16 | from jax.experimental.sparse import BCOO, BCSR 17 | import equinox as eqx 18 | import lineax as lx 19 | from lineax import AbstractLinearOperator 20 | from jaxtyping import Float, Integer, Bool, Array 21 | 22 | if TYPE_CHECKING: 23 | from cvxpy.reductions.dcp2cone.cone_matrix_stuffing import ParamConeProg 24 | 25 | from diffqcp.cones.canonical import ProductConeProjector 26 | from diffqcp._helpers import _coo_to_csr_transpose_map, _TransposeCSRInfo 27 | 28 | class QCPStructure(eqx.Module): 29 | 30 | # The following are needed for `_form_atoms` in `AbstractQCP`. 31 | n: eqx.AbstractVar[int] 32 | m: eqx.AbstractVar[int] 33 | N: eqx.AbstractVar[int] 34 | cone_projector: eqx.AbstractVar[ProductConeProjector] 35 | 36 | @abstractmethod 37 | def obj_matrix_init(self): 38 | pass 39 | 40 | @abstractmethod 41 | def constr_matrix_init(self): 42 | pass 43 | 44 | 45 | class QCPStructureCPU(QCPStructure): 46 | """ 47 | `P` is assumed to be the upper triangular part of the matrix in the quadratic form. 48 | """ 49 | 50 | n: int 51 | m: int 52 | N: int 53 | cone_projector: ProductConeProjector 54 | is_batched: bool 55 | 56 | P_nonzero_rows: Integer[Array, "..."] 57 | P_nonzero_cols: Integer[Array, "..."] 58 | P_diag_mask: Bool[Array, "..."] 59 | P_diag_indices: Integer[Array, "..."] 60 | 61 | A_nonzero_rows: Integer[Array, "..."] 62 | A_nonzero_cols: Integer[Array, "..."] 63 | 64 | def __init__( 65 | self, 66 | P: Float[BCOO, "*batch n n"], 67 | A: Float[BCOO, "*batch n n"], 68 | cone_dims: dict[str, int | list[int] | list[float]], 69 | onto_dual: bool = True 70 | ): 71 | 72 | # NOTE(quill): checks on `cone_dims` done in `ProductConeProjector.__init__` 73 | self.cone_projector = ProductConeProjector(cone_dims, onto_dual=onto_dual) 74 | 75 | if not isinstance(P, BCOO): 76 | raise ValueError("The objective matrix `P` must be a `BCOO` JAX matrix," 77 | + f" but the provided `P` is a {type(P)}.") 78 | 79 | if P.n_batch == 0: 80 | self.is_batched = False 81 | self.obj_matrix_init(P) 82 | elif P.n_batch == 1: 83 | self.is_batched = True 84 | # Extract information from first matrix in the batch. 85 | # Strict requirement is that all matrices in the batch share 86 | # the same sparsity structure (holds via DPP, also maybe required by JAX?) 87 | self.obj_matrix_init(P[0]) 88 | else: 89 | raise ValueError("The objective matrix `P` must have at most one batch dimension," 90 | + f" but the provided BCOO matrix has {P.n_batch} dimensions.") 91 | 92 | if not isinstance(A, BCOO): 93 | raise ValueError("The objective matrix `A` must be a `BCOO` JAX matrix," 94 | + f" but the provided `A` is a {type(A)}.") 95 | 96 | # NOTE(quill): could theoretically allow mismatch and broadcast 97 | # (Just to keep in mind for the future; not needed now.) 98 | if A.n_batch != P.n_batch: 99 | raise ValueError(f"The objective matrix `P` has {P.n_batch} dimensions" 100 | + f" while the constraint matrix `A` has {A.n_batch}" 101 | + " dimensions. The batch dimensionality of `P` and `A`" 102 | + " must match.") 103 | if self.is_batched: 104 | self.constr_matrix_init(A[0]) 105 | else: 106 | self.constr_matrix_init(A) 107 | 108 | self.N = self.n + self.m + 1 109 | 110 | def obj_matrix_init(self, P: Float[BCOO, "n n"]): 111 | # TODO(quill): checks on P being upper triangular. 112 | # (Might as well do since this structure is formed once.) 113 | self.n = jnp.shape(P)[0] 114 | self.P_nonzero_rows = P.indices[:, 0] 115 | self.P_nonzero_cols = P.indices[:, 1] 116 | self.P_diag_mask = P.indices[:, 0] == P.indices[:, 1] 117 | self.P_diag_indices = P.indices[:, 0][self.P_diag_mask] 118 | 119 | def constr_matrix_init(self, A: Float[BCOO, "m n"]): 120 | self.m = jnp.shape(A)[0] 121 | self.A_nonzero_rows = A.indices[:, 0] 122 | self.A_nonzero_cols = A.indices[:, 1] 123 | 124 | def form_obj(self, P_like: Float[BCOO, "n n"]) -> ObjMatrixCPU: 125 | diag_values = P_like.data[self.P_diag_mask] 126 | diag = jnp.zeros(self.n) 127 | diag = diag.at[self.P_diag_indices].set(diag_values) 128 | return ObjMatrixCPU(P_like, P_like.T, diag) 129 | 130 | 131 | class QCPStructureGPU(QCPStructure): 132 | """ 133 | P is assumed to be the full matrix 134 | """ 135 | 136 | n: int 137 | m: int 138 | N: int 139 | cone_projector: ProductConeProjector 140 | is_batched: bool 141 | 142 | P_csr_indices: Integer[Array, "..."] 143 | P_csr_indptr: Integer[Array, "..."] 144 | P_nonzero_rows: Integer[Array, "..."] 145 | P_nonzero_cols: Integer[Array, "..."] 146 | 147 | A_csr_indices: Integer[Array, "..."] 148 | A_csr_indptr: Integer[Array, "..."] 149 | A_nonzero_rows: Integer[Array, "..."] 150 | A_nonzero_cols: Integer[Array, "..."] 151 | A_transpose_info: _TransposeCSRInfo 152 | 153 | def __init__( 154 | self, 155 | P: Float[BCSR, "*batch n n"], 156 | A: Float[BCSR, "*batch m n"], 157 | cone_dims: dict[str, int | list[int] | list[float]], 158 | onto_dual: bool = True 159 | ): 160 | 161 | # NOTE(quill): checks on `cone_dims` done in `ProductConeProjector.__init__` 162 | self.cone_projector = ProductConeProjector(cone_dims, onto_dual=onto_dual) 163 | 164 | if not isinstance(P, BCSR): 165 | raise ValueError("The objective matrix `P` must be a `BCSR` JAX matrix," 166 | + f" but the provided `P` is a {type(P)}.") 167 | # check if batched 168 | if P.n_batch == 0: 169 | self.is_batched = False 170 | self.obj_matrix_init(P) 171 | elif P.n_batch == 1: 172 | self.is_batched = True 173 | # NOTE(quill): see note in `QCPStructureCPU` 174 | self.obj_matrix_init(P[0]) 175 | else: 176 | raise ValueError("The objective matrix `P` must have at most one batch dimension," 177 | + f" but the provided BCSR matrix has {P.n_batch} dimensions.") 178 | 179 | if not isinstance(A, BCSR): 180 | raise ValueError("The objective matrix `A` must be a `BCSR` JAX matrix," 181 | + f" but the provided `A` is a {type(A)}.") 182 | 183 | # NOTE(quill): see note in `QCPStructureCPU` 184 | if A.n_batch != P.n_batch: 185 | raise ValueError(f"The objective matrix `P` has {P.n_batch} dimensions" 186 | + f" while the constraint matrix `A` has {A.n_batch}" 187 | + " dimensions. The batch dimensionality of `P` and `A`" 188 | + " must match.") 189 | 190 | if self.is_batched: 191 | self.constr_matrix_init(A[0]) 192 | else: 193 | self.constr_matrix_init(A) 194 | 195 | self.N = self.n + self.m + 1 196 | 197 | def obj_matrix_init(self, P: Float[BCSR, "n n"]): 198 | self.n = jnp.shape(P)[0] 199 | P_coo = P.to_bcoo() 200 | # NOTE(quill): the following assumption is needed for the following 201 | # manipulation to result in accurate metadata. 202 | # If this error occurs more frequently than not, then it will probably 203 | # be worth canonicalizing the data matrices by default. 204 | # NOTE(quill): must use `allclose` since `!=` compares if same data in memory. 205 | if not jnp.allclose(P_coo.data, P.data): 206 | raise ValueError("The ordering of the data in `P_coo` and `P`" 207 | + " (a BCSR matrix) does not match." 208 | + " Please try to coerce `P` into canonical form.") 209 | 210 | self.P_csr_indices = P.indices 211 | self.P_csr_indptr = P.indptr 212 | 213 | self.P_nonzero_rows = P_coo.indices[:, 0] 214 | self.P_nonzero_cols = P_coo.indices[:, 1] 215 | 216 | def constr_matrix_init(self, A: Float[BCSR, "m n"]): 217 | self.m = jnp.shape(A)[0] 218 | A_coo = A.to_bcoo() 219 | # NOTE(quill): see note in `obj_matrix_init` 220 | if not jnp.allclose(A_coo.data, A.data): 221 | raise ValueError("The ordering of the data in `A_coo` and `A`" 222 | + " (a BCSR matrix) does not match." 223 | + " Please try to coerce `A` into canonical form.") 224 | 225 | self.A_csr_indices = A.indices 226 | self.A_csr_indptr = A.indptr 227 | 228 | self.A_nonzero_rows = A_coo.indices[:, 0] 229 | self.A_nonzero_cols = A_coo.indices[:, 1] 230 | 231 | # Create metadata for cheap transposes 232 | self.A_transpose_info = _coo_to_csr_transpose_map(A_coo) 233 | 234 | def form_A_transpose(self, A_like: Float[BCSR, "m n"]) -> Float[BCSR, "n m"]: 235 | transposed_data = A_like.data[self.A_transpose_info.sorting_perm] 236 | return BCSR((transposed_data, 237 | self.A_transpose_info.indices, 238 | self.A_transpose_info.indptr), 239 | shape=(self.n, self.m)) 240 | 241 | 242 | class QCPStructureLayers(QCPStructure): 243 | """Meant to be used with CVXPYlayers.""" 244 | 245 | n: int 246 | m: int 247 | N: int 248 | cone_projector: ProductConeProjector 249 | is_batched: bool 250 | 251 | def __init__( 252 | self, 253 | prob: ParamConeProg, 254 | cone_dims: dict[str, int | list[int] | list[float]], 255 | onto_dual: bool = True 256 | ): 257 | 258 | self.cone_projector = ProductConeProjector(cone_dims, onto_dual=onto_dual) 259 | 260 | # # Now we need to obtain 261 | # constraint_structure = 262 | 263 | 264 | type ObjMatrix = ObjMatrixCPU | ObjMatrixGPU 265 | 266 | 267 | class ObjMatrixCPU(AbstractLinearOperator): 268 | P: Float[BCOO, "n n"] 269 | PT: Float[BCOO, "n n"] 270 | diag: Float[BCOO, " n"] 271 | in_struc: ShapeDtypeStruct 272 | 273 | def __init__( 274 | self, 275 | P: Float[BCOO, "n n"], 276 | PT: Float[BCOO, "n n"], 277 | diag: Float[BCOO, " n"] 278 | ): 279 | self.P, self.PT, self.diag = P, PT, diag 280 | n = jnp.shape(P)[0] 281 | self.in_struc = ShapeDtypeStruct(shape=(n,), 282 | dtype=P.data.dtype) 283 | 284 | def mv(self, vector): 285 | return self.P @ vector + self.PT @ vector - self.diag*vector 286 | 287 | def transpose(self): 288 | return self 289 | 290 | def as_matrix(self): 291 | raise NotImplementedError(f"{self.__class__.__name__}'s `as_matrix` method is" 292 | + " not yet implemented.") 293 | 294 | def in_structure(self): 295 | pass 296 | 297 | def out_structure(self): 298 | return self.in_structure() 299 | 300 | class ObjMatrixGPU(AbstractLinearOperator): 301 | P: Float[BCSR, "n n"] 302 | in_struc: ShapeDtypeStruct 303 | 304 | def __init__( 305 | self, 306 | P: Float[BCSR, "n n"], 307 | ): 308 | self.P = P 309 | n = jnp.shape(P)[0] 310 | self.in_struc = ShapeDtypeStruct(shape=(n,), 311 | dtype=P.data.dtype) 312 | 313 | def mv(self, vector): 314 | return self.P @ vector 315 | 316 | def transpose(self): 317 | return self 318 | 319 | def as_matrix(self): 320 | raise NotImplementedError(f"{self.__class__.__name__}'s `as_matrix` method is" 321 | + " not yet implemented.") 322 | 323 | def in_structure(self): 324 | pass 325 | 326 | def out_structure(self): 327 | return self.in_structure() 328 | 329 | @lx.is_symmetric.register(ObjMatrixCPU) 330 | def _(op): 331 | return True 332 | 333 | @lx.is_symmetric.register(ObjMatrixGPU) 334 | def _(op): 335 | return True -------------------------------------------------------------------------------- /diffqcp/cones/pow.py: -------------------------------------------------------------------------------- 1 | """Subroutines for projecting onto power cone and computing JVPs and VJPs with the derivative of the projection. 2 | """ 3 | from typing import TYPE_CHECKING 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import equinox as eqx 8 | import lineax as lx 9 | from jaxtyping import Array, Float, Bool 10 | 11 | from .abstract_projector import AbstractConeProjector 12 | 13 | if jax.config.jax_enable_x64: 14 | TOL = 1e-12 15 | else: 16 | TOL = 1e-6 17 | 18 | MAX_ITER = 20 19 | 20 | def _pow_calc_xi( 21 | ri: Float[Array, ""], 22 | x: Float[Array, ""], 23 | abs_z: Float[Array, ""], 24 | alpha: Float[Array, ""] 25 | ) -> Float[Array, ""]: 26 | """x_i from eq 4. from Hien paper""" 27 | x = 0.5 * (x + jnp.sqrt(x*x + 4. * alpha * (abs_z - ri) * ri)) 28 | return jnp.maximum(x, TOL) 29 | 30 | 31 | def _gi( 32 | ri: Float[Array, ""], 33 | xi: Float[Array, ""], 34 | abs_z: Float[Array, ""], 35 | alpha: Float[Array, ""] 36 | ) -> Float[Array, ""]: 37 | """gi from diffqcp paper.""" 38 | return 2. * _pow_calc_xi(ri, xi, abs_z, alpha) - xi 39 | 40 | 41 | def _pow_calc_f( 42 | ri: Float[Array, ""], 43 | xi: Float[Array, ""], 44 | yi: Float[Array, ""], 45 | alpha: Float[Array, ""] 46 | ) -> Float[Array, ""]: 47 | """Phi from Hien paper.""" 48 | return xi**alpha * yi**(1.-alpha) - ri 49 | 50 | 51 | def _pow_calc_dxi_dr( 52 | ri: Float[Array, ""], 53 | xi: Float[Array, ""], 54 | x: Float[Array, ""], 55 | abs_z: Float[Array, ""], 56 | alpha: Float[Array, ""] 57 | ) -> Float[Array, ""]: 58 | """ 59 | `xi` is an iterate toward the projection of `x` or `y` in `(x, y, z)` toward 60 | the first element or second element, respectively, in `proj(v)`. 61 | """ 62 | return alpha * (abs_z - 2.0 * ri) / (2.0 * xi - x) 63 | 64 | 65 | def _pow_calc_fp( 66 | xi: Float[Array, ""], 67 | yi: Float[Array, ""], 68 | dxidri: Float[Array, ""], 69 | dyidri: Float[Array, ""], 70 | alpha: Float[Array, ""] 71 | ) -> Float[Array, ""]: 72 | alphac = 1 - alpha 73 | # return xi**alpha + yi**alphac * (alpha * dxidri / xi + alphac * dyidri / yi) - 1 74 | return (xi**alpha) * (yi**alphac) * (alpha * dxidri / xi + alphac * dyidri / yi) - 1.0 75 | 76 | 77 | def _in_cone( 78 | x: Float[Array, ""], 79 | y: Float[Array, ""], 80 | abs_z: Float[Array, ""], 81 | alpha: Float[Array, ""] 82 | ) -> bool: 83 | return jnp.logical_and(x >= 0, 84 | jnp.logical_and(y >= 0, 85 | TOL + x**alpha * y**(1-alpha) >= abs_z)) 86 | 87 | 88 | def _in_polar_cone( 89 | u: Float[Array, ""], 90 | v: Float[Array, ""], 91 | abs_w: Float[Array, ""], 92 | alpha: Float[Array, ""] 93 | ) -> bool: 94 | return jnp.logical_and(u <= 0, 95 | jnp.logical_and(v <= 0, 96 | TOL + jnp.pow(-u, alpha) * jnp.pow(-v, 1. - alpha) >= 97 | abs_w * alpha**alpha + jnp.pow(1. - alpha, 1. - alpha))) 98 | 99 | 100 | def _proj_dproj( 101 | v: Float[Array, " 3"], 102 | alpha: Float[Array, ""] 103 | ) -> tuple[Float[Array, " 3"], Float[Array, "3 3"]]: 104 | x, y, z = v 105 | abs_z = jnp.abs(z) 106 | 107 | def identity_case(): 108 | return ( 109 | v, jnp.eye(3, dtype=x.dtype) 110 | ) 111 | 112 | def zero_case(): 113 | return ( 114 | jnp.zeros_like(v), jnp.zeros((3, 3), dtype=x.dtype) 115 | ) 116 | 117 | def z_zero_case(): 118 | J = jnp.zeros((3, 3), dtype=v.dtype) 119 | J = J.at[0, 0].set(0.5 * (jnp.sign(x) + 1.0)) 120 | J = J.at[1, 1].set(0.5 * (jnp.sign(y) + 1.0)) 121 | 122 | def case1(): # (x > 0 and y0 < 0 and a_device > 0.5) or (y0 > 0 and x < 0 and alpha < 0.5) 123 | return 1.0 124 | 125 | def case2(): # (x > 0 and y < 0 and alpha < 0.5) or (y > 0 and x < 0 and alpha > 0.5) 126 | return 0.0 127 | 128 | def case3(): # a_device == 0.5 and x0 > 0 and y0 < 0 129 | return x / (2 * jnp.abs(y) + x) 130 | 131 | def case4(): 132 | return y / (2 * jnp.abs(x) + y) 133 | 134 | cond1 = ((x > 0) & (y < 0) & (alpha > 0.5)) | ((y > 0) & (x < 0) & (alpha < 0.5)) 135 | cond2 = ((x > 0) & (y < 0) & (alpha < 0.5)) | ((y > 0) & (x < 0) & (alpha > 0.5)) 136 | cond3 = (alpha == 0.5) & (x > 0) & (y < 0) 137 | 138 | J22 = jax.lax.cond( 139 | cond1, case1, 140 | lambda: jax.lax.cond( 141 | cond2, case2, 142 | lambda: jax.lax.cond( 143 | cond3, case3, 144 | case4 145 | ) 146 | ) 147 | ) 148 | 149 | J = J.at[2, 2].set(J22) 150 | proj_v = jnp.array([jnp.maximum(x, 0), jnp.maximum(y, 0), 0.0], dtype=v.dtype) 151 | return proj_v, J 152 | 153 | 154 | def solve_case(): 155 | 156 | def _solve_while_body(loop_state): 157 | # NOTE(quill): we're purposefully using both `i` and `j`. 158 | # The former (which is in the function names) is denoting 159 | # an element in a vector while the latter is being used to denote 160 | # an interation count. 161 | loop_state["xj"] = _pow_calc_xi(loop_state["rj"], x, abs_z, alpha) 162 | loop_state["yj"] = _pow_calc_xi(loop_state["rj"], y, abs_z, 1.0 - alpha) 163 | fj = _pow_calc_f(loop_state["rj"], loop_state["xj"], loop_state["yj"], alpha) 164 | 165 | dxdr = _pow_calc_dxi_dr(loop_state["rj"], loop_state["xj"], x, abs_z, alpha) 166 | dydr = _pow_calc_dxi_dr(loop_state["rj"], loop_state["yj"], y, abs_z, 1.-alpha) 167 | fp = _pow_calc_fp(loop_state["xj"], loop_state["yj"], dxdr, dydr, alpha) 168 | 169 | loop_state["rj"] = jnp.maximum(loop_state["rj"] - fj / fp, 0) 170 | loop_state["rj"] = jnp.minimum(loop_state["rj"], abs_z) 171 | 172 | loop_state["itn"] += 1 173 | loop_state["istop"] = jax.lax.select(loop_state["itn"] > MAX_ITER, 2, loop_state["istop"]) 174 | loop_state["istop"] = jax.lax.select(jnp.abs(fj) <= TOL, 1, loop_state["istop"]) 175 | 176 | return loop_state 177 | 178 | def condfun(loop_state): 179 | return loop_state["istop"] == 0 180 | 181 | loop_state = { 182 | "xj": 0, 183 | "yj": 0, 184 | "rj": abs_z / 2, 185 | "istop": 0, 186 | "itn": 0 187 | } 188 | 189 | loop_state = jax.lax.while_loop(condfun, _solve_while_body, loop_state) 190 | 191 | r_star = loop_state["rj"] 192 | x_star = loop_state["xj"] 193 | y_star = loop_state["yj"] 194 | z_star = jax.lax.cond(z < 0, lambda: -r_star, lambda: r_star) 195 | proj_v = jnp.array([x_star, y_star, z_star]) 196 | a = alpha 197 | aa = a * a 198 | ac = 1 - alpha 199 | acac = ac * ac 200 | 201 | two_r = 2 * r_star 202 | sign_z = jnp.sign(z) 203 | gx = _gi(r_star, x, abs_z, a) 204 | gy = _gi(r_star, y, abs_z, a) 205 | frac_x = (a * x) / gx 206 | frac_y = (ac * y) / gy 207 | T = - (frac_x + frac_y) 208 | L = 2 * abs_z - two_r 209 | L = L / (abs_z + (abs_z - two_r) * (frac_x + frac_y)) 210 | 211 | gxgy = gx * gy 212 | rL = r_star * L 213 | J = jnp.zeros((3, 3), dtype=x.dtype) 214 | J = J.at[0, 0].set(0.5 + x / (2 * gx) + (aa * (abs_z - two_r) * rL) / (gx * gx)) 215 | J = J.at[1, 1].set(0.5 + y / (2 * gy) + (acac * (abs_z - two_r) * rL) / (gy * gy)) 216 | J = J.at[2, 2].set(r_star / abs_z + (r_star / abs_z) * T * L) 217 | J = J.at[0, 1].set(rL * acac * (abs_z - two_r) / gxgy) 218 | J = J.at[1, 0].set(J[0, 1]) 219 | J = J.at[0, 2].set(sign_z * a * rL / gx) 220 | J = J.at[2, 0].set(J[0, 2]) 221 | J = J.at[1, 2].set(sign_z * ac * rL / gy) 222 | J = J.at[2, 1].set(J[1, 2]) 223 | 224 | return proj_v, J 225 | 226 | return jax.lax.cond(_in_cone(x, y, abs_z, alpha), 227 | identity_case, 228 | lambda: jax.lax.cond( 229 | _in_polar_cone(x, y, abs_z, alpha), 230 | zero_case, 231 | lambda: jax.lax.cond( 232 | abs_z <= TOL, 233 | z_zero_case, 234 | solve_case 235 | ))) 236 | 237 | 238 | def _pow_cone_jacobian_mv( 239 | dx: Float[Array, "num_cones 3"], 240 | jacobians: Float[Array, "num_cones 3 3"], 241 | is_dual: Bool[Array, " num_cones"], 242 | num_cones: int, 243 | ): 244 | # num cones could be 1. 245 | dx_batch = jnp.reshape(dx, (num_cones, 3)) 246 | Jdx = eqx.filter_vmap(lambda jac, y: jac @ y, 247 | in_axes=(0, 0), out_axes=0)(jacobians, dx_batch) 248 | mv_dual = dx_batch - Jdx 249 | mv = jnp.where(is_dual[:, None], mv_dual, Jdx) 250 | return jnp.ravel(mv) 251 | 252 | 253 | class _PowerConeJacobianOperator(lx.AbstractLinearOperator): 254 | 255 | jacobians: Float[Array, "*num_batches num_cones 3 3"] 256 | is_dual: Bool[Array, " num_cones"] 257 | num_cones: int = eqx.field(static=True) 258 | 259 | def __init__( 260 | self, 261 | jacobians: Float[Array, "*num_batches num_cones 3 3"], 262 | is_dual: Bool[Array, " num_cones"], 263 | ): 264 | self.jacobians = jacobians 265 | self.is_dual = is_dual 266 | self.num_cones = jnp.size(is_dual) 267 | ndim = jnp.ndim(jacobians) 268 | if ndim not in [3, 4]: 269 | raise ValueError("The `jacobians` argument provided to the `_PowerConeJacobianOperator` " 270 | f"is {ndim}D, but it must be 3D or 4D.") 271 | 272 | def mv(self, dx: Float[Array, "*batch num_cones*3"]): 273 | ndim = jnp.ndim(dx) 274 | if ndim == 1: 275 | if jnp.ndim(self.jacobians) == 4: 276 | raise ValueError("Batched Power cone Jacobians cannot be applied to a 1D input.") 277 | 278 | return _pow_cone_jacobian_mv(dx, self.jacobians, self.is_dual, self.num_cones) 279 | elif ndim == 2: 280 | return eqx.filter_vmap(_pow_cone_jacobian_mv, 281 | in_axes=(0, 0, None, None), 282 | out_axes=0)(dx, self.jacobians, self.is_dual, self.num_cones) 283 | else: 284 | raise ValueError("The `_PowerConeJacobianOperator` can only be applied to 1D or 2D inputs " 285 | f"but the provided vector is {ndim}D.") 286 | 287 | def as_matrix(self): 288 | raise NotImplementedError("Power Cone Jacobian `as_matrix` not implemented.") 289 | 290 | def transpose(self): 291 | return self 292 | 293 | def in_structure(self): 294 | ndim = jnp.ndim(self.jacobians) 295 | shape = jnp.shape(self.jacobians) 296 | dtype = self.jacobians.dtype 297 | 298 | if ndim == 3: 299 | # non-batched case 300 | return jax.ShapeDtypeStruct(shape=(shape[0] * 3,), 301 | dtype=dtype) 302 | elif ndim == 4: 303 | # batched case 304 | return jax.ShapeDtypeStruct(shape=(shape[0], shape[1] * 3), 305 | dtype=dtype) 306 | 307 | def out_structure(self): 308 | return self.in_structure() 309 | 310 | @lx.is_symmetric.register(_PowerConeJacobianOperator) 311 | def _(op): 312 | return True 313 | 314 | class PowerConeProjector(AbstractConeProjector): 315 | 316 | # NOTE(quill): while similar, this implementation was a bit more challenging than 317 | # the exponential cone projector implementation as the `cone_dims` dictionary 318 | # returned by CVXPY has different keys for the exponential cone and its dual, whereas 319 | # primal vs dual power cone is encoded within the list of `alphas`. 320 | 321 | alphas: Float[Array, " num_cones"] 322 | num_cones: int = eqx.field(static=True) 323 | alphas_abs: Float[Array, " num_cones"] 324 | signs: Float[Array, " num_cones"] 325 | is_dual: Bool[Array, " num_cones"] 326 | 327 | def __init__(self, alphas: list[float], onto_dual: bool): 328 | 329 | self.alphas = jnp.array(alphas) 330 | self.num_cones = jnp.size(self.alphas) 331 | self.is_dual = self.alphas < 0 332 | self.signs = jnp.where(self.alphas < 0, -1.0, 1.0) 333 | if onto_dual: 334 | self.signs = -1.0 * self.signs 335 | self.is_dual = jnp.logical_not(self.is_dual) 336 | self.alphas_abs = jnp.abs(self.alphas) 337 | 338 | def proj_dproj(self, x): 339 | batch = jnp.reshape(x, (self.num_cones, 3)) 340 | # negate points being projected onto dual 341 | batch = batch * self.signs[:, None] 342 | 343 | proj_primal, jacs = eqx.filter_vmap(_proj_dproj, in_axes=(0, 0), out_axes=(0, 0))(batch, self.alphas_abs) 344 | 345 | # via Moreau: Pi_K^*(v) = v + Pi_K(-v) 346 | proj_dual = batch + proj_primal 347 | 348 | proj = jnp.where(self.is_dual[:, None], proj_dual, proj_primal) 349 | 350 | return jnp.ravel(proj), _PowerConeJacobianOperator(jacs, self.is_dual) -------------------------------------------------------------------------------- /experiments/heterogeneous_experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Experiment solving problems on the GPU and computing VJPs on the CPU. 3 | """ 4 | from juliacall import Main as jl 5 | # jl.seval('import Pkg; Pkg.develop(url="https://github.com/oxfordcontrol/Clarabel.jl.git")') 6 | jl.seval('using Clarabel, LinearAlgebra, SparseArrays') 7 | # jl.seval('Pkg.add("CUDA")') 8 | jl.seval('using CUDA, CUDA.CUSPARSE') 9 | 10 | type CuVector = jl.CUDA.Cuvector 11 | type CuSparseMatrixCSR = jl.CUDA.CUSPARSE.CuSparseMatrixCSR 12 | 13 | import time 14 | import os 15 | from dataclasses import dataclass, field 16 | import numpy as np 17 | import jax 18 | # TODO(quill): set JAX flags 19 | jax.config.update("jax_enable_x64", True) 20 | jax.config.update("jax_platform_name", "cpu") 21 | import jax.numpy as jnp 22 | import jax.numpy.linalg as la 23 | from jaxtyping import Float, Array 24 | import cupy as cp 25 | from cupy.sparse import csr_matrix 26 | import equinox as eqx 27 | import scipy.sparse as sparse 28 | import patdb 29 | from jax.experimental.sparse import BCOO 30 | 31 | from diffqcp.qcp import HostQCP, QCPStructureCPU 32 | import experiments.cvx_problem_generator as prob_generator 33 | from tests.helpers import QCPProbData, scoo_to_bcoo 34 | import matplotlib.pyplot as plt 35 | 36 | def JuliaCuVector2CuPyArray(jl_arr) -> cp.ndarray: 37 | """Taken from https://github.com/cvxgrp/CuClarabel/blob/main/src/python/jl2py.py. 38 | """ 39 | # Get the device pointer from Julia 40 | pDevice = jl.Int(jl.pointer(jl_arr)) 41 | 42 | # Get array length and element type 43 | span = jl.size(jl_arr) 44 | dtype = jl.eltype(jl_arr) 45 | 46 | # Map Julia type to CuPy dtype 47 | if dtype == jl.Float64: 48 | dtype = cp.float64 49 | else: 50 | dtype = cp.float32 51 | 52 | # Compute memory size in bytes (assuming 1D vector) 53 | size_bytes = int(span[0] * cp.dtype(dtype).itemsize) 54 | 55 | # Create CuPy memory view from the Julia pointer 56 | mem = cp.cuda.UnownedMemory(pDevice, size_bytes, owner=None) 57 | memptr = cp.cuda.MemoryPointer(mem, 0) 58 | 59 | # Wrap into CuPy ndarray 60 | arr = cp.ndarray(shape=span, dtype=dtype, memptr=memptr) 61 | return arr 62 | 63 | 64 | def cpu_csr_to_cupy_csr(mat: sparse.csr_matrix) -> csr_matrix: 65 | # Ensure all arrays are 1-D 66 | data = cp.array(mat.data) 67 | indices = cp.array(mat.indices) 68 | indptr = cp.array(mat.indptr) 69 | return csr_matrix((data, indices, indptr), shape=mat.shape) 70 | 71 | 72 | def cupy_csr_to_julia_csr(mat: csr_matrix) -> CuSparseMatrixCSR: 73 | shape = mat.shape 74 | m, n = shape[0], shape[1] 75 | nnz = mat.nnz 76 | if nnz == 0: 77 | mat_jl = jl.CuSparseMatrixCSR(jl.spzeros(m, n)) 78 | else: 79 | data_ptr = int(mat.data.data.ptr) 80 | indices_ptr = int(mat.indices.data.ptr) 81 | indptr_ptr = int(mat.indptr.data.ptr) 82 | mat_jl = jl.Clarabel.cupy_to_cucsrmat(jl.Float64, data_ptr, indices_ptr, indptr_ptr, m, n, nnz) 83 | return mat_jl 84 | 85 | 86 | @dataclass 87 | class SolverData: 88 | """ 89 | NOTE(quill): Never use the indices of `Pcp` or `Acp` after creating Julia ptrs. 90 | """ 91 | Pcp: csr_matrix 92 | Acp: csr_matrix 93 | qcp: cp.ndarray # q cupy, not QCP 94 | bcp: cp.ndarray 95 | 96 | Pjl: CuSparseMatrixCSR = field(init=False) 97 | Ajl: CuSparseMatrixCSR = field(init=False) 98 | qjl: CuVector = field(init=False) 99 | bjl: CuVector = field(init=False) 100 | 101 | def __post_init__(self): 102 | self.Pjl = cupy_csr_to_julia_csr(self.Pcp) 103 | self.Ajl = cupy_csr_to_julia_csr(self.Acp) 104 | self.qjl = jl.Clarabel.cupy_to_cuvector(jl.Float64, int(self.qcp.data.ptr), self.qcp.size) 105 | self.bjl = jl.Clarabel.cupy_to_cuvector(jl.Float64, int(self.bcp.data.ptr), self.bcp.size) 106 | 107 | 108 | def compute_loss(target_x, target_y, target_s, x, y, s): 109 | return (0.5 * la.norm(x - target_x)**2 + 0.5 * la.norm(y - target_y)**2 110 | + 0.5 * la.norm(s - target_s)**2) 111 | 112 | 113 | @eqx.filter_jit 114 | @eqx.debug.assert_max_traces(max_traces=1) 115 | def make_step( 116 | qcp, 117 | target_x, 118 | target_y, 119 | target_s, 120 | step_size 121 | ) -> tuple[Float[Array, ""], Float[BCOO, "n n"], Float[BCOO, "m n"], 122 | Float[Array, " n"], Float[Array, " m"]]: 123 | loss = compute_loss(target_x, target_y, target_s, qcp.x, qcp.y, qcp.s) 124 | dP, dA, dq, db = qcp.vjp(qcp.x - target_x, 125 | qcp.y - target_y, 126 | qcp.s - target_s) 127 | dP.data *= -step_size 128 | dA.data *= -step_size 129 | dq *= -step_size 130 | db *= -step_size 131 | 132 | return (loss, dP, dA, dq, db) 133 | 134 | def grad_desc( 135 | Pk: Float[BCOO, "n n"], 136 | Ak: Float[BCOO, "m n"], 137 | qk: Float[Array, " n"], 138 | bk: Float[Array, " m"], 139 | target_x: Float[Array, " n"], 140 | target_y: Float[Array, " m"], 141 | target_s: Float[Array, " m"], 142 | qcp_problem_structure: QCPStructureCPU, 143 | solver_data: SolverData, 144 | cuclarabel_solver, 145 | num_iter: int = 100, 146 | step_size = 1e-5 147 | ): 148 | curr_iter = 0 149 | losses = [] 150 | 151 | while curr_iter < num_iter: 152 | 153 | jl.Clarabel.solve_b(cuclarabel_solver) 154 | 155 | xkcp = JuliaCuVector2CuPyArray(jl.solver.solution.x) 156 | xk = cp.asnumpy(xkcp) 157 | ykcp = JuliaCuVector2CuPyArray(jl.solver.solution.z) 158 | yk = cp.asnumpy(ykcp) 159 | skcp = JuliaCuVector2CuPyArray(jl.solver.solution.s) 160 | sk = cp.asnumpy(skcp) 161 | 162 | qcp = HostQCP(Pk, Ak, qk, bk, xk, yk, sk, qcp_problem_structure) 163 | loss, *dtheta_steps = make_step(qcp, target_x, target_y, target_s, step_size) 164 | losses.append(loss) 165 | 166 | dP_step, dA_step, dq_step, db_step = dtheta_steps 167 | Pk.data += dP_step.data 168 | Ak.data += dA_step.data 169 | qk += dq_step 170 | bk += db_step 171 | # update solver data 172 | solver_data.Pcp.data += cp.asarray(dP_step.data) 173 | solver_data.Acp.data += cp.asarray(dA_step.data) 174 | solver_data.qcp += cp.asarray(dq_step) 175 | solver_data.bcp += cp.asarray(db_step) 176 | 177 | # Also update solver 178 | jl.Clarabel.update_P_b(cuclarabel_solver, solver_data.Pjl) 179 | jl.Clarabel.update_A_b(cuclarabel_solver, solver_data.Ajl) 180 | jl.Clarabel.update_q_b(cuclarabel_solver, solver_data.qjl) 181 | jl.Clarabel.update_b_b(cuclarabel_solver, solver_data.bjl) 182 | 183 | curr_iter += 1 184 | 185 | return losses 186 | 187 | if __name__ == "__main__": 188 | 189 | np.random.seed(28) 190 | 191 | # m = 20 192 | # n = 10 193 | m = 2_000 194 | n = 1_000 195 | # problem = prob_generator.generate_group_lasso(n=n, m=m) #TODO(quill): CuClarabel failing for this problem 196 | start_time = time.perf_counter() 197 | target_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 198 | # target_problem = prob_generator.generate_LS_problem(m=m, n=n) 199 | prob_data_cpu = QCPProbData(target_problem) 200 | end_time = time.perf_counter() 201 | print("Time to generate the target problem," 202 | + " canonicalize it, and solve it on the CPU:" 203 | + f" {end_time - start_time} seconds") 204 | 205 | # === Obtain target vectors + warm up GPU and JIT compile CuClarabel and diffqcp === 206 | 207 | solver_data = SolverData(cpu_csr_to_cupy_csr(prob_data_cpu.Pcsr), 208 | cpu_csr_to_cupy_csr(prob_data_cpu.Acsr), 209 | cp.array(prob_data_cpu.q), 210 | cp.array(prob_data_cpu.b)) 211 | 212 | # Create Julia cone variables 213 | jl.zero_cone = prob_data_cpu.scs_cones["z"] 214 | jl.nonneg_cone = prob_data_cpu.scs_cones["l"] 215 | jl.soc = prob_data_cpu.scs_cones["q"] 216 | # Now use Julia variables in Julia code 217 | jl.seval(""" 218 | cones = Dict( 219 | "f" => zero_cone, 220 | "l" => nonneg_cone, 221 | "q" => soc 222 | ) 223 | settings = Clarabel.Settings(direct_solve_method = :cudss) 224 | settings.verbose = false 225 | """) 226 | # create CuClarabel solver 227 | jl.solver = jl.Clarabel.Solver(solver_data.Pjl, solver_data.qjl, 228 | solver_data.Ajl, solver_data.bjl, 229 | jl.cones, jl.settings) 230 | start_solve = time.perf_counter() 231 | jl.Clarabel.solve_b(jl.solver) # solve new problem w/o creating memory 232 | cp.cuda.Device().synchronize() 233 | end_solve = time.perf_counter() 234 | print(f"CuClarabel compile + solve took: {end_solve - start_solve} seconds") 235 | 236 | xcp = JuliaCuVector2CuPyArray(jl.solver.solution.x) 237 | ycp = JuliaCuVector2CuPyArray(jl.solver.solution.z) 238 | scp = JuliaCuVector2CuPyArray(jl.solver.solution.s) 239 | 240 | x_target = jnp.array(cp.asnumpy(xcp)) 241 | y_target = jnp.array(cp.asnumpy(ycp)) 242 | s_target = jnp.array(cp.asnumpy(scp)) 243 | 244 | # --- Time compiled speedup --- 245 | start_solve = time.perf_counter() 246 | jl.Clarabel.solve_b(jl.solver) # solve new problem w/o creating memory 247 | cp.cuda.Device().synchronize() 248 | end_solve = time.perf_counter() 249 | print(f"Compiled CuClarabel solve took: {end_solve - start_solve} seconds") 250 | # --- --- 251 | 252 | P = scoo_to_bcoo(prob_data_cpu.Pupper_coo) 253 | A = scoo_to_bcoo(prob_data_cpu.Acoo) 254 | q = prob_data_cpu.q 255 | b = prob_data_cpu.b 256 | scs_cones = prob_data_cpu.scs_cones 257 | problem_structure = QCPStructureCPU(P, A, scs_cones) 258 | qcp_initial = HostQCP(P, A, q, b, 259 | x_target, y_target, s_target, 260 | problem_structure) 261 | fake_target_x = 1e-3 * jnp.arange(jnp.size(q), dtype=q.dtype) 262 | fake_target_y = 1e-3 * jnp.arange(jnp.size(b), dtype=b.dtype) 263 | fake_target_s = 1e-3 * jnp.arange(jnp.size(b), dtype=b.dtype) 264 | 265 | start_time = time.perf_counter() 266 | result = make_step(qcp_initial, fake_target_x, fake_target_y, 267 | fake_target_s, step_size=1e-5) 268 | result[0].block_until_ready() 269 | end_time = time.perf_counter() 270 | # NOTE(quill): well, technically VJP + loss + step computations 271 | print("diffqcp VJP compile + compute took: ", end_time - start_time) 272 | 273 | # --- test compiled solve --- 274 | 275 | start_time = time.perf_counter() 276 | result = make_step(qcp_initial, fake_target_x, fake_target_y, 277 | fake_target_s, step_size=1e-5) 278 | result[0].block_until_ready() 279 | end_time = time.perf_counter() 280 | print("Compiled diffqcp VJP compute took: ", end_time - start_time) 281 | print("The result is on the: ", result[0].device) 282 | 283 | # --- --- 284 | 285 | # === Finished initialization === 286 | 287 | # === Now get problem we'll actually use for LL === 288 | 289 | start_time = time.perf_counter() 290 | initial_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 291 | # initial_problem = prob_generator.generate_LS_problem(m=m, n=n) 292 | prob_data_cpu = QCPProbData(initial_problem) 293 | end_time = time.perf_counter() 294 | print("Time to generate the initial (starting point) problem," 295 | + f" canonicalize it, and solve it on the cpu is: {end_time - start_time} seconds") 296 | print(f"Canonicalized n is: {prob_data_cpu.n}") 297 | print(f"Canonicalized m is: {prob_data_cpu.m}") 298 | 299 | # Put new data on GPU and create CuPy <-> Julia linking 300 | 301 | solver_data = SolverData(cpu_csr_to_cupy_csr(prob_data_cpu.Pcsr), 302 | cpu_csr_to_cupy_csr(prob_data_cpu.Acsr), 303 | cp.array(prob_data_cpu.q), 304 | cp.array(prob_data_cpu.b)) 305 | 306 | # Because problem is DPP-compliant, now just update existing solver object 307 | jl.Clarabel.update_P_b(jl.solver, solver_data.Pjl) 308 | jl.Clarabel.update_A_b(jl.solver, solver_data.Ajl) 309 | jl.Clarabel.update_q_b(jl.solver, solver_data.qjl) 310 | jl.Clarabel.update_b_b(jl.solver, solver_data.bjl) 311 | 312 | num_iter = 100 313 | 314 | start_time = time.perf_counter() 315 | losses = grad_desc(Pk=scoo_to_bcoo(prob_data_cpu.Pupper_coo), 316 | Ak=scoo_to_bcoo(prob_data_cpu.Acoo), 317 | qk = jnp.array(prob_data_cpu.q), 318 | bk = jnp.array(prob_data_cpu.b), 319 | target_x=x_target, 320 | target_y=y_target, 321 | target_s=s_target, 322 | qcp_problem_structure=problem_structure, 323 | solver_data=solver_data, 324 | cuclarabel_solver=jl.solver, 325 | num_iter=num_iter) 326 | cp.cuda.Device().synchronize() 327 | losses[0].block_until_ready() 328 | end_time = time.perf_counter() 329 | print(f"The learning loop time was {end_time - start_time} seconds") 330 | print(f"Avg. iteration (solve + VJP) time: {(end_time - start_time) / num_iter}") 331 | losses = jnp.stack(losses) 332 | losses = np.asarray(losses) 333 | 334 | plt.figure(figsize=(8, 6)) 335 | plt.plot(range(num_iter), losses, label="Objective Trajectory") 336 | plt.xlabel("num. iterations") 337 | plt.ylabel("Objective function") 338 | plt.legend() 339 | plt.title(label="diffqcp") 340 | results_dir = os.path.join(os.path.dirname(__file__), "results") 341 | if prob_data_cpu.n > 999: 342 | output_path = os.path.join(results_dir, "hetero_probability_large.svg") 343 | else: 344 | output_path = os.path.join(results_dir, "hetero_probability_small.svg") 345 | plt.savefig(output_path, format="svg") 346 | plt.close() -------------------------------------------------------------------------------- /experiments/gpu_experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | I'd like it to be known that this experiment was NOT fun to create. 3 | 4 | Some sharp bits: 5 | - Julia 1-based indexing (https://www.reddit.com/r/Julia/comments/o90ejj/some_may_hate_it_some_may_love_it/), 6 | which makes it hard to point everyone at same data 7 | - for jax and cupy arrays to be equivalent (if not forcing cupy to be float32), you better 8 | bet setting the "jax_enable_x64" flag. 9 | 10 | """ 11 | 12 | from juliacall import Main as jl 13 | # jl.seval('import Pkg; Pkg.develop(url="https://github.com/oxfordcontrol/Clarabel.jl.git")') 14 | jl.seval('using Clarabel, LinearAlgebra, SparseArrays') 15 | # jl.seval('Pkg.add("CUDA")') 16 | jl.seval('using CUDA, CUDA.CUSPARSE') 17 | 18 | type CuVector = jl.CUDA.Cuvector 19 | type CuSparseMatrixCSR = jl.CUDA.CUSPARSE.CuSparseMatrixCSR 20 | 21 | import time 22 | import os 23 | from dataclasses import dataclass, field 24 | import numpy as np 25 | import jax 26 | # TODO(quill): set JAX flags 27 | jax.config.update("jax_enable_x64", True) 28 | import jax.numpy as jnp 29 | import jax.numpy.linalg as la 30 | from jaxtyping import Float, Array 31 | import cupy as cp 32 | from cupy import from_dlpack as cp_from_dlpack 33 | from cupy.sparse import csr_matrix 34 | import equinox as eqx 35 | import scipy.sparse as sparse 36 | import patdb 37 | from jax.experimental.sparse import BCSR 38 | 39 | from diffqcp.qcp import DeviceQCP, QCPStructureGPU 40 | import experiments.cvx_problem_generator as prob_generator 41 | from tests.helpers import QCPProbData, scsr_to_bcsr 42 | import matplotlib.pyplot as plt 43 | 44 | # what auxillary objects can I create to store the CuPy <-> Julia objects? 45 | 46 | # will need helpers to do CuPy CSR <-> JAX BCSR 47 | # put in this function how to handle a 0 matrix 48 | 49 | def JuliaCuVector2CuPyArray(jl_arr): 50 | """Taken from https://github.com/cvxgrp/CuClarabel/blob/main/src/python/jl2py.py. 51 | """ 52 | # Get the device pointer from Julia 53 | pDevice = jl.Int(jl.pointer(jl_arr)) 54 | 55 | # Get array length and element type 56 | span = jl.size(jl_arr) 57 | dtype = jl.eltype(jl_arr) 58 | 59 | # Map Julia type to CuPy dtype 60 | if dtype == jl.Float64: 61 | dtype = cp.float64 62 | else: 63 | dtype = cp.float32 64 | 65 | # Compute memory size in bytes (assuming 1D vector) 66 | size_bytes = int(span[0] * cp.dtype(dtype).itemsize) 67 | 68 | # Create CuPy memory view from the Julia pointer 69 | mem = cp.cuda.UnownedMemory(pDevice, size_bytes, owner=None) 70 | memptr = cp.cuda.MemoryPointer(mem, 0) 71 | 72 | # Wrap into CuPy ndarray 73 | arr = cp.ndarray(shape=span, dtype=dtype, memptr=memptr) 74 | return arr 75 | 76 | 77 | def cpu_csr_to_cupy_csr(mat: sparse.csr_matrix) -> csr_matrix: 78 | # Ensure all arrays are 1-D 79 | data = cp.array(mat.data) 80 | indices = cp.array(mat.indices) 81 | indptr = cp.array(mat.indptr) 82 | return csr_matrix((data, indices, indptr), shape=mat.shape) 83 | 84 | 85 | def cupy_csr_to_julia_csr(mat: csr_matrix) -> CuSparseMatrixCSR: 86 | shape = mat.shape 87 | m, n = shape[0], shape[1] 88 | nnz = mat.nnz 89 | if nnz == 0: 90 | mat_jl = jl.CuSparseMatrixCSR(jl.spzeros(m, n)) 91 | else: 92 | data_ptr = int(mat.data.data.ptr) 93 | indices_ptr = int(mat.indices.data.ptr) 94 | indptr_ptr = int(mat.indptr.data.ptr) 95 | mat_jl = jl.Clarabel.cupy_to_cucsrmat(jl.Float64, data_ptr, indices_ptr, indptr_ptr, m, n, nnz) 96 | return mat_jl 97 | 98 | 99 | def cupy_csr_to_jax_bcsr(mat: csr_matrix) -> BCSR: 100 | shape = mat.shape 101 | m, n = shape[0], shape[1] 102 | nnz = mat.nnz 103 | if nnz == 0: 104 | mat_jax = BCSR.fromdense(jnp.zeros(shape=shape, dtype=mat.dtype)) 105 | else: 106 | data = jax.dlpack.from_dlpack(mat.data) 107 | indices = jax.dlpack.from_dlpack(mat.indices) 108 | indptr = jax.dlpack.from_dlpack(mat.indptr) 109 | mat_jax = BCSR((data, indices, indptr), shape=mat.shape) 110 | return mat_jax 111 | 112 | @dataclass 113 | class SolverData: 114 | """ 115 | NOTE(quill): Never use the indices of `Pcp` or `Acp` after creating Julia ptrs. 116 | """ 117 | Pcp: csr_matrix 118 | Acp: csr_matrix 119 | qcp: cp.ndarray # q cupy, not QCP 120 | bcp: cp.ndarray 121 | 122 | Pjl: CuSparseMatrixCSR = field(init=False) 123 | Ajl: CuSparseMatrixCSR = field(init=False) 124 | qjl: CuVector = field(init=False) 125 | bjl: CuVector = field(init=False) 126 | 127 | def __post_init__(self): 128 | self.Pjl = cupy_csr_to_julia_csr(self.Pcp) 129 | self.Ajl = cupy_csr_to_julia_csr(self.Acp) 130 | self.qjl = jl.Clarabel.cupy_to_cuvector(jl.Float64, int(self.qcp.data.ptr), self.qcp.size) 131 | self.bjl = jl.Clarabel.cupy_to_cuvector(jl.Float64, int(self.bcp.data.ptr), self.bcp.size) 132 | 133 | 134 | def compute_loss(target_x, target_y, target_s, x, y, s): 135 | return (0.5 * la.norm(x - target_x)**2 + 0.5 * la.norm(y - target_y)**2 136 | + 0.5 * la.norm(s - target_s)**2) 137 | 138 | 139 | @eqx.filter_jit 140 | @eqx.debug.assert_max_traces(max_traces=1) 141 | def make_step( 142 | qcp, 143 | target_x, 144 | target_y, 145 | target_s, 146 | step_size 147 | ) -> tuple[Float[Array, ""], Float[BCSR, "n n"], Float[BCSR, "m n"], 148 | Float[Array, " n"], Float[Array, " m"]]: 149 | loss = compute_loss(target_x, target_y, target_s, qcp.x, qcp.y, qcp.s) 150 | dP, dA, dq, db = qcp.vjp(qcp.x - target_x, 151 | qcp.y - target_y, 152 | qcp.s - target_s) 153 | dP.data *= -step_size 154 | dA.data *= -step_size 155 | dq *= -step_size 156 | db *= -step_size 157 | 158 | return (loss, dP, dA, dq, db) 159 | 160 | 161 | def grad_desc( 162 | Pk: Float[BCSR, "n n"], 163 | Ak: Float[BCSR, "m n"], 164 | solver_data: SolverData, 165 | target_x: Float[Array, " n"], 166 | target_y: Float[Array, " m"], 167 | target_s: Float[Array, " m"], 168 | cuclarabel_solver, 169 | qcp_problem_structure: QCPStructureGPU, 170 | num_iter: int = 100, 171 | step_size: float = 1e-5, 172 | ) -> list[Float[Array, ""]]: 173 | 174 | curr_iter = 0 175 | losses = [] 176 | 177 | while curr_iter < num_iter: 178 | 179 | jl.Clarabel.solve_b(cuclarabel_solver) 180 | 181 | Pk.data = jax.dlpack.from_dlpack(solver_data.Pcp.data) 182 | Ak.data = jax.dlpack.from_dlpack(solver_data.Acp.data) 183 | qk = jax.dlpack.from_dlpack(solver_data.qcp) 184 | bk = jax.dlpack.from_dlpack(solver_data.bcp) 185 | 186 | xk = jax.dlpack.from_dlpack(JuliaCuVector2CuPyArray(jl.solver.solution.x)) 187 | yk = jax.dlpack.from_dlpack(JuliaCuVector2CuPyArray(jl.solver.solution.z)) 188 | sk = jax.dlpack.from_dlpack(JuliaCuVector2CuPyArray(jl.solver.solution.s)) 189 | 190 | qcp = DeviceQCP(Pk, Ak, qk, bk, xk, yk, sk, qcp_problem_structure) 191 | loss, *dtheta_steps = make_step(qcp, target_x, target_y, target_s, step_size) 192 | losses.append(loss) 193 | 194 | dP_step, dA_step, dq_step, db_step = dtheta_steps 195 | solver_data.Pcp.data += cp_from_dlpack(dP_step.data) 196 | solver_data.Acp.data += cp_from_dlpack(dA_step.data) 197 | solver_data.qcp += cp_from_dlpack(dq_step) 198 | solver_data.bcp += cp_from_dlpack(db_step) 199 | 200 | # Also update solver 201 | jl.Clarabel.update_P_b(cuclarabel_solver, solver_data.Pjl) 202 | jl.Clarabel.update_A_b(cuclarabel_solver, solver_data.Ajl) 203 | jl.Clarabel.update_q_b(cuclarabel_solver, solver_data.qjl) 204 | jl.Clarabel.update_b_b(cuclarabel_solver, solver_data.bjl) 205 | 206 | curr_iter += 1 207 | 208 | return losses 209 | 210 | 211 | if __name__ == "__main__": 212 | 213 | np.random.seed(28) 214 | 215 | m = 20 216 | n = 10 217 | # m = 2_000 218 | # n = 1_000 219 | start_time = time.perf_counter() 220 | target_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 221 | prob_data_cpu = QCPProbData(target_problem) 222 | end_time = time.perf_counter() 223 | print("Time to generate the target problem," 224 | + " canonicalize it, and solve it on the CPU:" 225 | + f" {end_time - start_time} seconds") 226 | 227 | # === Obtain target vectors + warm up GPU and JIT compile CuClarabel and diffqcp === 228 | 229 | solver_data = SolverData(cpu_csr_to_cupy_csr(prob_data_cpu.Pcsr), 230 | cpu_csr_to_cupy_csr(prob_data_cpu.Acsr), 231 | cp.array(prob_data_cpu.q), 232 | cp.array(prob_data_cpu.b)) 233 | 234 | # Create Julia cone variables 235 | jl.zero_cone = prob_data_cpu.scs_cones["z"] 236 | jl.nonneg_cone = prob_data_cpu.scs_cones["l"] 237 | jl.soc = prob_data_cpu.scs_cones["q"] 238 | # Now use Julia variables in Julia code 239 | jl.seval(""" 240 | cones = Dict( 241 | "f" => zero_cone, 242 | "l" => nonneg_cone, 243 | "q" => soc 244 | ) 245 | settings = Clarabel.Settings(direct_solve_method = :cudss) 246 | settings.verbose = false 247 | """) 248 | # create CuClarabel solver 249 | jl.solver = jl.Clarabel.Solver(solver_data.Pjl, solver_data.qjl, 250 | solver_data.Ajl, solver_data.bjl, 251 | jl.cones, jl.settings) 252 | start_solve = time.perf_counter() 253 | jl.Clarabel.solve_b(jl.solver) # solve new problem w/o creating memory 254 | cp.cuda.Device().synchronize() 255 | end_solve = time.perf_counter() 256 | print(f"CuClarabel compile + solve took: {end_solve - start_solve} seconds") 257 | 258 | xcp = JuliaCuVector2CuPyArray(jl.solver.solution.x) 259 | ycp = JuliaCuVector2CuPyArray(jl.solver.solution.z) 260 | scp = JuliaCuVector2CuPyArray(jl.solver.solution.s) 261 | 262 | x_target = jax.dlpack.from_dlpack(cp.array(xcp, copy=True)) 263 | y_target = jax.dlpack.from_dlpack(cp.array(ycp, copy=True)) 264 | s_target = jax.dlpack.from_dlpack(cp.array(scp, copy=True)) 265 | 266 | # --- Time compiled speedup --- 267 | start_solve = time.perf_counter() 268 | jl.Clarabel.solve_b(jl.solver) # solve new problem w/o creating memory 269 | cp.cuda.Device().synchronize() 270 | end_solve = time.perf_counter() 271 | print(f"Compiled CuClarabel solve took: {end_solve - start_solve} seconds") 272 | # --- --- 273 | 274 | # NOTE(quill): go from host data since the indices of `Pcp` and `Acp` have been corrupted 275 | P = scsr_to_bcsr(prob_data_cpu.Pcsr) 276 | A = scsr_to_bcsr(prob_data_cpu.Acsr) 277 | q = jax.dlpack.from_dlpack(solver_data.qcp) 278 | b = jax.dlpack.from_dlpack(solver_data.bcp) 279 | 280 | # --- JIT compile `make_step` (so compile what's needed for `diffqcp`) --- 281 | 282 | problem_structure = QCPStructureGPU(P=P, A=A, cone_dims=prob_data_cpu.scs_cones) 283 | 284 | qcp_initial = DeviceQCP(P=P, A=A, q=q, b=b, 285 | x=x_target, y=y_target, s=s_target, 286 | problem_structure=problem_structure) 287 | 288 | fake_target_x = 1e-3 * jnp.arange(jnp.size(q), dtype=q.dtype) 289 | fake_target_y = 1e-3 * jnp.arange(jnp.size(b), dtype=b.dtype) 290 | fake_target_s = 1e-3 * jnp.arange(jnp.size(b), dtype=b.dtype) 291 | 292 | start_time = time.perf_counter() 293 | result = make_step(qcp_initial, fake_target_x, fake_target_y, 294 | fake_target_s, step_size=1e-5) 295 | result[0].block_until_ready() 296 | end_time = time.perf_counter() 297 | # NOTE(quill): well, technically VJP + loss + step computations 298 | print("diffqcp VJP compile + compute took: ", end_time - start_time) 299 | 300 | # --- test compiled solve --- 301 | 302 | start_time = time.perf_counter() 303 | result = make_step(qcp_initial, fake_target_x, fake_target_y, 304 | fake_target_s, step_size=1e-5) 305 | result[0].block_until_ready() 306 | end_time = time.perf_counter() 307 | print("Compiled diffqcp VJP compute took: ", end_time - start_time) 308 | 309 | # --- --- 310 | 311 | # === Finished initialization === 312 | 313 | # === Now get problem we'll actually use for LL === 314 | 315 | start_time = time.perf_counter() 316 | initial_problem = prob_generator.generate_least_squares_eq(m=m, n=n) 317 | # initial_problem = prob_generator.generate_LS_problem(m=m, n=n) 318 | prob_data_cpu = QCPProbData(initial_problem) 319 | end_time = time.perf_counter() 320 | print("Time to generate the initial (starting point) problem," 321 | + f" canonicalize it, and solve it on the cpu is: {end_time - start_time} seconds") 322 | print(f"Canonicalized n is: {prob_data_cpu.n}") 323 | print(f"Canonicalized m is: {prob_data_cpu.m}") 324 | 325 | # Put new data on GPU and create CuPy <-> Julia linking 326 | 327 | solver_data = SolverData(cpu_csr_to_cupy_csr(prob_data_cpu.Pcsr), 328 | cpu_csr_to_cupy_csr(prob_data_cpu.Acsr), 329 | cp.array(prob_data_cpu.q), 330 | cp.array(prob_data_cpu.b)) 331 | 332 | # Because problem is DPP-compliant, now just update existing solver object 333 | jl.Clarabel.update_P_b(jl.solver, solver_data.Pjl) 334 | jl.Clarabel.update_A_b(jl.solver, solver_data.Ajl) 335 | jl.Clarabel.update_q_b(jl.solver, solver_data.qjl) 336 | jl.Clarabel.update_b_b(jl.solver, solver_data.bjl) 337 | 338 | # Now let's create data for JAX to use (that are 0-based, :eyeroll) 339 | Pk = BCSR((jax.dlpack.from_dlpack(solver_data.Pcp.data), 340 | prob_data_cpu.Pcsr.indices, 341 | prob_data_cpu.Pcsr.indptr), shape=solver_data.Pcp.shape) 342 | Ak = BCSR((jax.dlpack.from_dlpack(solver_data.Acp.data), 343 | prob_data_cpu.Acsr.indices, 344 | prob_data_cpu.Acsr.indptr), shape=solver_data.Acp.shape) 345 | 346 | # num_iter = 100 347 | num_iter = 25 348 | cp.cuda.Device().synchronize() 349 | start_time = time.perf_counter() 350 | losses = grad_desc(Pk=Pk, Ak=Ak, solver_data=solver_data, 351 | target_x=x_target, target_y=y_target, target_s=s_target, 352 | cuclarabel_solver=jl.solver, qcp_problem_structure=problem_structure, 353 | num_iter=num_iter) 354 | losses[0].block_until_ready() 355 | end_time = time.perf_counter() 356 | print(f"The learning loop time was {end_time - start_time} seconds") 357 | print(f"Avg. iteration (solve + VJP) time: {(end_time - start_time) / num_iter}") 358 | losses = jnp.stack(losses) 359 | losses = np.asarray(losses) 360 | 361 | plt.figure(figsize=(8, 6)) 362 | plt.plot(range(num_iter), losses, label="Objective Trajectory") 363 | plt.xlabel("num. iterations") 364 | plt.ylabel("Objective function") 365 | plt.legend() 366 | plt.title(label="diffqcp") 367 | results_dir = os.path.join(os.path.dirname(__file__), "results") 368 | if prob_data_cpu.n > 999: 369 | output_path = os.path.join(results_dir, "diffqcp_probability_large.svg") 370 | else: 371 | output_path = os.path.join(results_dir, "diffqcp_probability_small.svg") 372 | plt.savefig(output_path, format="svg") 373 | plt.close() 374 | 375 | -------------------------------------------------------------------------------- /tests/test_cone_projectors.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import numpy as np 4 | import cvxpy as cvx 5 | from jax import vmap, jit 6 | import equinox as eqx 7 | import jax.numpy as jnp 8 | import jax.random as jr 9 | 10 | import diffqcp.cones.canonical as cone_lib 11 | from diffqcp.cones.exp import in_exp, in_exp_dual, ExponentialConeProjector 12 | from .helpers import tree_allclose 13 | 14 | def _test_dproj_finite_diffs( 15 | projection_func: Callable, key_func, dim: int, num_batches: int = 0 16 | ): 17 | if num_batches > 0: 18 | x = jr.normal(key_func(), (num_batches, dim)) 19 | dx = jr.normal(key_func(), (num_batches, dim)) 20 | # NOTE(quill): `jit`ing the following slows the check down 21 | # since this is called in a loop, so we end up `jit`ing multiple times. 22 | # Just doing it here to ensure it works. 23 | _projector = jit(vmap(projection_func)) 24 | else: 25 | x = jr.normal(key_func(), dim) 26 | dx = jr.normal(key_func(), dim) 27 | _projector = jit(projection_func) 28 | 29 | dx = 1e-5 * dx 30 | 31 | proj_x, dproj_x = _projector(x) 32 | proj_x_plus_dx, _ = _projector(x + dx) 33 | 34 | dproj_x_fd = proj_x_plus_dx - proj_x 35 | dproj_x_dx = dproj_x.mv(dx) 36 | assert dproj_x_dx is not None 37 | assert dproj_x_fd is not None 38 | assert tree_allclose(dproj_x_dx, dproj_x_fd) 39 | 40 | 41 | def test_zero_projector(getkey): 42 | n = 100 43 | num_batches = 10 44 | 45 | for dual in [True, False]: 46 | 47 | _zero_projector = cone_lib.ZeroConeProjector(onto_dual=dual) 48 | zero_projector = jit(_zero_projector) 49 | batched_zero_projector = jit(vmap(_zero_projector)) 50 | 51 | for _ in range(15): 52 | 53 | x = jr.normal(getkey(), n) 54 | 55 | proj_x, _ = zero_projector(x) 56 | truth = jnp.zeros_like(x) if not dual else x 57 | assert tree_allclose(truth, proj_x) 58 | _test_dproj_finite_diffs(zero_projector, getkey, dim=n, num_batches=0) 59 | 60 | # --- batched --- 61 | x = jr.normal(getkey(), (num_batches, n)) 62 | proj_x, _ = batched_zero_projector(x) 63 | truth = jnp.zeros_like(x) if not dual else x 64 | assert tree_allclose(truth, proj_x) 65 | _test_dproj_finite_diffs(_zero_projector, getkey, dim=n, num_batches=num_batches) 66 | 67 | 68 | def test_nonnegative_projector(getkey): 69 | n = 100 70 | num_batches = 10 71 | 72 | _nn_projector = cone_lib.NonnegativeConeProjector() 73 | nn_projector = jit(_nn_projector) 74 | batched_nn_projector = jit(vmap(_nn_projector)) 75 | 76 | for _ in range(15): 77 | 78 | x = jr.normal(getkey(), n) 79 | proj_x, _ = nn_projector(x) 80 | truth = jnp.maximum(x, 0) 81 | assert tree_allclose(truth, proj_x) 82 | _test_dproj_finite_diffs(nn_projector, getkey, dim=n, num_batches=0) 83 | 84 | x = jr.normal(getkey(), (num_batches, n)) 85 | proj_x, _ = batched_nn_projector(x) 86 | truth = jnp.maximum(x, 0) 87 | assert tree_allclose(truth, proj_x) 88 | _test_dproj_finite_diffs(_nn_projector, getkey, dim=n, num_batches=10) 89 | 90 | 91 | def _proj_soc_via_cvxpy(x: np.ndarray) -> np.ndarray: 92 | n = x.size 93 | z = cvx.Variable(n) 94 | objective = cvx.Minimize(cvx.sum_squares(z - x)) 95 | constraints = [cvx.norm(z[1:], 2) <= z[0]] 96 | prob = cvx.Problem(objective, constraints) 97 | prob.solve(solver=cvx.SCS, eps=1e-10) 98 | return z.value 99 | 100 | 101 | def test_soc_private_projector(getkey): 102 | n = 100 103 | num_batches = 10 104 | 105 | _soc_projector = cone_lib._SecondOrderConeProjector(dim=n) 106 | soc_projector = eqx.filter_jit(_soc_projector) 107 | batched_soc_projector = jit(vmap(_soc_projector)) 108 | 109 | for _ in range(15): 110 | x_jnp = jr.normal(getkey(), n) 111 | x_np = np.array(x_jnp) 112 | proj_x_solver = jnp.array(_proj_soc_via_cvxpy(x_np)) 113 | 114 | proj_x, _ = soc_projector(x_jnp) 115 | assert tree_allclose(proj_x, proj_x_solver) 116 | _test_dproj_finite_diffs(soc_projector, getkey, dim=n, num_batches=0) 117 | 118 | # --- batched --- 119 | x_jnp = jr.normal(getkey(), (num_batches, n)) 120 | x_np = np.array(x_jnp) 121 | proj_x, _ = batched_soc_projector(x_jnp) 122 | for i in range(num_batches): 123 | proj_x_solver = jnp.array(_proj_soc_via_cvxpy(x_np[i, :])) 124 | assert tree_allclose(proj_x[i, :], proj_x_solver) 125 | 126 | _test_dproj_finite_diffs(_soc_projector, getkey, dim=n, num_batches=num_batches) 127 | 128 | 129 | def _test_soc_projector(dims, num_batches, keyfunc): 130 | total_dim = sum(dims) 131 | 132 | _soc_projector = cone_lib.SecondOrderConeProjector(dims=dims) 133 | soc_projector = eqx.filter_jit(_soc_projector) 134 | batched_soc_projector = eqx.filter_jit(eqx.filter_vmap(_soc_projector)) 135 | 136 | for _ in range(15): 137 | 138 | x_jnp = jr.normal(keyfunc(), total_dim) 139 | x_np = np.array(x_jnp) 140 | start = 0 141 | solns = [] 142 | for dim in dims: 143 | end = start + dim 144 | solns.append(jnp.array(_proj_soc_via_cvxpy(x_np[start:end]))) 145 | start = end 146 | proj_x_solver = jnp.concatenate(solns) 147 | proj_x, _ = soc_projector(x_jnp) 148 | assert tree_allclose(proj_x, proj_x_solver) 149 | _test_dproj_finite_diffs(soc_projector, keyfunc, dim=total_dim, num_batches=0) 150 | 151 | # --- batched --- 152 | x_jnp = jr.normal(keyfunc(), (num_batches, total_dim)) 153 | x_np = np.array(x_jnp) 154 | proj_x, _ = batched_soc_projector(x_jnp) 155 | for i in range(num_batches): 156 | start = 0 157 | solns = [] 158 | for dim in dims: 159 | end = start + dim 160 | solns.append(jnp.array(_proj_soc_via_cvxpy(x_np[i, start:end]))) 161 | start = end 162 | proj_x_solver = jnp.concatenate(solns) 163 | assert tree_allclose(proj_x[i, :], proj_x_solver) 164 | 165 | _test_dproj_finite_diffs(_soc_projector, keyfunc, dim=total_dim, num_batches=num_batches) 166 | 167 | 168 | def test_soc_projector_simple(getkey): 169 | dims = [10, 15, 30] 170 | num_batches = 10 171 | _test_soc_projector(dims, num_batches, getkey) 172 | 173 | 174 | def test_soc_projector_hard(getkey): 175 | dims = [5, 5, 5, 3, 3, 4, 5, 2, 2] 176 | num_batches = 10 177 | _test_soc_projector(dims, num_batches, getkey) 178 | 179 | 180 | def _proj_psd_via_cvxpy(x: np.ndarray) -> np.ndarray: 181 | """Project vectorized symmetric matrix x onto the PSD cone using CVXPY.""" 182 | size = cone_lib.symm_dim_to_size(x.size) 183 | X = np.zeros((size, size), dtype=x.dtype) 184 | idxs = np.triu_indices(size) 185 | sqrt2 = np.sqrt(2.0) 186 | X[idxs] = x / sqrt2 187 | X = X + X.T 188 | diag = np.arange(size) 189 | X[diag, diag] /= sqrt2 190 | 191 | z = cvx.Variable((size, size), PSD=True) 192 | objective = cvx.Minimize(cvx.sum_squares(z - X)) 193 | prob = cvx.Problem(objective) 194 | prob.solve(solver="SCS", eps=1e-10) 195 | Z_val = z.value 196 | vec = Z_val[idxs] 197 | off_diag = idxs[0] != idxs[1] 198 | vec[off_diag] *= sqrt2 199 | return vec 200 | 201 | 202 | def _test_psd_projector(sizes, num_batches, keyfunc): 203 | total_size = sum([cone_lib.symm_size_to_dim(s) for s in sizes]) 204 | 205 | _psd_projector = cone_lib.PSDConeProjector(sizes=sizes) 206 | psd_projector = eqx.filter_jit(_psd_projector) 207 | batched_psd_projector = eqx.filter_jit(eqx.filter_vmap(_psd_projector)) 208 | 209 | for _ in range(10): 210 | x_jnp = jr.normal(keyfunc(), total_size) 211 | x_np = np.array(x_jnp) 212 | start = 0 213 | solns = [] 214 | for size in sizes: 215 | end = start + cone_lib.symm_size_to_dim(size) 216 | solns.append(jnp.array(_proj_psd_via_cvxpy(x_np[start:end]))) 217 | start = end 218 | proj_x_solver = jnp.concatenate(solns) 219 | proj_x, _ = psd_projector(x_jnp) 220 | assert tree_allclose(proj_x, proj_x_solver) 221 | 222 | # --- batched --- 223 | x_jnp = jr.normal(keyfunc(), (num_batches, total_size)) 224 | x_np = np.array(x_jnp) 225 | proj_x, _ = batched_psd_projector(x_jnp) 226 | for i in range(num_batches): 227 | start = 0 228 | solns = [] 229 | for size in sizes: 230 | end = start + cone_lib.symm_size_to_dim(size) 231 | solns.append(jnp.array(_proj_psd_via_cvxpy(x_np[i, start:end]))) 232 | start = end 233 | proj_x_solver = jnp.concatenate(solns) 234 | assert tree_allclose(proj_x[i, :], proj_x_solver) 235 | 236 | 237 | def test_psd_projector_simple(getkey): 238 | sizes = [3, 4, 10] 239 | num_batches = 5 240 | _test_psd_projector(sizes, num_batches, getkey) 241 | 242 | def test_psd_projector_hard(getkey): 243 | sizes = [2, 3, 3, 4, 4, 2] 244 | num_batches = 5 245 | _test_psd_projector(sizes, num_batches, getkey) 246 | 247 | def _proj_pow_via_cvxpy(x: np.ndarray, alphas: list[float]) -> np.ndarray: 248 | """Project x onto the product of 3D power cones with given alphas using CVXPY.""" 249 | n = len(x) 250 | assert n % 3 == 0 251 | num_cones = n // 3 252 | var = cvx.Variable(n) 253 | constraints = [] 254 | for i in range(num_cones): 255 | constraints.append(cvx.PowCone3D(var[3*i], var[3*i+1], var[3*i+2], alphas[i])) 256 | objective = cvx.Minimize(cvx.sum_squares(var - x)) 257 | prob = cvx.Problem(objective, constraints) 258 | prob.solve(solver="SCS", eps=1e-10) 259 | return np.array(var.value) 260 | 261 | def test_proj_pow(): 262 | np.random.seed(0) 263 | n = 3 264 | alphas = np.random.uniform(low=0, high=1, size=15) 265 | for alpha in alphas: 266 | x = np.random.randn(n) 267 | proj_cvx = _proj_pow_via_cvxpy(x, [alpha]) 268 | projector = cone_lib.PowerConeProjector([alpha], onto_dual=False) 269 | # this is not efficient since recompiling; just doing for testing. 270 | proj_jax, _ = eqx.filter_jit(projector)(jnp.array(x)) 271 | proj_jax = np.array(proj_jax) 272 | print("proj_jax: ", proj_jax) 273 | print("proj_cvx: ", proj_cvx) 274 | assert np.allclose(proj_jax, proj_cvx, atol=1e-6, rtol=1e-7) 275 | 276 | # def test_proj_pow_diffcpish(): 277 | # # TODO(quill): test itself needs fixing 278 | # np.random.seed(0) 279 | # alphas1 = np.random.uniform(low=0.01, high=1, size=15) 280 | # alphas2 = np.random.uniform(low=0.01, high=1, size=15) 281 | # alphas3 = np.random.uniform(low=0.01, high=1, size=15) 282 | # for i in range(alphas1.shape[0]): 283 | # x = np.random.randn(9) 284 | # # primal 285 | # proj_cvx = _proj_pow_via_cvxpy(x, [alphas1[i], alphas2[i], alphas3[i]]) 286 | # projector = cone_lib.PowerConeProjector([alphas1[i], alphas2[i], alphas3[i]], onto_dual=False) 287 | # proj_jax, _ = projector(jnp.array(x)) 288 | # assert np.allclose(np.array(proj_jax), proj_cvx, atol=1e-4, rtol=1e-7) 289 | # # dual 290 | # proj_dual = cone_lib.PowerConeProjector([-alphas1[i], -alphas2[i], -alphas3[i]], onto_dual=False) 291 | # proj_cvx_dual = _proj_pow_via_cvxpy(-x, [-alphas1[i], -alphas2[i], -alphas3[i]]) 292 | # proj_jax_dual, _ = proj_dual(jnp.array(x)) 293 | # # Moreau: Pi_K^*(v) = v + Pi_K(-v) 294 | # assert np.allclose(np.array(proj_jax_dual), x + proj_cvx_dual, atol=1e-4) 295 | 296 | def test_proj_pow_specific(): 297 | n = 3 298 | x = np.array([1., 2., 3.]) 299 | alpha = 0.6 300 | proj_cvx = _proj_pow_via_cvxpy(x, [alpha]) 301 | projector = cone_lib.PowerConeProjector([alpha], onto_dual=False) 302 | proj_jax, _ = projector(jnp.array(x)) 303 | proj_jax = np.array(proj_jax) 304 | print("proj_jax: ", proj_jax) 305 | print("proj_cvx:", proj_cvx) 306 | assert np.allclose(np.array(proj_jax), proj_cvx, atol=1e-6, rtol=1e-7) 307 | 308 | 309 | def test_product_projector(getkey): 310 | """assumes that the other tests in this file pass.""" 311 | zero_dim = 15 312 | nn_dim = 23 313 | soc_dims = [5, 5, 5, 3, 3, 4, 5, 2, 2] 314 | soc_total_dim = sum(soc_dims) 315 | total_dim = zero_dim + nn_dim + soc_total_dim 316 | num_batches = 10 317 | cones = { 318 | cone_lib.ZERO : zero_dim, 319 | cone_lib.NONNEGATIVE: nn_dim, 320 | cone_lib.SOC : soc_dims 321 | } 322 | 323 | _nn_projector = cone_lib.NonnegativeConeProjector() 324 | nn_projector = eqx.filter_jit(_nn_projector) 325 | batched_nn_projector = eqx.filter_jit(eqx.filter_vmap(_nn_projector)) 326 | 327 | _soc_projector = cone_lib.SecondOrderConeProjector(dims=soc_dims) 328 | soc_projector = eqx.filter_jit(_soc_projector) 329 | batched_soc_projector = eqx.filter_jit(eqx.filter_vmap(_soc_projector)) 330 | 331 | for dual in [True, False]: 332 | 333 | _zero_projector = cone_lib.ZeroConeProjector(onto_dual=dual) 334 | zero_projector = eqx.filter_jit(_zero_projector) 335 | batched_zero_projector = eqx.filter_jit(eqx.filter_vmap(_zero_projector)) 336 | 337 | _cone_projector = cone_lib.ProductConeProjector(cones, onto_dual=dual) 338 | cone_projector = eqx.filter_jit(_cone_projector) 339 | batched_cone_projector = eqx.filter_jit(eqx.filter_vmap(_cone_projector)) 340 | 341 | for _ in range(15): 342 | x = jr.normal(getkey(), total_dim) 343 | proj_x, _ = cone_projector(x) 344 | proj_x_zero, _ = zero_projector(x[0:zero_dim]) 345 | proj_x_nn, _ = nn_projector(x[zero_dim:zero_dim+nn_dim]) 346 | proj_x_soc, _ = soc_projector(x[zero_dim+nn_dim:zero_dim+nn_dim+soc_total_dim]) 347 | proj_x_handmade = jnp.concatenate([proj_x_zero, 348 | proj_x_nn, 349 | proj_x_soc]) 350 | assert tree_allclose(proj_x, proj_x_handmade) 351 | _test_dproj_finite_diffs(cone_projector, getkey, dim=total_dim, num_batches=0) 352 | 353 | # --- batched --- 354 | x = jr.normal(getkey(), (num_batches, total_dim)) 355 | proj_x, _ = batched_cone_projector(x) 356 | proj_x_zero, _ = batched_zero_projector(x[:, 0:zero_dim]) 357 | proj_x_nn, _ = batched_nn_projector(x[:, zero_dim:zero_dim+nn_dim]) 358 | proj_x_soc, _ = batched_soc_projector(x[:, zero_dim+nn_dim:zero_dim+nn_dim+soc_total_dim]) 359 | proj_x_handmade = jnp.concatenate([proj_x_zero, 360 | proj_x_nn, 361 | proj_x_soc], axis=-1) 362 | assert tree_allclose(proj_x, proj_x_handmade) 363 | _test_dproj_finite_diffs(cone_projector, getkey, dim=total_dim, num_batches=num_batches) 364 | 365 | 366 | def test_in_exp(getkey): 367 | in_vecs = [[0., 0., 1.], [-1., 0., 0.], [1., 1., 5.]] 368 | for vec in in_vecs: 369 | assert in_exp(jnp.array(vec)) 370 | not_in_vecs = [[1., 0., 0.], [-1., -1., 1.], [-1., 0., -1.]] 371 | for vec in not_in_vecs: 372 | assert not in_exp(jnp.array(vec)) 373 | 374 | 375 | def test_in_exp_dual(getkey): 376 | in_vecs = [[0., 1., 1.], [-1., 1., 5.]] 377 | not_in_vecs = [[0., -1., 1.], [0., 1., -1.]] 378 | for vec in in_vecs: 379 | arr = jnp.array(vec) 380 | assert in_exp_dual(arr) 381 | for vec in not_in_vecs: 382 | arr = jnp.array(vec) 383 | assert not in_exp_dual(vec) 384 | 385 | 386 | def test_proj_exp_scs(getkey): 387 | """test values ported from scs/test/problems/test_exp_cone.h 388 | """ 389 | vs = [jnp.array([1.0, 2.0, 3.0]), 390 | jnp.array([0.14814832, 1.04294573, 0.67905585]), 391 | jnp.array([-0.78301134, 1.82790084, -1.05417044]), 392 | jnp.array([1.3282585, -0.43277314, 1.7468072]), 393 | jnp.array([0.67905585, 0.14814832, 1.04294573]), 394 | jnp.array([0.50210027, 0.12314491, -1.77568921])] 395 | 396 | num_cones = len(vs) 397 | 398 | vp_true = [jnp.array([0.8899428, 1.94041881, 3.06957226]), 399 | jnp.array([-0.02001571, 0.8709169, 0.85112944]), 400 | jnp.array([-1.17415616, 0.9567094, 0.280399]), 401 | jnp.array([0.53160512, 0.2804836, 1.86652094]), 402 | jnp.array([0.38322814, 0.27086569, 1.11482228]), 403 | jnp.array([0.0, 0.0, 0.0])] 404 | vd_true = [jnp.array([-0., 2., 3.]), 405 | jnp.array([-0., 1.04294573, 0.67905585]), 406 | jnp.array([-0.68541419, 1.85424082, 0.01685653]), 407 | jnp.array([-0.02277033, -0.12164823, 1.75085347]), 408 | jnp.array([-0., 0.14814832, 1.04294573]), 409 | jnp.array([-0., 0.12314491, -0.])] 410 | 411 | primal_projector = ExponentialConeProjector(1, onto_dual=False) 412 | dual_projector = ExponentialConeProjector(1, onto_dual=True) 413 | 414 | import diffcp._diffcp as _diffcp 415 | from diffcp.cones import parse_cone_dict_cpp 416 | cones = [("ep", 1)] 417 | cones = parse_cone_dict_cpp(cones) 418 | 419 | for i in range(len(vs)): 420 | print(f"=== trial {i} ===") 421 | v = vs[i] 422 | vp, Jp = jit(primal_projector)(v) 423 | vd, Jd = jit(dual_projector)(v) 424 | assert jnp.allclose(vp, vp_true[i]) 425 | assert jnp.allclose(vd, vd_true[i]) 426 | _test_dproj_finite_diffs(primal_projector, getkey, dim=3) 427 | J_diffcp = _diffcp.dprojection(np.array(v), cones, False) 428 | e1, e2, e3 = np.array([1., 0., 0.]), np.array([0., 1., 0.]), np.array([0., 0., 1.]) 429 | col1 = J_diffcp.matvec(e1) 430 | col2 = J_diffcp.matvec(e2) 431 | col3 = J_diffcp.matvec(e3) 432 | J_materialized_diffcp = np.column_stack([col1, col2, col3]) 433 | assert np.allclose(J_materialized_diffcp, np.array(Jp.jacobians[0, ...])) 434 | J_diffcp = _diffcp.dprojection(np.array(v), cones, True) 435 | col1 = J_diffcp.matvec(e1) 436 | col2 = J_diffcp.matvec(e2) 437 | col3 = J_diffcp.matvec(e3) 438 | J_materialized_diffcp = np.column_stack([col1, col2, col3]) 439 | assert np.allclose(J_materialized_diffcp, np.array(Jd.jacobians[0, ...])) 440 | _test_dproj_finite_diffs(dual_projector, getkey, dim=3) 441 | 442 | # Now test batched 443 | vps, _ = vmap(primal_projector)(jnp.array(vs)) 444 | vds, _ = vmap(dual_projector)(jnp.array(vs)) 445 | 446 | for i in range(len(vs)): 447 | assert jnp.allclose(vps[i, :], vp_true[i]) 448 | assert jnp.allclose(vds[i, :], vd_true[i]) 449 | 450 | # now test with num_cones > 1 for single projector 451 | vs = jnp.concatenate(vs) 452 | vp_true = jnp.concatenate(vp_true) 453 | vd_true = jnp.concatenate(vd_true) 454 | 455 | primal_projector = ExponentialConeProjector(num_cones, onto_dual=False) 456 | dual_projector = ExponentialConeProjector(num_cones, onto_dual=True) 457 | 458 | vps, _ = primal_projector(vs) 459 | vds, _ = dual_projector(vs) 460 | 461 | assert jnp.allclose(vps, vp_true) 462 | assert jnp.allclose(vds, vd_true) 463 | _test_dproj_finite_diffs(primal_projector, getkey, dim=3*num_cones) 464 | _test_dproj_finite_diffs(dual_projector, getkey, dim=3*num_cones) -------------------------------------------------------------------------------- /diffqcp/qcp.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Callable 3 | import functools as ft 4 | import jax 5 | from jax import eval_shape 6 | import jax.numpy as jnp 7 | import equinox as eqx 8 | import lineax as lx 9 | from lineax import AbstractLinearOperator, IdentityLinearOperator, linear_solve 10 | from jaxtyping import Float, Array 11 | from jax.experimental.sparse import BCOO, BCSR 12 | try: 13 | from cupy import from_dlpack as cp_from_dlpack 14 | from cupyx.scipy.sparse import csr_matrix 15 | from nvmath.sparse.advanced import DirectSolver, DirectSolverAlgType 16 | except ImportError: 17 | cp_from_dlpack = None 18 | csr_matrix = None 19 | DirectSolver = None 20 | DirectSolverAlgType = None 21 | 22 | from diffqcp.problem_data import (QCPStructureCPU, QCPStructureGPU, 23 | QCPStructure, ObjMatrixCPU, ObjMatrixGPU, ObjMatrix) 24 | from diffqcp.linops import _BlockLinearOperator 25 | from diffqcp.qcp_derivs import (_DuQ, _d_data_Q, _d_data_Q_adjoint_cpu, _d_data_Q_adjoint_gpu) 26 | 27 | class AbstractQCP(eqx.Module): 28 | """Quadratic Cone Program. 29 | 30 | Represents a (solved) quadratic convex cone program given 31 | by the primal-dual problems 32 | 33 | (P) minimize (1/2)x^T P x + q^T x 34 | subject to Ax + s = b 35 | s in K 36 | 37 | (D) minimize -(1/2)x^T P x - b^T y 38 | subject to A^T y + q = 0 39 | y in K^*, 40 | 41 | where P, A, q, b are mutable problem data, K and K^* are 42 | immutable problem data, and (x, y, s) are the optimization 43 | variables. 44 | """ 45 | 46 | P: eqx.AbstractVar[ObjMatrix] 47 | A: eqx.AbstractVar[BCSR | BCOO] 48 | AT: eqx.AbstractVar[BCSR | BCOO] 49 | q: eqx.AbstractVar[Array] 50 | b: eqx.AbstractVar[Array] 51 | x: eqx.AbstractVar[Array] 52 | y: eqx.AbstractVar[Array] 53 | s: eqx.AbstractVar[Array] 54 | problem_structure: eqx.AbstractVar[QCPStructure] 55 | 56 | def _form_atoms(self) -> tuple[Float[Array, " n+m+1"], AbstractLinearOperator, AbstractLinearOperator]: 57 | proj_kstar_v, dproj_kstar_v = self.problem_structure.cone_projector(self.y - self.s) 58 | pi_z = jnp.concatenate([self.x, proj_kstar_v, jnp.array([1.0], dtype=self.x.dtype)]) 59 | dpi_z = _BlockLinearOperator([IdentityLinearOperator(eval_shape(lambda: self.x)), 60 | dproj_kstar_v, 61 | IdentityLinearOperator(eval_shape(lambda: jnp.array([1.0])))]) 62 | Px = self.P.mv(self.x) 63 | xTPx = self.x @ Px 64 | AT = self.AT 65 | # NOTE(quill): seems hard to avoid the `DzQ` bit of the variable name. 66 | # NOTE(quill): Note that we're skipping the step of extracting the first n components of 67 | # `pi_z` and just using `P @ pi_z[:n] = P @ x`. 68 | DzQ_pi_z = _DuQ(P=self.P, Px=Px, xTPx=xTPx, A=self.A, AT=AT, q=self.q, 69 | b=self.b, x=self.x, tau=jnp.array(1.0, dtype=self.x.dtype), 70 | n=self.problem_structure.n, m=self.problem_structure.m) 71 | # NOTE(quill): we use that z_N (as defined in paper) is always 1.0, thus don't 72 | # include that division. 73 | F = (DzQ_pi_z @ dpi_z) - dpi_z + IdentityLinearOperator(eval_shape(lambda: pi_z)) 74 | 75 | return (pi_z, F, dproj_kstar_v) 76 | 77 | @eqx.filter_jit 78 | def _jvp_direct_solve_get_F(self, F) -> Float[Array, "N N"]: 79 | """For prototyping purposes; obviously not efficient. 80 | """ 81 | 82 | def _get_dense_mat(mat: lx.AbstractLinearOperator): 83 | mv = lambda vec: mat.mv(vec) 84 | mm = jax.vmap(mv, in_axes=1, out_axes=1) 85 | return mm(jnp.eye(self.problem_structure.N)) 86 | 87 | return _get_dense_mat(F) 88 | 89 | def _jvp_common( 90 | self, 91 | dP: ObjMatrix, 92 | dA: Float[BCOO | BCSR, "m n"], 93 | dAT: Float[BCOO | BCSR, "n m"], 94 | dq: Float[Array, " n"], 95 | db: Float[Array, " m"], 96 | solve_method: str = "jax-lsmr" 97 | ) -> tuple[Float[Array, " n"], Float[Array, " m"], Float[Array, " m"]]: 98 | pi_z, F, dproj_kstar_v = self._form_atoms() 99 | 100 | n, m = self.problem_structure.n, self.problem_structure.m 101 | pi_z_n, pi_z_m, pi_z_N = pi_z[:n], pi_z[n:n+m], pi_z[-1] 102 | d_data_N = _d_data_Q(x=pi_z_n, y=pi_z_m, tau=pi_z_N, dP=dP, 103 | dA=dA, dAT=dAT, dq=dq, db=db) 104 | 105 | def zero_case(): 106 | return jnp.zeros_like(d_data_N) 107 | 108 | def nonzero_case(): 109 | if solve_method == "jax-lsmr": 110 | try: 111 | from lineax import LSMR 112 | except ImportError: 113 | raise ValueError("In your current environment the LSMR solve is not available.") 114 | soln = linear_solve(F, -d_data_N, solver=LSMR(rtol=1e-8, atol=1e-8)) 115 | return soln.value 116 | else: 117 | F_dense = self._jvp_direct_solve_get_F(F) 118 | soln = linear_solve(lx.MatrixLinearOperator(F_dense), -d_data_N) 119 | return soln.value 120 | 121 | dz = jax.lax.cond(jnp.allclose(d_data_N, 0), 122 | zero_case, 123 | nonzero_case) 124 | 125 | dz_n, dz_m, dz_N = dz[:n], dz[n:n+m], dz[-1] 126 | dx = dz_n - self.x * dz_N 127 | dproj_k_star_v_dz_m = dproj_kstar_v.mv(dz_m) 128 | dy = dproj_k_star_v_dz_m - self.y * dz_N 129 | ds = dproj_k_star_v_dz_m - dz_m - self.s * dz_N 130 | return dx, dy, ds 131 | 132 | @abstractmethod 133 | def jvp( 134 | self, 135 | dP: Float[BCOO | BCSR, "n n"], 136 | dA: Float[BCOO | BCSR, "m n"], 137 | dq: Float[Array, " n"], 138 | db: Float[Array, " m"], 139 | solve_method: str = "jax-lsmr" 140 | ) -> tuple[Float[Array, " n"], Float[Array, " m"], Float[Array, " m"]]: 141 | """Apply the derivative of the QCP's solution map to an input perturbation. 142 | """ 143 | raise NotImplementedError 144 | 145 | @eqx.filter_jit 146 | def _vjp_direct_solve_get_FT(self, F) -> Float[Array, "N N"]: 147 | """For prototyping purposes; obviously not efficient. 148 | 149 | NOTE(quill): same innards as `_jvp_direct_solve_get_F`, but keeping separate for now 150 | since how we efficiently materialize the operators may vary? 151 | """ 152 | 153 | def _get_dense_mat(mat: lx.AbstractLinearOperator): 154 | mv = lambda vec: mat.mv(vec) 155 | mm = jax.vmap(mv, in_axes=1, out_axes=1) 156 | return mm(jnp.eye(self.problem_structure.N)) 157 | 158 | return _get_dense_mat(F.T) 159 | 160 | def _vjp_common( 161 | self, 162 | dx: Float[Array, " n"], 163 | dy: Float[Array, " m"], 164 | ds: Float[Array, " m"], 165 | produce_output: Callable, 166 | solve_method: str = "jax-lsmr" 167 | ) -> tuple[ 168 | Float[BCOO | BCSR, "n n"], Float[BCOO | BCSR, "m n"], 169 | Float[Array, " n"], Float[Array, " m"]]: 170 | n, m = self.problem_structure.n, self.problem_structure.m 171 | pi_z, F, dproj_kstar_v = self._form_atoms() 172 | dz = jnp.concatenate([dx, 173 | dproj_kstar_v.mv(dy + ds) - ds, 174 | - jnp.array([self.x @ dx + self.y @ dy + self.s @ ds])] 175 | ) 176 | 177 | def zero_case(): 178 | return jnp.zeros_like(dz) 179 | 180 | def nonzero_case(): 181 | if solve_method == "jax-lsmr": 182 | try: 183 | from lineax import LSMR 184 | except ImportError: 185 | raise ValueError("In your current environment the LSMR solve is not available.") 186 | soln = linear_solve(F.T, -dz, solver=LSMR(rtol=1e-8, atol=1e-8)) 187 | return soln.value 188 | else: 189 | FT = self._vjp_direct_solve_get_FT(F) 190 | soln = linear_solve(lx.MatrixLinearOperator(FT), -dz) 191 | return soln.value 192 | 193 | d_data_N = jax.lax.cond(jnp.allclose(dz, 0), 194 | zero_case, 195 | nonzero_case) 196 | 197 | pi_z_n = pi_z[:n] 198 | pi_z_m = pi_z[n:n+m] 199 | pi_z_N = pi_z[-1] 200 | d_data_N_n = d_data_N[:n] 201 | d_data_N_m = d_data_N[n:n+m] 202 | d_data_N_N = d_data_N[-1] 203 | 204 | return produce_output(x=pi_z_n, y=pi_z_m, tau=pi_z_N, 205 | w1=d_data_N_n, w2=d_data_N_m, w3=d_data_N_N) 206 | 207 | @abstractmethod 208 | def vjp( 209 | self, 210 | dx: Float[Array, " n"], 211 | dy: Float[Array, " m"], 212 | ds: Float[Array, " m"], 213 | solve_method: str = "jax-lsmr" 214 | ) -> tuple[ 215 | Float[BCOO | BCSR, "n n"], Float[BCOO | BCSR, "m n"], 216 | Float[Array, " n"], Float[Array, " m"]]: 217 | """Apply the adjoint of the derivative of the QCP's solution map to a solution perturbation. 218 | """ 219 | raise NotImplementedError 220 | 221 | 222 | class HostQCP(AbstractQCP): 223 | """QCP whose subroutines are optimized to run on host (CPU). 224 | """ 225 | P: ObjMatrixCPU 226 | A: Float[BCOO, "m n"] 227 | AT: Float[BCOO, "n m"] 228 | q: Float[Array, " n"] 229 | b: Float[Array, " m"] 230 | x: Float[Array, " n"] 231 | y: Float[Array, " m"] 232 | s: Float[Array, " m"] 233 | 234 | problem_structure: QCPStructureCPU 235 | 236 | def __init__( 237 | self, 238 | P: Float[BCOO, "n n"], 239 | A: Float[BCOO, "m n"], 240 | q: Float[Array, " n"], 241 | b: Float[Array, " m"], 242 | x: Float[Array, " n"], 243 | y: Float[Array, " m"], 244 | s: Float[Array, " m"], 245 | problem_structure: QCPStructureCPU 246 | ): 247 | """**Arguments:** 248 | - `P`: BCOO, shape (n, n). The quadratic objective matrix. Must be symmetric and provided in sparse BCOO format. 249 | Only the upper triangular part is required and used for efficiency. 250 | - `A`: BCOO, shape (m, n). The constraint matrix in sparse BCOO format. 251 | - `q`: ndarray, shape (n,). The linear objective vector. 252 | - `b`: ndarray, shape (m,). The constraint vector. 253 | - `x`: ndarray, shape (n,). The primal solution vector. 254 | - `y`: ndarray, shape (m,). The dual solution vector. 255 | - `s`: ndarray, shape (m,). The primal slack variable. 256 | - `problem_structure`: QCPStructureCPU. Structure object containing metadata about the problem, including sparsity patterns (such as the nonzero row and column indices for P and A), and cone information. 257 | 258 | **Notes:** 259 | - The sparsity structure of `P` and `A` must match that described in `problem_structure`. 260 | - `P` should only contain the upper triangular part of the matrix. 261 | - All arrays should be on the host (CPU) and compatible with JAX operations. 262 | """ 263 | self.A, self.q, self.b = A, q, b 264 | self.AT = A.T 265 | self.x, self.y, self.s = x, y, s 266 | self.problem_structure = problem_structure 267 | self.P = self.problem_structure.form_obj(P) 268 | 269 | def jvp( 270 | self, 271 | dP: Float[BCOO, "n n"], 272 | dA: Float[BCOO, "m n"], 273 | dq: Float[Array, " n"], 274 | db: Float[Array, " m"], 275 | solve_method: str = "jax-lsmr" 276 | ) -> tuple[Float[Array, " n"], Float[Array, " m"], Float[Array, " m"]]: 277 | """Apply the derivative of the QCP's solution map to an input perturbation. 278 | 279 | Specifically, an implementation of the method given in section 3.1 of the paper. 280 | 281 | **Arguments:** 282 | - `dP`: should have the same sparsity structure as `P`. *Note* that 283 | this means it should only contain the upper triangular part of `dP`. 284 | - `dA`: should have the same sparsity structure as `A`. 285 | - `dq` 286 | - `db` 287 | 288 | **Returns:** 289 | 290 | A 3-tuple containing the perturbations to the solution: `(dx, dy, ds)`. 291 | """ 292 | # NOTE(quill): this implementation is identitcal to `DeviceQCP`'s implementation 293 | # minus the `dAT = dA.T`. 294 | # Can this be consolidated / does it indicate incorrect design decision/execution? 295 | # => NOTE(quill): I've attempted to address this annoyance with `_jvp_common`. 296 | dAT = dA.T 297 | dP = self.problem_structure.form_obj(dP) 298 | # need to wrap dP. 299 | return self._jvp_common(dP=dP, dA=dA, dAT=dAT, dq=dq, db=db, solve_method=solve_method) 300 | 301 | def vjp( 302 | self, 303 | dx: Float[Array, " n"], 304 | dy: Float[Array, " m"], 305 | ds: Float[Array, " m"], 306 | solve_method: str = "jax-lsmr" 307 | ) -> tuple[ 308 | Float[BCSR, "n n"], Float[BCSR, "m n"], 309 | Float[Array, " n"], Float[Array, " m"]]: 310 | """Apply the adjoint of the derivative of the QCP's solution map to a solution perturbation. 311 | 312 | Specifically, an implementation of the method given in section 3.2 of the paper. 313 | 314 | **Arguments:** 315 | - `dx`: A perturbation to the primal solution. 316 | - `dy`: A perturbation to the dual solution. 317 | - `ds`: A perturbation to the primal slack solution. 318 | 319 | **Returns** 320 | 321 | A four-tuple containing the perturbations to the objective matrix, constraint matrix, 322 | linear cost function vector, and constraint vector. Note that these perturbation matrices 323 | will have the same sparsity patterns as their corresponding problem matrices. (So, importantly, 324 | the first matrix will only contain the upper triangular part of the true perturbation to the 325 | objective matrix perturbation.) 326 | """ 327 | # NOTE(quill): This is a similar note to the one I left in this class's `jvp`. That is, this 328 | # implementation is identical to `DeviceQCP`'s `vjp` minus the function call at the very bottom. 329 | # Can this be consolidated / does it indicate incorrect design decision/execution? 330 | 331 | partial_d_data_Q_adjoint_cpu = ft.partial(_d_data_Q_adjoint_cpu, 332 | P_rows=self.problem_structure.P_nonzero_rows, 333 | P_cols=self.problem_structure.P_nonzero_cols, 334 | A_rows=self.problem_structure.A_nonzero_rows, 335 | A_cols=self.problem_structure.A_nonzero_cols, 336 | n=self.problem_structure.n, 337 | m=self.problem_structure.m) 338 | 339 | return self._vjp_common(dx=dx, dy=dy, ds=ds, 340 | produce_output=partial_d_data_Q_adjoint_cpu, 341 | solve_method=solve_method) 342 | 343 | 344 | class DeviceQCP(AbstractQCP): 345 | """QCP whose subroutines are optimized to run on device (GPU). 346 | """ 347 | # NOTE(quill): when we allow for batched problem data, will need 348 | # to wrap `P` in an `AbstractLinearOperator` to dictate how the `mv` 349 | # operation should behave. 350 | P: ObjMatrixGPU 351 | A: Float[BCSR, "m n"] 352 | AT: Float[BCSR, "n m"] 353 | q: Float[Array, " n"] 354 | b: Float[Array, " m"] 355 | x: Float[Array, " n"] 356 | y: Float[Array, " m"] 357 | s: Float[Array, " m"] 358 | 359 | problem_structure: QCPStructureGPU 360 | 361 | def __init__( 362 | self, 363 | P: Float[BCSR, "n n"], 364 | A: Float[BCSR, "m n"], 365 | q: Float[Array, " n"], 366 | b: Float[Array, " m"], 367 | x: Float[Array, " n"], 368 | y: Float[Array, " m"], 369 | s: Float[Array, " m"], 370 | problem_structure: QCPStructureGPU 371 | ): 372 | """**Arguments:** 373 | - `P`: BCSR, shape (n, n). The quadratic objective matrix in sparse BCSR format. 374 | Must be symmetric. For device execution, the full matrix (not just upper triangular) is required. 375 | - `A`: BCSR, shape (m, n). The constraint matrix in sparse BCSR format. 376 | - `q`: ndarray, shape (n,). The linear objective vector. 377 | - `b`: ndarray, shape (m,). The constraint vector. 378 | - `x`: ndarray, shape (n,). The primal solution vector. 379 | - `y`: ndarray, shape (m,). The dual solution vector. 380 | - `s`: ndarray, shape (m,). The primal slack variable. 381 | - `problem_structure`: QCPStructureGPU. Structure object containing metadata about the problem, including sparsity patterns 382 | (such as the nonzero row and column indices for P and A), and cone information. 383 | 384 | **Notes:** 385 | - The sparsity structure of `P` and `A` must match that described in `problem_structure`. 386 | - `P` should contain the full symmetric matrix (not just upper triangular). 387 | - All arrays should be on the device (GPU) and compatible with JAX operations. 388 | """ 389 | self.problem_structure = problem_structure 390 | self.P = ObjMatrixGPU(P) 391 | self.A, self.q, self.b = A, q, b 392 | self.AT = self.problem_structure.form_A_transpose(self.A) 393 | self.x, self.y, self.s = x, y, s 394 | 395 | @eqx.filter_jit 396 | def _jvp_nvmath_form_atoms( 397 | self, 398 | dP: ObjMatrix, 399 | dA: Float[BCSR, "m n"], 400 | dAT: Float[BCSR, "n m"], 401 | dq: Float[Array, " n"], 402 | db: Float[Array, " m"] 403 | ) -> tuple[Float[Array, " N"], AbstractLinearOperator, AbstractLinearOperator]: 404 | n = self.problem_structure.n 405 | m = self.problem_structure.m 406 | pi_z, F, dproj_k_star_v = self._form_atoms() 407 | pi_z_n, pi_z_m, pi_z_N = pi_z[:n], pi_z[n:n+m], pi_z[-1] 408 | d_data_N = _d_data_Q(x=pi_z_n, y=pi_z_m, tau=pi_z_N, dP=dP, 409 | dA=dA, dAT=dAT, dq=dq, db=db) 410 | 411 | return -d_data_N, F, dproj_k_star_v 412 | 413 | def _jvp_nvmath_actual_solve(self, F, d_data_N_minus): 414 | # NOTE(quill): separating this out for timing purposes. 415 | # return nvmath.sparse.advanced.direct_solver(F, d_data_N_minus) 416 | 417 | with DirectSolver( 418 | F, 419 | d_data_N_minus 420 | ) as solver: 421 | 422 | config = solver.plan_config 423 | config.reordering_algorithm = DirectSolverAlgType.ALG_1 424 | 425 | solver.plan() 426 | solver.factorize() 427 | x = solver.solve() 428 | 429 | return x 430 | 431 | def _jvp_nvmath_direct_solve(self, F, d_data_N_minus): 432 | F_cupy_csr = csr_matrix(cp_from_dlpack(F)) 433 | d_data_N_minus_cupy = cp_from_dlpack(d_data_N_minus) 434 | dz_cupy = self._jvp_nvmath_actual_solve(F_cupy_csr, d_data_N_minus_cupy) 435 | dz = jax.dlpack.from_dlpack(dz_cupy) 436 | return dz 437 | 438 | @eqx.filter_jit 439 | def _jvp_nvmath_get_output(self, dz, dproj_kstar_v): 440 | n = self.problem_structure.n 441 | m = self.problem_structure.m 442 | 443 | dz_n, dz_m, dz_N = dz[:n], dz[n:n+m], dz[-1] 444 | dx = dz_n - self.x * dz_N 445 | dproj_k_star_v_dz_m = dproj_kstar_v.mv(dz_m) 446 | dy = dproj_k_star_v_dz_m - self.y * dz_N 447 | ds = dproj_k_star_v_dz_m - dz_m - self.s * dz_N 448 | return dx, dy, ds 449 | 450 | def _jvp_nvmath( 451 | self, 452 | dP: ObjMatrix, 453 | dA: Float[BCSR, "m n"], 454 | dAT: Float[BCSR, "n m"], 455 | dq: Float[Array, " n"], 456 | db: Float[Array, " m"] 457 | ): 458 | d_data_N_minus, F, dproj_k_star_v = self._jvp_nvmath_form_atoms(dP, dA, dAT, dq, db) 459 | 460 | # `_jvp_direct_solve` cannot be jitted, so can use regular 461 | # Python control flow 462 | # TODO(quill): use a norm tolerance instead? 463 | if jnp.allclose(d_data_N_minus, 0): 464 | return jnp.zeros_like(d_data_N_minus) 465 | else: 466 | F = self._jvp_direct_solve_get_F(F) 467 | dz = self._jvp_nvmath_direct_solve(F, d_data_N_minus) 468 | 469 | return self._jvp_nvmath_get_output(dz, dproj_k_star_v) 470 | 471 | def jvp( 472 | self, 473 | dP: Float[BCSR, "n n"], 474 | dA: Float[BCSR, "m n"], 475 | dq: Float[Array, " n"], 476 | db: Float[Array, " m"], 477 | solve_method: str = "jax-lsmr" 478 | ) -> tuple[Float[Array, " n"], Float[Array, " m"], Float[Array, " m"]]: 479 | """Apply the derivative of the QCP's solution map to an input perturbation. 480 | 481 | Specifically, an implementation of the method given in section 3.1 of the paper. 482 | 483 | **Arguments:** 484 | - `dP` should have the same sparsity structure as `P`. *Note* that 485 | this means it should only contain the entirety of `dP`. 486 | (i.e., not just the upper triangular part.) 487 | - `dA` should have the same sparsity structure as `A`. 488 | - `dq` 489 | - `db` 490 | 491 | **Returns:** 492 | 493 | A 3-tuple containing the perturbations to the solution: `(dx, dy, ds)`. 494 | """ 495 | dP = ObjMatrixGPU(dP) 496 | dAT = eqx.filter_jit(self.problem_structure.form_A_transpose)(dA) 497 | if solve_method in ["jax-lsmr", "jax-lu"]: 498 | return self._jvp_common(dP=dP, dA=dA, dAT=dAT, dq=dq, db=db, solve_method=solve_method) 499 | elif solve_method == "nvmath-direct": 500 | if DirectSolver is None: 501 | raise ValueError("The `nvmath-direct` option can only be used when " 502 | "`nvmath-python` is installed. Also check that CuPy is " 503 | "installed.") 504 | return self._jvp_nvmath(dP=dP, dA=dA, dAT=dAT, dq=dq, db=db) 505 | else: 506 | raise ValueError(f"Solve method \"{solve_method}\" is not specified. " 507 | " The options are \"jax-lsmr\", \"nvmath-direct\", and " 508 | "\"jax-lu\".") 509 | 510 | @eqx.filter_jit 511 | def _vjp_nvmath_form_atoms( 512 | self, 513 | dx: Float[Array, " n"], 514 | dy: Float[Array, " m"], 515 | ds: Float[Array, " m"] 516 | ): 517 | pi_z, F, dproj_kstar_v = self._form_atoms() 518 | dz = jnp.concatenate([dx, 519 | dproj_kstar_v.mv(dy + ds) - ds, 520 | - jnp.array([self.x @ dx + self.y @ dy + self.s @ ds])] 521 | ) 522 | return -dz, F, pi_z 523 | 524 | def _vjp_nvmath_actual_solve(self, FT, dz_minus): 525 | # NOTE(quill): separating this out for timing purposes. 526 | # return nvmath.sparse.advanced.direct_solver(FT, dz_minus) 527 | 528 | with DirectSolver( 529 | FT, 530 | dz_minus 531 | ) as solver: 532 | 533 | config = solver.plan_config 534 | config.reordering_algorithm = DirectSolverAlgType.ALG_1 535 | 536 | solver.plan() 537 | solver.factorize() 538 | x = solver.solve() 539 | 540 | return x 541 | 542 | def _vjp_nvmath_direct_solve(self, FT, dz_minus): 543 | # FT is a JAX Array (<=> it is materialized.) 544 | 545 | # === some tinkering with preconditioner === 546 | 547 | # prec_jax = jnp.diag((jnp.diag(jnp.transpose(FT) @ FT))**(-1)) 548 | # prec = cp_from_dlpack(prec_jax) 549 | # FT_cupy_csr = csr_matrix(cp_from_dlpack(prec_jax @ FT)) 550 | # dz_minus_cupy = cp_from_dlpack(dz_minus) 551 | # d_data_N_cupy = self._vjp_actual_solve(FT_cupy_csr, prec @ dz_minus_cupy) 552 | 553 | # === === 554 | 555 | FT_cupy_csr = csr_matrix(cp_from_dlpack(FT)) 556 | dz_minus_cupy = cp_from_dlpack(dz_minus) 557 | d_data_N_cupy = self._vjp_nvmath_actual_solve(FT_cupy_csr, dz_minus_cupy) 558 | d_data_N = jax.dlpack.from_dlpack(d_data_N_cupy) 559 | return d_data_N 560 | 561 | @eqx.filter_jit() 562 | def _vjp_nvmath_get_output( 563 | self, 564 | pi_z, 565 | d_data_N 566 | ): 567 | n = self.problem_structure.n 568 | m = self.problem_structure.m 569 | 570 | pi_z_n = pi_z[:n] 571 | pi_z_m = pi_z[n:n+m] 572 | pi_z_N = pi_z[-1] 573 | d_data_N_n = d_data_N[:n] 574 | d_data_N_m = d_data_N[n:n+m] 575 | d_data_N_N = d_data_N[-1] 576 | 577 | return _d_data_Q_adjoint_gpu( 578 | x=pi_z_n, 579 | y=pi_z_m, 580 | tau=pi_z_N, 581 | w1=d_data_N_n, 582 | w2=d_data_N_m, 583 | w3=d_data_N_N, 584 | P_rows=self.problem_structure.P_nonzero_rows, 585 | P_cols=self.problem_structure.P_nonzero_cols, 586 | P_csr_indices=self.problem_structure.P_csr_indices, 587 | P_csr_indtpr=self.problem_structure.P_csr_indptr, 588 | A_rows=self.problem_structure.A_nonzero_rows, 589 | A_cols=self.problem_structure.A_nonzero_cols, 590 | A_csr_indices=self.problem_structure.A_csr_indices, 591 | A_csr_indtpr=self.problem_structure.A_csr_indptr, 592 | n=n, 593 | m=m 594 | ) 595 | 596 | def _vjp_nvmath( 597 | self, 598 | dx: Float[Array, " n"], 599 | dy: Float[Array, " m"], 600 | ds: Float[Array, " m"] 601 | ): 602 | dz_minus, F, pi_z = self._vjp_nvmath_form_atoms(dx, dy, ds) 603 | 604 | # now check if 0 or not. `_vjp_nvmath` cannot be jitted, so we can 605 | # just use typical Python control flow 606 | if jnp.allclose(dz_minus, 0): 607 | return jnp.zeros_like(dz_minus) 608 | else: 609 | # obtain FT 610 | FT = self._vjp_direct_solve_get_FT(F) 611 | d_data_N = self._vjp_nvmath_direct_solve(FT, dz_minus) 612 | 613 | return self._vjp_nvmath_get_output(pi_z, d_data_N) 614 | 615 | 616 | def vjp( 617 | self, 618 | dx: Float[Array, " n"], 619 | dy: Float[Array, " m"], 620 | ds: Float[Array, " m"], 621 | solve_method: str = "jax-lu" 622 | ) -> tuple[ 623 | Float[BCSR, "n n"], Float[BCSR, "m n"], 624 | Float[Array, " n"], Float[Array, " m"]]: 625 | """Apply the adjoint of the derivative of the QCP's solution map to a solution perturbation. 626 | 627 | Specifically, an implementation of the method given in section 3.2 of the paper. 628 | 629 | **Arguments:** 630 | - `dx`: A perturbation to the primal solution. 631 | - `dy`: A perturbation to the dual solution. 632 | - `ds`: A perturbation to the primal slack solution. 633 | - `solve_method` (str): How TODO(quill). Options are: 634 | - "jax-lsmr" 635 | - "jax-lu" 636 | - "nvmath-direct" 637 | 638 | **Returns** 639 | 640 | A four-tuple containing the perturbations to the objective matrix, constraint matrix, 641 | linear cost function vector, and constraint vector. Note that these perturbation matrices 642 | will have the same sparsity patterns as their corresponding problem matrices. 643 | """ 644 | 645 | if solve_method in ["jax-lsmr", "jax-lu"]: 646 | partial_d_data_Q_adjoint_gpu = ft.partial(_d_data_Q_adjoint_gpu, 647 | P_rows=self.problem_structure.P_nonzero_rows, 648 | P_cols=self.problem_structure.P_nonzero_cols, 649 | P_csr_indices=self.problem_structure.P_csr_indices, 650 | P_csr_indtpr=self.problem_structure.P_csr_indptr, 651 | A_rows=self.problem_structure.A_nonzero_rows, 652 | A_cols=self.problem_structure.A_nonzero_cols, 653 | A_csr_indices=self.problem_structure.A_csr_indices, 654 | A_csr_indtpr=self.problem_structure.A_csr_indptr, 655 | n=self.problem_structure.n, 656 | m=self.problem_structure.m) 657 | 658 | return self._vjp_common(dx=dx, dy=dy, ds=ds, produce_output=partial_d_data_Q_adjoint_gpu, solve_method=solve_method) 659 | elif solve_method == "nvmath-direct": 660 | if DirectSolver is None: 661 | raise ValueError("The `nvmath-direct` option can only be used when " 662 | "`nvmath-python` is installed. Also check that CuPy is " 663 | "installed.") 664 | return self._vjp_nvmath(dx=dx, dy=dy, ds=ds) 665 | else: 666 | raise ValueError(f"Solve method \"{solve_method}\" is not specified. " 667 | " The options are \"jax-lsmr\", \"nvmath-direct\", and " 668 | "\"jax-lu\".") 669 | --------------------------------------------------------------------------------