├── .gitignore ├── README.md ├── benchmark.png ├── examples ├── benchmark.py └── minimal.py ├── pyproject.toml └── sparse_solver ├── __init__.py └── sparse_solver.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | sparse_solver.egg-info/ 4 | sparse_solver/__init__.py 5 | tests/ 6 | venv/ 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Sparse Solve 2 | 3 | This small library provides a python class `SparseSolve` to use pytorch to 4 | back-propagate through function's which involve solving against a sparse matrix 5 | whose entries are differentiation variables. For example, consider that we have 6 | a function $\mathcal{L} : \mathbb{R}^d \rightarrow \mathbb{R}$ where $\mathcal{L}$ is defined as: 7 | 8 | $$ 9 | \mathcal{L}(w) = \frac{1}{2} \left\| A(w)^{-1} b \right\|^2 10 | $$ 11 | 12 | and $A: \mathbb{R}^d \rightarrow \mathbb{R}^{n \times n}$ is some 13 | *sparse*-matrix function of the model parameters $w$. 14 | Correspondingly in pytorch we might write: 15 | 16 | ```python 17 | b = torch.ones(n, dtype=torch.double) 18 | # x = A⁻¹ b 19 | x = torch.linalg.solve(A,b) 20 | # L(w) = ½‖x‖² = ½‖A(w)⁻¹ b‖² 21 | L = torch.sum(x**2) / 2 22 | ``` 23 | 24 | For example, suppose that $A$ is defined to be the identity matrix plus the weighted graph 25 | Laplacian for a sparse set of $d$ edges. In mathematical terms, we can write: 26 | 27 | $$ 28 | A_{ij} = \begin{cases} 29 | -w_e & \text{if } (i,j) \text{ or } (j,i) \text{ is the $e$-th edge} \\ 30 | 1 + \sum\limits_{k\neq i} A_{ik} & \text{if } i = j \\ 31 | 0 & \text{otherwise} 32 | \end{cases}. 33 | $$ 34 | 35 | ### 😭 Dense baseline 36 | 37 | Unfortunately, pytorch doesn't support sparse matrices well by default. So if we 38 | were to build a $A$ matrix, we would have to build a dense matrix. For example, 39 | assuming a sparse set of random edges, we might write something like: 40 | 41 | ```python 42 | # size of problem 43 | n = 10 44 | # seed 45 | torch.manual_seed(0) 46 | # E is a #E by 2 list of edges (i,j) ∈ [0,n)² 47 | E = torch.unique(torch.randint(0, n, (n*6, 2), dtype=torch.int64), dim=0) 48 | # w is a #E vector of parameters 49 | w = torch.ones(E.shape[0], dtype=torch.double, requires_grad=True) 50 | # A = I + WeightedGraphLaplacian(E,w) 51 | diag = torch.arange(n) 52 | indices = torch.stack([torch.cat([diag,E[:,0],E[:,1],E[:,0],E[:,1]]),torch.cat([diag,E[:,1],E[:,0],E[:,0],E[:,1]])]) 53 | values = torch.cat([torch.ones(n, dtype=torch.double), -w, -w, w, w]) 54 | 55 | # Build A as dense matrix (default for torch) 56 | A = torch.zeros(n, n, dtype=torch.double) 57 | A.index_put_((indices[0], indices[1]), values, accumulate=True) 58 | ``` 59 | 60 | The forward pass is of course $O(n^2)$ just to construct `A`, but calling 61 | `torch.linalg.solve(A, b)` is $O(n^3)$. The backward pass is similarly $O(n^3)$: 62 | 63 | ```python 64 | # very slow for large n 65 | L.backward() 66 | dLdw = w.grad.clone().detach() 67 | ``` 68 | 69 | This default dense pytorch code will choke as $n$ increases. 70 | 71 | ### 🚀 Sparse 72 | 73 | Fortunately instead, we can use `torch.sparse_coo_tensor` and `SparseSolve` to 74 | construct and solve against $A$ in a sparse way while maintaining 75 | differentiability. 76 | 77 | ```python 78 | import sparse_solver 79 | # build A as a torch.sparse_coo_tensor 80 | A_sparse = torch.sparse_coo_tensor(indices, values, size=(n, n), dtype=torch.double).coalesce() 81 | # x = A⁻¹ b 82 | x = sparse_solver.SparseSolver.apply(A_sparse, b) 83 | # L(w) = ½‖x‖² = ½‖A(w)⁻¹ b‖² 84 | L = torch.sum(x**2) / 2 85 | w.grad = None 86 | # Efficient even for large n 87 | L.backward() 88 | dLdw_sparse = w.grad.clone().detach() 89 | ``` 90 | 91 | `SparseSolve` uses CPU-based sparse Cholesky factorization and GPU-back/forward-substitution in the forward pass and cache the factorization for efficient GPU backward pass. The precise asymptotic behavior depends on the sparsity pattern and ability to permute the matrix well, but for common patterns it will be something like $O(n^p)$ where $1\leq p \leq 2$. 92 | 93 | For examples like the one above, as $n$ increases torch indeed measures performance looking something like $n^{2.5}$ 94 | and SparseSolve measures performance very close to $n^{1.0}$. 95 | 96 | ![](benchmark.png) 97 | 98 | ## Use 99 | 100 | Install with pip: 101 | 102 | python -m pip install . 103 | 104 | Run tests 105 | 106 | pytest 107 | 108 | Run minimal example above 109 | 110 | python examples/minimal.py 111 | 112 | Run benchmark 113 | 114 | python examples/benchmark.py 115 | 116 | 117 | ## To-do list 118 | 119 | - [ ] Add fuller example (e.g., "Fast Quasi-Harmonic Weights for Geometric Data Interpolation", or inverse design of mass-spring cantilever) 120 | 121 | You might also be interested in https://github.com/alecjacobson/indexed_sum 122 | 123 | https://github.com/flaport/torch_sparse_solve appears to be similar, but supports batching and uses LU instead of Cholesky. 124 | 125 | _Original code from Aravind Ramakrishnan._ 126 | -------------------------------------------------------------------------------- /benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecjacobson/pytorch-sparse-solve/dcd62b25c55c41fc29eb4136ab693daaa19f0442/benchmark.png -------------------------------------------------------------------------------- /examples/benchmark.py: -------------------------------------------------------------------------------- 1 | import sparse_solver 2 | import torch 3 | import time 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | 9 | # is torch_sparse_solve available? 10 | try: 11 | import torch_sparse_solve 12 | torch_sparse_solve_available = True 13 | except ImportError: 14 | torch_sparse_solve_available = False 15 | 16 | print(f"torch_sparse_solve available: {torch_sparse_solve_available}") 17 | 18 | def x_dense(theta,b): 19 | n = b.shape[0] 20 | # A is defined so that: 21 | # ½ x^T A x = ½ x^T I x + ∑_i theta_i (x_i - x_{i+1})^2 22 | # 23 | # A = 0 24 | # A(i,i) += 1 + theta(i) + theta(i+1) 25 | # A(i,j) += -theta(i) if j == i+1 or i == j+1 26 | 27 | # Start with identity 28 | A = torch.eye(n, dtype=torch.double) 29 | 30 | # Compute indices for cyclic neighbor (i, i+1 mod n) 31 | i = torch.arange(n) 32 | j = (i + 1) % n 33 | 34 | # Add theta[i] to (i, j) and (j, i) 35 | A[i, j] -= theta 36 | A[j, i] -= theta 37 | 38 | i = torch.arange(n) 39 | j = (i + 1) % n 40 | A[i,i] += theta 41 | A[j,j] += theta 42 | 43 | 44 | x = torch.linalg.solve(A,b) 45 | return x 46 | 47 | def A_sparse(theta,n): 48 | # Identity part: (i, i) → 1 49 | diag_i = torch.arange(n) 50 | diag_j = torch.arange(n) 51 | diag_val = torch.ones(n, dtype=torch.double) + theta + theta.roll(1) 52 | 53 | # Off-diagonal part: (i, i+1) and (i+1, i) → -theta[i] 54 | i = torch.arange(n) 55 | j = (i + 1) % n 56 | 57 | # Stack (i,j) and (j,i) 58 | off_i = torch.cat([i, j]) 59 | off_j = torch.cat([j, i]) 60 | off_val = torch.cat([-theta, -theta]) 61 | 62 | # Combine all 63 | indices = torch.stack([torch.cat([diag_i, off_i]), 64 | torch.cat([diag_j, off_j])]) 65 | values = torch.cat([diag_val, off_val]) 66 | 67 | A = torch.sparse_coo_tensor(indices, values, size=(n, n), dtype=torch.double).coalesce() 68 | return A 69 | 70 | def x_sparse(theta,b): 71 | n = b.shape[0] 72 | A = A_sparse(theta, n) 73 | x = sparse_solver.SparseSolver.apply(A, b) 74 | return x 75 | 76 | def x_lu(theta,b): 77 | n = b.shape[0] 78 | A = A_sparse(theta, n) 79 | x = torch_sparse_solve.solve(A.unsqueeze(0), b.unsqueeze(1).unsqueeze(0) ) 80 | x = x.squeeze(0).squeeze(-1) 81 | return x 82 | 83 | def loss(x): 84 | # loss function 85 | return torch.linalg.norm(x)**2 86 | 87 | 88 | # Timing and data collection 89 | ns = [] 90 | dense_times = [] 91 | sparse_times = [] 92 | lu_times = [] 93 | 94 | for n in (2 ** i for i in range(8, 14)): 95 | torch.manual_seed(0) 96 | b = torch.rand(n, requires_grad=False, dtype=torch.double) 97 | theta = torch.rand(n, requires_grad=True, dtype=torch.double) 98 | 99 | for _ in range(2): 100 | theta.grad = None 101 | start = time.time() 102 | f = loss(x_dense(theta, b)) 103 | f.backward() 104 | dfdtheta_dense = theta.grad.clone().detach() 105 | t_dense = time.time() - start 106 | 107 | for _ in range(2): 108 | theta.grad = None 109 | start = time.time() 110 | f = loss(x_sparse(theta, b)) 111 | f.backward() 112 | dfdtheta_sparse = theta.grad.clone().detach() 113 | t_sparse = time.time() - start 114 | 115 | assert torch.allclose(dfdtheta_dense, dfdtheta_sparse, atol=1e-6), f"Gradient mismatch for n={n}" 116 | 117 | if torch_sparse_solve_available: 118 | for _ in range(2): 119 | theta.grad = None 120 | start = time.time() 121 | f = loss(x_lu(theta, b)) 122 | f.backward() 123 | dfdtheta_lu = theta.grad.clone().detach() 124 | t_lu = time.time() - start 125 | 126 | assert torch.allclose(dfdtheta_dense, dfdtheta_lu, atol=1e-6), f"Gradient mismatch for n={n}" 127 | 128 | ns.append(n) 129 | dense_times.append(t_dense) 130 | sparse_times.append(t_sparse) 131 | if torch_sparse_solve_available: 132 | lu_times.append(t_lu) 133 | 134 | 135 | # The rest is all plotting 136 | 137 | plt.rcParams.update({ 138 | 'font.size': 16, 139 | 'axes.titlesize': 18, 140 | 'axes.labelsize': 18, 141 | 'legend.fontsize': 14, 142 | 'xtick.labelsize': 14, 143 | 'ytick.labelsize': 14, 144 | 'figure.facecolor': '1.0', 145 | 'axes.facecolor': '0.95', 146 | 'grid.color': 'white', 147 | 'grid.linestyle': '-', 148 | 'grid.linewidth': 1.2 149 | }) 150 | 151 | 152 | # Plot 153 | ns = np.array(ns) 154 | dense_times = np.array(dense_times) 155 | sparse_times = np.array(sparse_times) 156 | 157 | fig, ax = plt.subplots(figsize=(8, 6)) 158 | 159 | ax.loglog(ns, dense_times, 'o-', label='torch', linewidth=3) 160 | ax.loglog(ns, sparse_times, 's-', label='SparseSolve', linewidth=3) 161 | if torch_sparse_solve_available: 162 | ax.loglog(ns, lu_times, '^-', label='torch_sparse_solve', linewidth=3) 163 | 164 | # Guide lines and text annotations 165 | x0 = ns[0] 166 | y0 = sparse_times[0] 167 | orders = [1, 2, 3] 168 | labels = [r'$\mathcal{O}(n)$', r'$\mathcal{O}(n^2)$', r'$\mathcal{O}(n^3)$'] 169 | 170 | for p, label in zip(orders, labels): 171 | guide_y = y0 * (ns / x0) ** p 172 | ax.loglog(ns, guide_y, '--', color='black') 173 | 174 | x_last = ns[-1] 175 | y_last = guide_y[-1] 176 | 177 | # Shift label slightly up and left (log scale aware) 178 | x_shift = x_last / 1.1 179 | y_shift = y_last * 0.9 180 | 181 | ax.text(x_shift, y_shift, label, 182 | color='black', fontsize=14, 183 | ha='right', va='bottom') 184 | 185 | # Styling 186 | ax.set_xlabel('Problem size $n$') 187 | ax.set_ylabel('Time (s)') 188 | ax.set_title(r'∂|A(θ)⁻¹ b|²/∂θ computation time') 189 | ax.legend() 190 | ax.grid(axis='y', which='major') 191 | ax.xaxis.grid(False) 192 | ax.set_axisbelow(True) 193 | 194 | plt.tight_layout() 195 | plt.savefig("benchmark.png", dpi=300) 196 | -------------------------------------------------------------------------------- /examples/minimal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # size of problem 4 | n = 10 5 | b = torch.ones(n, dtype=torch.double) 6 | # seed 7 | torch.manual_seed(0) 8 | # E is a #E by 2 list of edges (i,j) ∈ [0,n)² 9 | E = torch.unique(torch.randint(0, n, (n*6, 2), dtype=torch.int64), dim=0) 10 | # w is a #E vector of parameters 11 | w = torch.ones(E.shape[0], dtype=torch.double, requires_grad=True) 12 | # A = I + WeightedGraphLaplacian(E,w) 13 | diag = torch.arange(n) 14 | indices = torch.stack([torch.cat([diag,E[:,0],E[:,1],E[:,0],E[:,1]]),torch.cat([diag,E[:,1],E[:,0],E[:,0],E[:,1]])]) 15 | values = torch.cat([torch.ones(n, dtype=torch.double), -w, -w, w, w]) 16 | 17 | # Build A as dense matrix (default for torch) 18 | A = torch.zeros(n, n, dtype=torch.double) 19 | A.index_put_((indices[0], indices[1]), values, accumulate=True) 20 | 21 | # x = A⁻¹ b 22 | x = torch.linalg.solve(A,b) 23 | # L(w) = ½‖x‖² = ½‖A(w)⁻¹ b‖² 24 | L = torch.sum(x**2) / 2 25 | w.grad = None 26 | # very slow for large n 27 | L.backward() 28 | dLdw = w.grad.clone().detach() 29 | 30 | import sparse_solver 31 | 32 | # build A as a torch.sparse_coo_tensor 33 | A_sparse = torch.sparse_coo_tensor(indices, values, size=(n, n), dtype=torch.double).coalesce() 34 | # x = A⁻¹ b 35 | x = sparse_solver.SparseSolver.apply(A_sparse, b) 36 | # L(w) = ½‖x‖² = ½‖A(w)⁻¹ b‖² 37 | L = torch.sum(x**2) / 2 38 | w.grad = None 39 | # Efficient even for large n 40 | L.backward() 41 | dLdw_sparse = w.grad.clone().detach() 42 | 43 | print("‖dLdw - dLdw_sparse‖ = ", torch.linalg.norm(dLdw - dLdw_sparse).item()) 44 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "sparse_solver" 7 | version = "0.1.0" 8 | description = "A module for computing sparse Hessians using indexed sums in PyTorch." 9 | readme = "README.md" 10 | authors = [ 11 | {name = "Alec Jacobson", email = "alecjacobson@gmail.com"}, 12 | {name = "Aravind Ramakrishnan", email = "aravind947@gmail.com"}, 13 | ] 14 | license = {text = "MIT"} 15 | dependencies = [ 16 | "torch", 17 | "cholespy", 18 | ] 19 | requires-python = ">=3.7" 20 | 21 | [tool.setuptools] 22 | packages = ["sparse_solver"] 23 | 24 | [tool.setuptools.package-data] 25 | sparse_solver = ["*.py"] 26 | 27 | [tool.setuptools.dynamic] 28 | version = {attr = "sparse_solver.__version__"} 29 | 30 | [project.urls] 31 | "Homepage" = "https://github.com/alecjacobson/sparse_solver" 32 | -------------------------------------------------------------------------------- /sparse_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sparse_solver import SparseSolver 2 | 3 | __version__ = "0.1.0" 4 | -------------------------------------------------------------------------------- /sparse_solver/sparse_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from cholespy import CholeskySolverD, MatrixType 3 | 4 | class SparseSolver(torch.autograd.Function): 5 | CHOL = None # cholesky decomposition 6 | A0 = torch.sparse_coo_tensor((1,1)).coalesce() # current matrix corresponding to CHOL 7 | 8 | @staticmethod 9 | def forward(ctx, A, b): 10 | if A.layout != torch.sparse_coo: 11 | A = A.to_sparse_coo() 12 | 13 | if not (torch.equal(SparseSolver.A0.indices(), A.indices()) and torch.equal(SparseSolver.A0.values(), A.values())): 14 | # don't factor matrix unless necessary 15 | SparseSolver.A0 = A 16 | 17 | ind = SparseSolver.A0.indices() 18 | rows = ind[0,:] 19 | cols = ind[1,:] 20 | vals = SparseSolver.A0.values() 21 | 22 | SparseSolver.CHOL = CholeskySolverD(SparseSolver.A0.size(0), rows, cols, vals, MatrixType.COO) 23 | 24 | x = torch.zeros_like(b, dtype=torch.double) 25 | SparseSolver.CHOL.solve(b, x) 26 | ctx.save_for_backward(b, SparseSolver.A0.indices(), x) 27 | return x 28 | 29 | @staticmethod 30 | @torch.autograd.function.once_differentiable 31 | def backward(ctx, grad_output): 32 | b, ind, res = ctx.saved_tensors 33 | grad_A = grad_b = None 34 | 35 | if ctx.needs_input_grad[0]: 36 | # partial1 = −diag(inv(A⊙S)^⊤⋅g)⋅S⋅diag(inv(A⊙S)⋅b) # but for symmetric A, S, we can drop the transposes 37 | n = b.size(0) 38 | S = torch.sparse_coo_tensor(ind, torch.ones(ind.size(1)), (n,n), dtype=torch.double) 39 | 40 | p1left = torch.zeros_like(b, dtype=torch.double) 41 | SparseSolver.CHOL.solve(grad_output.clone().detach().double(), p1left) 42 | p1left = torch.sparse_coo_tensor(torch.stack([torch.arange(n), torch.arange(n)], 0), p1left) 43 | 44 | p1right = torch.sparse_coo_tensor(torch.stack([torch.arange(n), torch.arange(n)], 0), res, dtype=torch.double) 45 | 46 | grad_A = -p1left @ S @ p1right 47 | 48 | if ctx.needs_input_grad[1]: 49 | # partial2 = inv(A⊙S)^⊤⋅g # drop transpose bc symmetric 50 | grad_b = torch.zeros_like(b, dtype=torch.double) 51 | SparseSolver.CHOL.solve(grad_output.clone().detach().double(), grad_b) 52 | 53 | return grad_A, grad_b 54 | --------------------------------------------------------------------------------