├── .gitignore ├── README.md ├── demo.py ├── figs ├── relative_cg_test_2048.png ├── relative_cg_test_4096.png ├── relative_gmres_test_1024.png └── relative_gmres_test_4096.png ├── linalg.py ├── requirements.txt ├── test.py └── test2.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Solving linear system in Pytorch 2 | This repository implements most commonly used iterative methods for solving linear system Ax=b in Pytorch that can run on GPUs. 3 | 4 | This repository includes Conjugate Gradient (CG) and GMRES. 5 | 6 | The figures below compare my implementation (`torch gmres`, `torch cg` in the figures) 7 | of CG and GMRES against the implementation in Scipy and JAX with single and double precisions. 8 | 9 | A few useful features: 10 | 1. Fully implemented in Pytorch and can run on GPUs. 11 | 2. Not only support matrix A, but also custom linear operator that can produce Ax. 12 | 3. Stable convergence. 13 | 14 | ![](figs/relative_cg_test_4096.png) 15 | 16 | ![](figs/relative_gmres_test_2048.png) 17 | 18 | ## How to use 19 | 20 | ### Demo 1 21 | ```python 22 | import torch 23 | from linalg import CG, GMRES 24 | 25 | A = torch.tensor([[3.0, 1.0, 0.0], 26 | [1.0, 2.0, -1.0], 27 | [0.0, -1.0, 1.0]]) 28 | 29 | b = torch.tensor([1.0, 2.0, 3.0]) 30 | 31 | sol1, info = CG(A, b) 32 | print(f'Solution by CG: {sol1}') 33 | 34 | sol2, info = GMRES(A, b) 35 | print(f'Solution by GMRES: {sol2}') 36 | 37 | ``` 38 | Remark: `info` is a tuple where `info[0]` is the number of iterations and `info[1]` is a list of relative residual error at each iteration. 39 | 40 | ### Demo 2 41 | 42 | ```python 43 | import torch 44 | from linalg import CG, GMRES 45 | from functools import partial 46 | 47 | A = torch.tensor([[3.0, 1.0, 0.0], 48 | [1.0, 2.0, -1.0], 49 | [0.0, -1.0, 1.0]]) 50 | 51 | b = torch.tensor([1.0, 2.0, 3.0]) 52 | 53 | def Avp(A, vec): 54 | return A @ vec 55 | 56 | # create custom linear operator that produces Ax 57 | LinOp = partial(Avp, A) 58 | 59 | sol3, info = CG(LinOp, b) 60 | print(f'Solution by CG: {sol3}') 61 | 62 | sol4, info = GMRES(LinOp, b) 63 | print(f'Solution by GMRES: {sol4}') 64 | 65 | ``` 66 | See more examples in `test.py`. -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from linalg import CG, GMRES 3 | from functools import partial 4 | 5 | A = torch.tensor([[3.0, 1.0, 0.0], 6 | [1.0, 2.0, -1.0], 7 | [0.0, -1.0, 1.0]]) 8 | 9 | b = torch.tensor([1.0, 2.0, 3.0]) 10 | 11 | sol1, _ = CG(A, b) 12 | print(f'Solution by CG: {sol1}') 13 | 14 | sol2, _ = GMRES(A, b) 15 | print(f'Solution by GMRES: {sol2}') 16 | 17 | 18 | def Avp(A, vec): 19 | return A @ vec 20 | 21 | # create custom linear operator that produces Ax 22 | LinOp = partial(Avp, A) 23 | 24 | sol3, _ = CG(LinOp, b) 25 | print(f'Solution by CG: {sol3}') 26 | 27 | sol4, _ = GMRES(LinOp, b) 28 | print(f'Solution by GMRES: {sol4}') -------------------------------------------------------------------------------- /figs/relative_cg_test_2048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devzhk/Pytorch-linalg/970a63ffe69d5a5f4f3aef3812b82b1cf37a414f/figs/relative_cg_test_2048.png -------------------------------------------------------------------------------- /figs/relative_cg_test_4096.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devzhk/Pytorch-linalg/970a63ffe69d5a5f4f3aef3812b82b1cf37a414f/figs/relative_cg_test_4096.png -------------------------------------------------------------------------------- /figs/relative_gmres_test_1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devzhk/Pytorch-linalg/970a63ffe69d5a5f4f3aef3812b82b1cf37a414f/figs/relative_gmres_test_1024.png -------------------------------------------------------------------------------- /figs/relative_gmres_test_4096.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devzhk/Pytorch-linalg/970a63ffe69d5a5f4f3aef3812b82b1cf37a414f/figs/relative_gmres_test_4096.png -------------------------------------------------------------------------------- /linalg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | 4 | 5 | def _check_nan(vec, msg): 6 | if torch.isnan(vec).any(): 7 | raise ValueError(msg) 8 | 9 | 10 | def _safe_normalize(x, threshold=None): 11 | norm = torch.norm(x) 12 | if threshold is None: 13 | threshold = torch.finfo(norm.dtype).eps 14 | normalized_x = x / norm if norm > threshold else torch.zeros_like(x) 15 | return normalized_x, norm 16 | 17 | 18 | def Mvp(A, vec): 19 | return A @ vec 20 | 21 | 22 | def arnoldi(vec, # Matrix vector product 23 | V, # List of existing basis 24 | H, # H matrix 25 | j): # number of basis 26 | ''' 27 | Arnoldi iteration to find the j th l2-orthonormal vector 28 | compute the j-1 th column of Hessenberg matrix 29 | ''' 30 | _check_nan(vec, 'Matrix vector product is Nan') 31 | 32 | for i in range(j): 33 | H[i, j - 1] = torch.dot(vec, V[i]) 34 | vec = vec - H[i, j-1] * V[i] 35 | new_v, vnorm = _safe_normalize(vec) 36 | H[j, j - 1] = vnorm 37 | return new_v 38 | 39 | 40 | def cal_rotation(a, b): 41 | ''' 42 | Args: 43 | a: element h in position j 44 | b: element h in position j+1 45 | Returns: 46 | cosine = a / \sqrt{a^2 + b^2} 47 | sine = - b / \sqrt{a^2 + b^2} 48 | ''' 49 | c = torch.sqrt(a * a + b * b) 50 | return a / c, - b / c 51 | 52 | 53 | def apply_given_rotation(H, cs, ss, j): 54 | ''' 55 | Apply givens rotation to H columns 56 | :param H: 57 | :param cs: 58 | :param ss: 59 | :param j: 60 | :return: 61 | ''' 62 | # apply previous rotation to the 0->j-1 columns 63 | for i in range(j): 64 | tmp = cs[i] * H[i, j] - ss[i] * H[i + 1, j] 65 | H[i + 1, j] = cs[i] * H[i+1, j] + ss[i] * H[i, j] 66 | H[i, j] = tmp 67 | cs[j], ss[j] = cal_rotation(H[j, j], H[j + 1, j]) 68 | H[j, j] = cs[j] * H[j, j] - ss[j] * H[j + 1, j] 69 | H[j + 1, j] = 0 70 | return H, cs, ss 71 | 72 | 73 | ''' 74 | GMRES solver for solving Ax=b. 75 | Reference: https://web.stanford.edu/class/cme324/saad-schultz.pdf 76 | ''' 77 | 78 | def GMRES(A, # Linear operator, matrix or function 79 | b, # RHS of the linear system in which the first half has the same shape as grad_gx, the second half has the same shape as grad_fy 80 | x0=None, # initial guess, tuple has the same shape as b 81 | max_iter=None, # maximum number of GMRES iterations 82 | tol=1e-6, # relative tolerance 83 | atol=1e-6, # absolute tolerance 84 | track=False): # If True, track the residual error of each iteration 85 | ''' 86 | Return: 87 | sol: solution 88 | (j, err_history): 89 | j is the number of iterations used to achieve the target accuracy; 90 | err_history is a list of relative residual error at each iteration if track=True, empty list otherwise. 91 | ''' 92 | if isinstance(A, torch.Tensor): 93 | Avp = partial(Mvp, A) 94 | elif hasattr(A, '__call__'): 95 | Avp = A 96 | else: 97 | raise ValueError('A must be a function or matrix') 98 | 99 | bnorm = torch.norm(b) 100 | 101 | if max_iter == 0 or bnorm < 1e-8: 102 | return b 103 | 104 | if max_iter is None: 105 | max_iter = b.shape[0] 106 | 107 | if x0 is None: 108 | x0 = torch.zeros_like(b) 109 | r0 = b 110 | else: 111 | r0 = b - Avp(x0) 112 | 113 | new_v, rnorm = _safe_normalize(r0) 114 | # initial guess residual 115 | beta = torch.zeros(max_iter + 1, device=b.device) 116 | beta[0] = rnorm 117 | err_history = [] 118 | if track: 119 | err_history.append((rnorm / bnorm).item()) 120 | 121 | V = [] 122 | V.append(new_v) 123 | H = torch.zeros((max_iter + 1, max_iter + 1), device=b.device) 124 | cs = torch.zeros(max_iter, device=b.device) # cosine values at each step 125 | ss = torch.zeros(max_iter, device=b.device) # sine values at each step 126 | 127 | for j in range(max_iter): 128 | p = Avp(V[j]) 129 | new_v = arnoldi(p, V, H, j + 1) # Arnoldi iteration to get the j+1 th basis 130 | V.append(new_v) 131 | 132 | H, cs, ss = apply_given_rotation(H, cs, ss, j) 133 | _check_nan(cs, f'{j}-th cosine contains NaN') 134 | _check_nan(ss, f'{j}-th sine contains NaN') 135 | beta[j + 1] = ss[j] * beta[j] 136 | beta[j] = cs[j] * beta[j] 137 | residual = torch.abs(beta[j + 1]) 138 | if track: 139 | err_history.append((residual / bnorm).item()) 140 | if residual < tol * bnorm or residual < atol: 141 | break 142 | y, _ = torch.triangular_solve(beta[0:j + 1].unsqueeze(-1), H[0:j + 1, 0:j + 1]) # j x j 143 | V = torch.stack(V[:-1], dim=0) 144 | sol = x0 + V.T @ y.squeeze(-1) 145 | return sol, (j, err_history) 146 | 147 | 148 | ''' 149 | Conjugate Gradient algorithm for solving Ax=b. 150 | Reference: https://en.wikipedia.org/wiki/Conjugate_gradient_method 151 | ''' 152 | 153 | def CG(A, # linear operator 154 | b, # RHS of the linear system 155 | x0=None, # initial guess 156 | max_iter=None, # maximum number of iterations 157 | tol=1e-5, # relative tolerance 158 | atol=1e-6, # absolute tolerance 159 | track=False, # if True, track the residual error of each iteration 160 | ): 161 | ''' 162 | Return: 163 | sol: solution 164 | (j, err_history): 165 | j is the number of iterations used to achieve the target accuracy; 166 | err_history is a list of relative residual error at each iteration if track=True, empty list otherwise. 167 | ''' 168 | if isinstance(A, torch.Tensor): 169 | Avp = partial(Mvp, A) 170 | elif hasattr(A, '__call__'): 171 | Avp = A 172 | else: 173 | raise ValueError('A must be a function or squared matrix') 174 | 175 | if max_iter is None: 176 | max_iter = b.shape[0] 177 | if x0 is None: 178 | x = torch.zeros_like(b) 179 | r = b.detach().clone() 180 | else: 181 | Av = Avp(x0) 182 | r = b.detach().clone() - Av 183 | x = x0 184 | 185 | p = r.clone() 186 | rdotr = torch.dot(r, r) 187 | err_history = [] 188 | if track: 189 | err_history.append(rdotr.item()) 190 | 191 | residual_tol = max(tol * tol * torch.dot(b, b), atol * atol) 192 | if rdotr < residual_tol: 193 | return x, 0 194 | 195 | for i in range(max_iter): 196 | Ap = Avp(p) 197 | 198 | alpha = rdotr / torch.dot(p, Ap) 199 | x.add_(alpha * p) 200 | r.add_(-alpha * Ap) 201 | new_rdotr = torch.dot(r, r) 202 | beta = new_rdotr / rdotr 203 | p = r + beta * p 204 | rdotr = new_rdotr 205 | if track: 206 | err_history.append(rdotr.item()) 207 | if rdotr < residual_tol: 208 | break 209 | return x, (i + 1, err_history) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | scipy 3 | pytorch 4 | numpy 5 | matplotlib -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | import numpy as np 4 | from scipy.sparse.linalg import cg as sp_cg 5 | from scipy.sparse.linalg import gmres as sp_gmres 6 | from jax.scipy.sparse.linalg import cg as jx_cg 7 | from jax.scipy.sparse.linalg import gmres as jax_gmres 8 | 9 | import matplotlib.pyplot as plt 10 | import torch 11 | from functools import partial 12 | from linalg import CG, GMRES 13 | 14 | 15 | def Mvp(A, vec): 16 | return A @ vec 17 | 18 | 19 | def test_cg(A, b, maxiter, x0=None): 20 | th_A = torch.from_numpy(A) 21 | th_b = torch.from_numpy(b) 22 | if x0 is not None: 23 | th_x0 = torch.from_numpy(x0) 24 | else: 25 | th_x0 = None 26 | LinOp = partial(Mvp, th_A) 27 | 28 | sp_sol = sp_cg(A, b, x0=x0, maxiter=maxiter) 29 | # print(sp_sol) 30 | res_sp = np.linalg.norm(A @ sp_sol[0] - b) 31 | 32 | jx_sol = jx_cg(A, b, x0=x0, maxiter=maxiter) 33 | # print(jx_sol) 34 | res_jx = np.linalg.norm(A @ sp_sol[0] - b) 35 | 36 | sol,_ = CG(LinOp, th_b, x0=th_x0, tol=1e-5, max_iter=maxiter) 37 | # print(sol) 38 | res_th = torch.norm(LinOp(sol) - th_b).item() 39 | return res_sp, res_jx, res_th 40 | 41 | 42 | def test_gmres(A, b, maxiter, x0=None): 43 | th_A = torch.from_numpy(A) 44 | th_b = torch.from_numpy(b) 45 | if x0 is not None: 46 | th_x0 = torch.from_numpy(x0) 47 | else: 48 | th_x0 = None 49 | LinOp = partial(Mvp, th_A) 50 | 51 | sp_sol = sp_gmres(A, b, x0=x0, restart=maxiter, maxiter=1) 52 | # print(sp_sol) 53 | res_sp = np.linalg.norm(A @ sp_sol[0] - b) 54 | 55 | jx_sol = jax_gmres(A, b, x0=x0, restart=maxiter, maxiter=1) 56 | # print(jx_sol) 57 | res_jx = np.linalg.norm(A @ sp_sol[0] - b) 58 | 59 | sol, _ = GMRES(LinOp, th_b, x0=th_x0, tol=1e-5, max_iter=maxiter) 60 | # print(sol) 61 | res_th = torch.norm(LinOp(sol) - th_b).item() 62 | return res_sp, res_jx, res_th 63 | 64 | 65 | 66 | def plot_test(size_sys, algo='cg', err_type='abs', init_guess=False): 67 | dtype = [np.float64, np.float32] 68 | 69 | K = int(np.log2(size_sys)) + 1 70 | 71 | mat = np.random.randn(size_sys, size_sys) 72 | A = mat.T @ mat + np.identity(size_sys) 73 | b = np.random.randn(size_sys) 74 | x0 = np.random.randn(size_sys) if init_guess else None 75 | 76 | iter_list = [2 ** k for k in range(K)] 77 | 78 | for dt in dtype: 79 | if dt == np.float32: 80 | A = A.astype(dt) 81 | b = b.astype(dt) 82 | if x0 is not None: 83 | x0 = x0.astype(dt) 84 | torch.set_default_dtype(torch.float32) 85 | else: 86 | torch.set_default_dtype(torch.float64) 87 | sp_list = [] 88 | jx_list = [] 89 | th_list = [] 90 | # float64 91 | for k in iter_list: 92 | maxiter = k 93 | if algo == 'cg': 94 | res_sp, res_jx, res_th = test_cg(A, b, maxiter, x0) 95 | elif algo == 'gmres': 96 | res_sp, res_jx, res_th = test_gmres(A, b, maxiter, x0) 97 | else: 98 | raise ValueError(f'{algo} not supported') 99 | if err_type == 'relative': 100 | bnorm = np.linalg.norm(b) 101 | res_sp = res_sp / bnorm 102 | res_jx = res_jx / bnorm 103 | res_th = res_th / bnorm 104 | 105 | sp_list.append(res_sp) 106 | jx_list.append(res_jx) 107 | th_list.append(res_th) 108 | line, = plt.plot(iter_list, sp_list, label=f'scipy {algo}-{dt}', alpha=0.5, marker='*') 109 | line1, = plt.plot(iter_list, jx_list, label=f'jax {algo}-{dt}', alpha=0.5, marker='+') 110 | line2, = plt.plot(iter_list, th_list, label=f'torch {algo}-{dt}', alpha=0.5, marker='o') 111 | plt.legend() 112 | plt.yscale('log') 113 | # plt.xscale('log') 114 | plt.xlabel('Number of iterations') 115 | plt.ylabel(f'L2 error ({err_type})') 116 | plt.savefig(f'figs/{err_type}_{algo}_test_{size_sys}.png') 117 | plt.cla() 118 | 119 | 120 | if __name__ == '__main__': 121 | sizes = [128, 256, 512, 1024, 2048] 122 | algo = 'cg' 123 | err_type = 'relative' 124 | 125 | for size_sys in tqdm(sizes): 126 | plot_test(size_sys, algo, err_type, init_guess=False) 127 | -------------------------------------------------------------------------------- /test2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file tests the accuracy of residual error tracked by CG 3 | ''' 4 | #%% 5 | from tqdm import tqdm 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | from linalg import CG 11 | 12 | #%% 13 | def test(size_sys, init_guess=False): 14 | mat = np.random.randn(size_sys, size_sys) 15 | A = mat.T @ mat + np.identity(size_sys) 16 | b = np.random.randn(size_sys) 17 | x0 = np.random.randn(size_sys) if init_guess else None 18 | 19 | A = torch.from_numpy(A)#.to(torch.float32) 20 | b = torch.from_numpy(b)#.to(torch.float32) 21 | x0 = torch.from_numpy(x0)#.to(torch.float32) 22 | # x0 = None 23 | 24 | sol, (num_iter, err_list) = CG(A, b, x0, track_res=True) 25 | res_gt = torch.norm(b - A @ sol) 26 | bnorm = torch.norm(b) 27 | rel_err = res_gt / bnorm 28 | print(f'Relative error: {rel_err}') 29 | return np.sqrt(np.array(err_list)) / bnorm 30 | 31 | #%% 32 | torch.set_default_dtype(torch.float64) 33 | 34 | errs = test(512, True) 35 | # %% 36 | xs = list(range(len(errs))) 37 | 38 | plt.plot(xs, errs) 39 | plt.xlabel('Steps') 40 | plt.ylabel('Relative error') 41 | plt.yscale('log') 42 | plt.show() 43 | # %% 44 | print(errs[-1]) 45 | # %% 46 | --------------------------------------------------------------------------------