├── .gitattributes ├── docs ├── api.rst ├── images │ └── logo-block.ico ├── _templates │ └── autosummary │ │ ├── function.rst │ │ └── class.rst ├── index.rst ├── _static │ └── css │ │ └── custom.css ├── Makefile └── conf.py ├── tests ├── fdem │ ├── RHS.npy │ ├── A_data.npy │ ├── A_indices.npy │ └── A_indptr.npy ├── test_uninstalled.py ├── test_Triangle.py ├── test_conjugate.py ├── test_Mumps.py ├── test_BicgJacobi.py ├── test_Scipy.py ├── test_Wrappers.py ├── test_Pardiso.py └── test_Basic.py ├── MANIFEST.in ├── .git_archival.txt ├── Makefile ├── .gitignore ├── pymatsolver ├── direct │ ├── __init__.py │ ├── mumps.py │ └── pardiso.py ├── __init__.py ├── iterative.py ├── wrappers.py └── solvers.py ├── LICENSE ├── README.rst ├── pyproject.toml └── .github └── workflows └── python-package-conda.yml /.gitattributes: -------------------------------------------------------------------------------- 1 | .git_archival.txt export-subst -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: pymatsolver 2 | -------------------------------------------------------------------------------- /tests/fdem/RHS.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simpeg/pymatsolver/main/tests/fdem/RHS.npy -------------------------------------------------------------------------------- /tests/fdem/A_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simpeg/pymatsolver/main/tests/fdem/A_data.npy -------------------------------------------------------------------------------- /tests/fdem/A_indices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simpeg/pymatsolver/main/tests/fdem/A_indices.npy -------------------------------------------------------------------------------- /tests/fdem/A_indptr.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simpeg/pymatsolver/main/tests/fdem/A_indptr.npy -------------------------------------------------------------------------------- /docs/images/logo-block.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simpeg/pymatsolver/main/docs/images/logo-block.ico -------------------------------------------------------------------------------- /docs/_templates/autosummary/function.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autofunction:: {{ objname }} 6 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune .github 2 | prune docs 3 | prune tests 4 | exclude .gitignore MANIFEST.in .pre-commit-config.yaml 5 | exclude .git_archival.txt .gitattributes Makefile -------------------------------------------------------------------------------- /.git_archival.txt: -------------------------------------------------------------------------------- 1 | node: 55ef77d46261e83d75de6eb3b1782a7af2c6a1b3 2 | node-date: 2025-10-15T12:55:55-06:00 3 | describe-name: v0.4.0 4 | ref-names: HEAD -> main, tag: v0.4.0 -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: coverage tests docs 2 | 3 | coverage: 4 | pytest --cov --cov-config=pyproject.toml -s -v 5 | coverage xml 6 | 7 | tests: 8 | pytest 9 | 10 | docs: 11 | cd docs;make html 12 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | 3 | .. toctree:: 4 | :maxdepth: 2 5 | 6 | api 7 | 8 | Indices and tables 9 | ================== 10 | 11 | * :ref:`genindex` 12 | * :ref:`modindex` 13 | * :ref:`search` 14 | 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .coverage 2 | coverage.xml 3 | coverage_html_report/ 4 | *.pyc 5 | *.so 6 | build/ 7 | dist/ 8 | pymatsolver.egg-info/ 9 | *.sublime-workspace 10 | *.sublime-project 11 | *.dSYM 12 | docs/_build/* 13 | docs/generated 14 | 15 | .idea/ 16 | -------------------------------------------------------------------------------- /pymatsolver/direct/__init__.py: -------------------------------------------------------------------------------- 1 | from scipy.sparse.linalg import spsolve, splu 2 | 3 | from ..wrappers import WrapDirect 4 | from .pardiso import Pardiso 5 | from .mumps import Mumps 6 | 7 | Solver = WrapDirect(spsolve, factorize=False, name="Solver") 8 | SolverLU = WrapDirect(splu, factorize=True, name="SolverLU") 9 | 10 | __all__ = ["Solver", "SolverLU", "Pardiso", "Mumps"] 11 | -------------------------------------------------------------------------------- /tests/test_uninstalled.py: -------------------------------------------------------------------------------- 1 | import pymatsolver 2 | import pytest 3 | import scipy.sparse as sp 4 | 5 | @pytest.mark.skipif(pymatsolver.AvailableSolvers["Mumps"], reason="Mumps is installed.") 6 | def test_mumps_uninstalled(): 7 | with pytest.raises(ImportError): 8 | pymatsolver.Mumps(sp.eye(4)) 9 | 10 | @pytest.mark.skipif(pymatsolver.AvailableSolvers["Pardiso"], reason="Pardiso is installed.") 11 | def test_pydiso_uninstalled(): 12 | with pytest.raises(ImportError): 13 | pymatsolver.Pardiso(sp.eye(4)) 14 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | @import url(https://fonts.googleapis.com/css?family=Raleway); 2 | 3 | .navbar-light { 4 | background: #acd6af!important; 5 | } 6 | .navbar-nav>.active>.nav-link { 7 | color: #000000!important; 8 | } 9 | 10 | #navbar-icon-links i.fa, #navbar-icon-links i.fab, 11 | #navbar-icon-links i.far, #navbar-icon-links i.fas { 12 | color: #000000; 13 | } 14 | 15 | body{ 16 | font-family: "Helvetica Neue", Helvetica, Arial, sans-serif; 17 | } 18 | 19 | h1{ 20 | font-family: "Raleway", Helvetica, Arial, sans-serif; font-weight: bold; 21 | } 22 | h2{ 23 | font-family: "Raleway", Helvetica, Arial, sans-serif; font-weight: bold; 24 | } 25 | h3{ 26 | font-family: "Raleway", Helvetica, Arial, sans-serif; font-weight: bold; 27 | } 28 | .column > h3{ 29 | font-family: "Raleway", Helvetica, Arial, sans-serif; 30 | } 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013-2016 Rowan Cockett 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /tests/test_Triangle.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.testing as npt 3 | import scipy.sparse as sp 4 | import pymatsolver 5 | import pytest 6 | 7 | TOL = 1e-12 8 | 9 | @pytest.mark.parametrize("solver", [pymatsolver.Triangle, pymatsolver.Forward, pymatsolver.Backward]) 10 | @pytest.mark.parametrize("transpose", [True, False]) 11 | def test_solve(solver, transpose): 12 | n = 50 13 | nrhs = 20 14 | A = sp.rand(n, n, 0.4) + sp.identity(n) 15 | sol = np.ones((n, nrhs)) 16 | if solver is pymatsolver.Backward: 17 | A = sp.triu(A) 18 | lower = False 19 | else: 20 | A = sp.tril(A) 21 | lower = True 22 | 23 | if transpose: 24 | rhs = A.T @ sol 25 | Ainv = solver(A, lower=lower).T 26 | else: 27 | rhs = A @ sol 28 | Ainv = solver(A, lower=lower) 29 | 30 | npt.assert_allclose(Ainv * rhs, sol, atol=TOL) 31 | npt.assert_allclose(Ainv * rhs[:, 0], sol[:, 0], atol=TOL) 32 | 33 | 34 | def test_triangle_errors(): 35 | A = sp.eye(5, format='csc') 36 | 37 | with pytest.raises(TypeError, match="lower must be a bool."): 38 | Ainv = pymatsolver.Forward(A) 39 | Ainv.lower = 1 40 | 41 | 42 | def test_mat_convert(): 43 | Ainv = pymatsolver.Forward(sp.eye(5, format='coo')) 44 | x = np.arange(5) 45 | npt.assert_allclose(Ainv @ x, x) 46 | 47 | 48 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname }} 2 | {{ underline }} 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. inheritance-diagram:: {{ objname }} 7 | :parts: 1 8 | 9 | .. autoclass:: {{ objname }} 10 | 11 | {% block methods %} 12 | .. HACK -- the point here is that we don't want this to appear in the output, but the autosummary should still generate the pages. 13 | .. autosummary:: 14 | :toctree: 15 | {% for item in all_methods %} 16 | {%- if not item.startswith('_') or item in ['__call__', '__mul__', '__getitem__', '__len__'] %} 17 | {{ name }}.{{ item }} 18 | {%- endif -%} 19 | {%- endfor %} 20 | {% for item in inherited_members %} 21 | {%- if item in ['__call__', '__mul__', '__getitem__', '__len__'] %} 22 | {{ name }}.{{ item }} 23 | {%- endif -%} 24 | {%- endfor %} 25 | {% endblock %} 26 | 27 | {% block attributes %} 28 | {% if attributes %} 29 | .. HACK -- the point here is that we don't want this to appear in the output, but the autosummary should still generate the pages. 30 | .. autosummary:: 31 | :toctree: 32 | {% for item in all_attributes %} 33 | {%- if not item.startswith('_') %} 34 | {{ name }}.{{ item }} 35 | {%- endif -%} 36 | {%- endfor %} 37 | {% endif %} 38 | {% endblock %} 39 | -------------------------------------------------------------------------------- /tests/test_conjugate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pymatsolver 3 | import numpy as np 4 | import scipy.sparse as sp 5 | import numpy.testing as npt 6 | 7 | 8 | @pytest.mark.parametrize('solver_class', [pymatsolver.Solver, pymatsolver.SolverLU, pymatsolver.Pardiso, pymatsolver.Mumps]) 9 | @pytest.mark.parametrize('dtype', [np.float64, np.complex128]) 10 | @pytest.mark.parametrize('n_rhs', [1, 4]) 11 | def test_conjugate_solve(solver_class, dtype, n_rhs): 12 | if solver_class is pymatsolver.Pardiso and not pymatsolver.AvailableSolvers['Pardiso']: 13 | pytest.skip("pydiso not installed.") 14 | if solver_class is pymatsolver.Mumps and not pymatsolver.AvailableSolvers['Mumps']: 15 | pytest.skip("python-mumps not installed.") 16 | 17 | n = 10 18 | D = sp.diags(np.linspace(1, 10, n)) 19 | if dtype == np.float64: 20 | L = sp.diags([1, -1], [0, -1], shape=(n, n)) 21 | 22 | sol = np.linspace(0.9, 1.1, n) 23 | # non-symmetric real matrix 24 | else: 25 | # non-symmetric 26 | L = sp.diags([1, -1j], [0, -1], shape=(n, n)) 27 | sol = np.linspace(0.9, 1.1, n) - 1j * np.linspace(0.9, 1.1, n)[::-1] 28 | 29 | if n_rhs > 1: 30 | sol = np.pad(sol[:, None], [(0, 0), (0, n_rhs - 1)], mode='constant') 31 | 32 | A = D @ L @ D @ L.T 33 | 34 | # double check it solves 35 | rhs = A @ sol 36 | Ainv = solver_class(A) 37 | npt.assert_allclose(Ainv @ rhs, sol) 38 | 39 | # is conjugate solve correct? 40 | rhs_conj = A.conjugate() @ sol 41 | Ainv_conj = Ainv.conjugate() 42 | npt.assert_allclose(Ainv_conj @ rhs_conj, sol) 43 | 44 | # is conjugate -> conjugate solve correct? 45 | Ainv2 = Ainv_conj.conjugate() 46 | npt.assert_allclose(Ainv2 @ rhs, sol) -------------------------------------------------------------------------------- /tests/test_Mumps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import pymatsolver 4 | import pytest 5 | import numpy.testing as npt 6 | 7 | if not pymatsolver.AvailableSolvers['Mumps']: 8 | pytest.skip(reason="MUMPS solver is not installed", allow_module_level=True) 9 | 10 | TOL = 1e-11 11 | 12 | @pytest.fixture() 13 | def test_mat_data(): 14 | nSize = 100 15 | A = sp.rand(nSize, nSize, 0.05, format='csr', random_state=100) 16 | A = A + sp.spdiags(np.ones(nSize), 0, nSize, nSize) 17 | A = A.T*A 18 | A = A.tocsr() 19 | sol = np.linspace(0.9, 1.1, nSize) 20 | sol = np.repeat(sol[:, None], 5, axis=-1) 21 | return A, sol 22 | 23 | 24 | @pytest.mark.parametrize('transpose', [True, False]) 25 | @pytest.mark.parametrize('dtype', [np.float64, np.complex128]) 26 | @pytest.mark.parametrize('symmetric', [True, False]) 27 | def test_solve(test_mat_data, dtype, transpose, symmetric): 28 | A, sol = test_mat_data 29 | sol = sol.astype(dtype) 30 | A = A.astype(dtype) 31 | if not symmetric: 32 | D = sp.diags(np.linspace(2, 3, A.shape[0])) 33 | A = D @ A 34 | rhs = A @ sol 35 | if transpose: 36 | Ainv = pymatsolver.Mumps(A.T, is_symmetric=symmetric).T 37 | else: 38 | Ainv = pymatsolver.Mumps(A, is_symmetric=symmetric) 39 | for i in range(rhs.shape[1]): 40 | npt.assert_allclose(Ainv * rhs[:, i], sol[:, i], atol=TOL) 41 | npt.assert_allclose(Ainv * rhs, sol, atol=TOL) 42 | 43 | 44 | def test_refactor(test_mat_data): 45 | A, sol = test_mat_data 46 | rhs = A @ sol 47 | Ainv = pymatsolver.Mumps(A, is_symmetric=True) 48 | npt.assert_allclose(Ainv * rhs, sol, atol=TOL) 49 | 50 | # scale rows and columns 51 | D = sp.diags(np.random.rand(A.shape[0]) + 1.0) 52 | A2 = D.T @ A @ D 53 | 54 | rhs2 = A2 @ sol 55 | Ainv.factor(A2) 56 | npt.assert_allclose(Ainv * rhs2, sol, atol=TOL) -------------------------------------------------------------------------------- /tests/test_BicgJacobi.py: -------------------------------------------------------------------------------- 1 | from pymatsolver import BicgJacobi 2 | import numpy as np 3 | import numpy.testing as npt 4 | import scipy.sparse as sp 5 | import pytest 6 | 7 | RTOL = 1e-5 8 | 9 | @pytest.fixture() 10 | def test_mat_data(): 11 | nSize = 100 12 | A = sp.rand(nSize, nSize, 0.05, format='csr', random_state=100) 13 | A = A + sp.spdiags(np.ones(nSize), 0, nSize, nSize) 14 | A = A.T*A 15 | A = A.tocsr() 16 | np.random.seed(1) 17 | sol = np.random.rand(nSize, 4) 18 | return A, sol 19 | 20 | 21 | @pytest.mark.parametrize('transpose', [True, False]) 22 | @pytest.mark.parametrize('dtype', [np.float64, np.complex128]) 23 | @pytest.mark.parametrize('symmetric', [True, False]) 24 | def test_solve(test_mat_data, dtype, transpose, symmetric): 25 | A, sol = test_mat_data 26 | A = A.astype(dtype) 27 | sol = sol.astype(dtype) 28 | if not symmetric: 29 | D = sp.diags(np.linspace(2, 3, A.shape[0])) 30 | A = D @ A 31 | rhs = A @ sol 32 | if transpose: 33 | Ainv = BicgJacobi(A.T, is_symmetric=symmetric).T 34 | else: 35 | Ainv = BicgJacobi(A, is_symmetric=symmetric) 36 | Ainv.maxiter = 2000 37 | solb = Ainv * rhs 38 | npt.assert_allclose(rhs, A @ solb, rtol=RTOL) 39 | 40 | def test_errors_and_warnings(test_mat_data): 41 | A, sol = test_mat_data 42 | with pytest.raises(TypeError, match="The symmetric keyword.*"): 43 | Ainv = BicgJacobi(A, symmetric=True) 44 | 45 | with pytest.raises(ValueError): 46 | Ainv = BicgJacobi(A, rtol=0.0) 47 | 48 | with pytest.raises(ValueError): 49 | Ainv = BicgJacobi(A, atol=-1.0) 50 | 51 | def test_shallow_copy(test_mat_data): 52 | A, sol = test_mat_data 53 | Ainv = BicgJacobi(A, maxiter=100, rtol=1.0E-3, atol=1.0E-16) 54 | 55 | attrs = Ainv.get_attributes() 56 | 57 | new_Ainv = BicgJacobi(A, **attrs) 58 | assert attrs == new_Ainv.get_attributes() -------------------------------------------------------------------------------- /tests/test_Scipy.py: -------------------------------------------------------------------------------- 1 | from pymatsolver import Solver, Diagonal, SolverCG, SolverLU 2 | import scipy.sparse as sp 3 | from scipy.sparse.linalg import aslinearoperator 4 | import numpy as np 5 | import numpy.testing as npt 6 | import pytest 7 | 8 | 9 | TOLD = 1e-10 10 | TOLI = 1e-3 11 | 12 | @pytest.fixture() 13 | def a_matrix(): 14 | nx, ny, nz = 10, 10, 10 15 | n = nx * ny * nz 16 | Gz = sp.kron( 17 | sp.eye(nx), 18 | sp.kron( 19 | sp.eye(ny), 20 | sp.diags([-1, 1], [-1, 0], shape=(nz+1, nz)) 21 | ) 22 | ) 23 | Gy = sp.kron( 24 | sp.eye(nx), 25 | sp.kron( 26 | sp.diags([-1, 1], [-1, 0], shape=(ny+1, ny)), 27 | sp.eye(nz), 28 | ) 29 | ) 30 | Gx = sp.kron( 31 | sp.diags([-1, 1], [-1, 0], shape=(nx+1, nx)), 32 | sp.kron( 33 | sp.eye(ny), 34 | sp.eye(nz), 35 | ) 36 | ) 37 | A = Gx.T @ Gx + Gy.T @ Gy + Gz.T @ Gz 38 | return A 39 | 40 | 41 | @pytest.mark.parametrize('n_rhs', [1, 5]) 42 | @pytest.mark.parametrize('solver', [Solver, SolverLU, SolverCG]) 43 | def test_solver(a_matrix, n_rhs, solver): 44 | if solver is SolverCG: 45 | tol = TOLI 46 | else: 47 | tol = TOLD 48 | 49 | n = a_matrix.shape[0] 50 | b = np.linspace(0.9, 1.1, n) 51 | if n_rhs > 1: 52 | b = np.repeat(b[:, None], n_rhs, axis=-1) 53 | rhs = a_matrix @ b 54 | 55 | Ainv = solver(a_matrix) 56 | x = Ainv * rhs 57 | Ainv.clean() 58 | 59 | npt.assert_allclose(x, b, atol=tol) 60 | 61 | @pytest.mark.parametrize('dtype', [np.float64, np.complex128]) 62 | def test_iterative_solver_linear_op(dtype): 63 | n = 10 64 | A = aslinearoperator(sp.eye(n).astype(dtype)) 65 | 66 | Ainv = SolverCG(A) 67 | 68 | rhs = np.linspace(0.9, 1.1, n) 69 | 70 | npt.assert_allclose(Ainv @ rhs, rhs) 71 | 72 | @pytest.mark.parametrize('n_rhs', [1, 5]) 73 | def test_diag_solver(n_rhs): 74 | n = 10 75 | A = sp.diags(np.linspace(2, 3, n)) 76 | b = np.linspace(0.9, 1.1, n) 77 | if n_rhs > 1: 78 | b = np.repeat(b[:, None], n_rhs, axis=-1) 79 | rhs = A @ b 80 | 81 | Ainv = Diagonal(A) 82 | x = Ainv * rhs 83 | 84 | npt.assert_allclose(x, b, atol=TOLD) -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | pymatsolver 2 | *********** 3 | 4 | .. image:: https://img.shields.io/pypi/v/pymatsolver.svg 5 | :target: https://pypi.python.org/pypi/pymatsolver 6 | :alt: Latest PyPI version 7 | 8 | .. image:: https://img.shields.io/badge/license-MIT-blue.svg 9 | :target: https://github.com/simpeg/pymatsolver/blob/master/LICENSE 10 | :alt: MIT license. 11 | 12 | .. image:: https://codecov.io/gh/simpeg/pymatsolver/branch/main/graph/badge.svg?token=8uQoxzxf3r 13 | :target: https://codecov.io/gh/simpeg/pymatsolver 14 | :alt: Coverage status 15 | 16 | 17 | A (sparse) matrix solver for python. 18 | 19 | Solving Ax = b should be as easy as: 20 | 21 | .. code-block:: python 22 | 23 | Ainv = Solver(A) 24 | x = Ainv * b 25 | 26 | In pymatsolver we provide a number of wrappers to existing numerical packages. Nothing fancy here. 27 | 28 | Solvers Available 29 | ================= 30 | 31 | All solvers work with :code:`scipy.sparse` matricies, and a single or multiple right hand sides using :code:`numpy`: 32 | 33 | * L/U Triangular Solves 34 | * Wrapping of SciPy matrix solvers (direct and indirect) 35 | * Pardiso solvers 36 | * Mumps solvers 37 | 38 | 39 | Installing Solvers 40 | ================== 41 | Often, there are faster solvers available for your system than the default scipy factorizations available. 42 | pymatsolver provides a consistent interface to both MKL's ``Pardiso`` routines and the ``MUMPS`` solver package. To 43 | make use of these we use intermediate wrappers for the libraries that must be installed separately. 44 | 45 | Pardiso 46 | ------- 47 | The Pardiso interface is recommended for Intel processor based systems. The interface is enabled by 48 | the ``pydiso`` python package, which can be installed through conda-forge as: 49 | 50 | .. code:: 51 | 52 | conda install -c conda-forge pydiso 53 | 54 | Mumps 55 | ----- 56 | Mumps is available for all platforms. The mumps interface is enabled by installing the ``python-mumps`` 57 | wrapper package. This can easily be installed through conda-forge with: 58 | 59 | .. code:: 60 | 61 | conda install -c conda-forge python-mumps 62 | 63 | 64 | 65 | Code: 66 | https://github.com/simpeg/pymatsolver 67 | 68 | 69 | Tests: 70 | https://github.com/simpeg/pymatsolver/actions 71 | 72 | 73 | Bugs & Issues: 74 | https://github.com/simpeg/pymatsolver/issues 75 | 76 | License: MIT 77 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64", "setuptools_scm>=8"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = 'pymatsolver' 7 | description = "pymatsolver: Matrix Solvers for Python" 8 | readme = 'README.rst' 9 | requires-python = '>=3.10' 10 | authors = [ 11 | {name = 'SimPEG developers', email = 'rowanc1@gmail.com'}, 12 | ] 13 | keywords = [ 14 | 'matrix solver', 15 | ] 16 | dependencies = [ 17 | "numpy>=1.21", 18 | "scipy>=1.8", 19 | "packaging", 20 | ] 21 | classifiers = [ 22 | 'Development Status :: 4 - Beta', 23 | 'Intended Audience :: Developers', 24 | 'Intended Audience :: Science/Research', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Programming Language :: Python', 27 | 'Topic :: Scientific/Engineering', 28 | 'Topic :: Scientific/Engineering :: Mathematics', 29 | 'Topic :: Scientific/Engineering :: Physics', 30 | 'Operating System :: Microsoft :: Windows', 31 | 'Operating System :: POSIX', 32 | 'Operating System :: Unix', 33 | 'Operating System :: MacOS', 34 | 'Natural Language :: English', 35 | ] 36 | dynamic = ["version"] 37 | 38 | [project.license] 39 | file = 'LICENSE' 40 | 41 | [project.urls] 42 | Homepage = 'https://simpeg.xyz' 43 | Documentation = 'https://simpeg.xyz/pymatsolver/' 44 | Repository = 'https://github.com/simpeg/pymatsolver' 45 | 46 | [project.optional-dependencies] 47 | pardiso = ["pydiso"] 48 | mumps = ["python-mumps"] 49 | docs = [ 50 | "sphinx", 51 | "numpydoc", 52 | "pydata-sphinx-theme" 53 | ] 54 | 55 | tests = [ 56 | "pytest", 57 | "pytest-cov", 58 | ] 59 | 60 | build = [ 61 | "setuptools_scm>=8", 62 | "setuptools>=64", 63 | ] 64 | 65 | [tool.setuptools.packages.find] 66 | include = ["pymatsolver*"] 67 | 68 | [tool.setuptools_scm] 69 | 70 | [tool.coverage.run] 71 | source = ["pymatsolver", "tests"] 72 | 73 | [tool.coverage.report] 74 | ignore_errors = false 75 | show_missing = true 76 | # Regexes for lines to exclude from consideration 77 | exclude_also = [ 78 | # Don't complain about missing debug-only code: 79 | "def __repr__", 80 | "if self\\.debug", 81 | 82 | # Don't complain if tests don't hit defensive assertion code: 83 | "raise AssertionError", 84 | "raise NotImplementedError", 85 | "AbstractMethodError", 86 | 87 | # Don't complain if non-runnable code isn't run: 88 | "if 0:", 89 | "if __name__ == .__main__.:", 90 | "except PackageNotFoundError:", 91 | 92 | # Don't complain about abstract methods, they aren't run: 93 | "@(abc\\.)?abstractmethod", 94 | ] 95 | 96 | [tool.coverage.html] 97 | directory = "coverage_html_report" 98 | 99 | -------------------------------------------------------------------------------- /tests/test_Wrappers.py: -------------------------------------------------------------------------------- 1 | from pymatsolver import SolverCG, SolverLU, wrap_direct, wrap_iterative 2 | from pymatsolver.solvers import UnusedArgumentWarning 3 | import pytest 4 | import scipy.sparse as sp 5 | import warnings 6 | import numpy.testing as npt 7 | import numpy as np 8 | import re 9 | 10 | 11 | @pytest.mark.parametrize("solver_class", [SolverCG, SolverLU]) 12 | def test_wrapper_unused_kwargs(solver_class): 13 | A = sp.eye(10) 14 | 15 | with pytest.warns(UnusedArgumentWarning, match="Unused keyword argument.*"): 16 | solver_class(A, not_a_keyword_arg=True) 17 | 18 | 19 | def test_good_arg_iterative(): 20 | # Ensure this doesn't throw a warning! 21 | with warnings.catch_warnings(): 22 | warnings.simplefilter("error") 23 | SolverCG(sp.eye(10), rtol=1e-4) 24 | 25 | 26 | def test_good_arg_direct(): 27 | # Ensure this doesn't throw a warning! 28 | with warnings.catch_warnings(): 29 | warnings.simplefilter("error") 30 | SolverLU(sp.eye(10, format='csc'), permc_spec='NATURAL') 31 | 32 | 33 | def test_bad_direct_function(): 34 | def bad_direct_func(A): 35 | class Empty(): 36 | def __init__(self, A): 37 | self.A = A 38 | # this object returned by the function doesn't have a solve method: 39 | return Empty(A) 40 | WrappedClass = wrap_direct(bad_direct_func, factorize=True) 41 | 42 | with pytest.raises(TypeError, match="instance returned by.*"): 43 | WrappedClass(sp.eye(2)) 44 | 45 | 46 | def test_direct_clean_function(): 47 | def direct_func(A): 48 | class Empty(): 49 | def __init__(self, A): 50 | self.A = A 51 | 52 | def solve(self, x): 53 | return x 54 | 55 | def clean(self): 56 | self.A = None 57 | 58 | return Empty(A) 59 | WrappedClass = wrap_direct(direct_func, factorize=True) 60 | 61 | A = sp.eye(2) 62 | Ainv = WrappedClass(A) 63 | assert Ainv.A is A 64 | assert Ainv.solver.A is A 65 | 66 | rhs = np.array([0.9, 1.0]) 67 | npt.assert_equal(Ainv @ rhs, rhs) 68 | 69 | Ainv.clean() 70 | assert Ainv.solver.A is None 71 | 72 | 73 | def test_iterative_removals(): 74 | 75 | with pytest.raises( 76 | TypeError, 77 | match=re.escape("wrap_iterative() got an unexpected keyword argument 'check_accuracy'") 78 | ): 79 | wrap_iterative(lambda a, x: x, check_accuracy=True) 80 | 81 | with pytest.raises( 82 | TypeError, 83 | match=re.escape("wrap_iterative() got an unexpected keyword argument 'accuracy_tol'") 84 | ): 85 | wrap_iterative(lambda a, x: x, accuracy_tol=1E-3) 86 | 87 | 88 | def test_non_scipy_iterative(): 89 | def iterative_solver(A, x): 90 | return x 91 | 92 | Wrapped = wrap_iterative(iterative_solver) 93 | 94 | Ainv = Wrapped(sp.eye(4)) 95 | npt.assert_equal(Ainv @ np.arange(4), np.arange(4)) 96 | -------------------------------------------------------------------------------- /pymatsolver/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | === 3 | API 4 | === 5 | .. currentmodule:: pymatsolver 6 | 7 | .. autosummary:: 8 | :toctree: generated/ 9 | 10 | solvers.Base 11 | 12 | Basic Solvers 13 | ============= 14 | 15 | Diagonal 16 | -------- 17 | .. autosummary:: 18 | :toctree: generated/ 19 | 20 | Diagonal 21 | 22 | Triangular 23 | ---------- 24 | .. autosummary:: 25 | :toctree: generated/ 26 | 27 | Triangle 28 | Forward 29 | Backward 30 | 31 | Iterative Solvers 32 | ================= 33 | 34 | .. autosummary:: 35 | :toctree: generated/ 36 | 37 | SolverCG 38 | BiCGJacobi 39 | 40 | Direct Solvers 41 | ============== 42 | 43 | .. autosummary:: 44 | :toctree: generated/ 45 | 46 | Solver 47 | SolverLU 48 | Pardiso 49 | Mumps 50 | """ 51 | 52 | # Simple solvers 53 | from .solvers import Diagonal, Triangle, Forward, Backward 54 | from .wrappers import wrap_direct, WrapDirect 55 | from .wrappers import wrap_iterative, WrapIterative 56 | 57 | # Scipy Iterative solvers 58 | from .iterative import SolverCG 59 | from .iterative import SolverBiCG 60 | from .iterative import BiCGJacobi 61 | 62 | # Scipy direct solvers 63 | from .direct import Solver, pardiso 64 | from .direct import SolverLU 65 | 66 | from .solvers import SolverAccuracyError 67 | from .direct import Pardiso, Mumps 68 | from .direct.pardiso import _available as _pardiso_available 69 | from .direct.mumps import _available as _mumps_available 70 | 71 | SolverHelp = {} 72 | AvailableSolvers = { 73 | "Diagonal": True, 74 | "Solver": True, 75 | "SolverLU": True, 76 | "SolverCG": True, 77 | "Triangle": True, 78 | "Pardiso": _pardiso_available, 79 | "Mumps": _mumps_available, 80 | } 81 | 82 | BicgJacobi = BiCGJacobi # backwards compatibility 83 | PardisoSolver = Pardiso # backwards compatibility 84 | 85 | if not AvailableSolvers["Pardiso"]: 86 | SolverHelp['Pardiso'] = """Pardiso is not working 87 | 88 | Ensure that you have pydiso installed, which may also require Python 89 | to be installed through conda. 90 | """ 91 | 92 | if not AvailableSolvers["Mumps"]: 93 | SolverHelp['Mumps'] = """Mumps is not working. 94 | 95 | Ensure that you have python-mumps installed, which may also require Python 96 | to be installed through conda. 97 | """ 98 | 99 | __author__ = 'SimPEG Team' 100 | __license__ = 'MIT' 101 | __copyright__ = '2013 - 2024, SimPEG Team, https://simpeg.xyz' 102 | 103 | from importlib.metadata import version, PackageNotFoundError 104 | 105 | # Version 106 | try: 107 | # - Released versions just tags: 0.8.0 108 | # - GitHub commits add .dev#+hash: 0.8.1.dev4+g2785721 109 | # - Uncommitted changes add timestamp: 0.8.1.dev4+g2785721.d20191022 110 | __version__ = version("pymatsolver") 111 | except PackageNotFoundError: 112 | # If it was not installed, then we don't know the version. We could throw a 113 | # warning here, but this case *should* be rare. discretize should be 114 | # installed properly! 115 | from datetime import datetime 116 | 117 | __version__ = "unknown-" + datetime.today().strftime("%Y%m%d") -------------------------------------------------------------------------------- /pymatsolver/iterative.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import scipy 4 | import scipy.sparse as sp 5 | from scipy.sparse.linalg import bicgstab, cg, aslinearoperator 6 | from packaging.version import Version 7 | from .wrappers import WrapIterative 8 | from .solvers import Base 9 | 10 | # The tol kwarg was removed from bicgstab in scipy 1.14.0. 11 | # See https://docs.scipy.org/doc/scipy-1.12.0/reference/generated/scipy.sparse.linalg.bicgstab.html 12 | RTOL_ARG_NAME = "rtol" if Version(scipy.__version__) >= Version("1.14.0") else "tol" 13 | 14 | SolverCG = WrapIterative(cg, name="SolverCG") 15 | SolverBiCG = WrapIterative(bicgstab, name="SolverBiCG") 16 | 17 | class BiCGJacobi(Base): 18 | """Diagonal pre-conditioned BiCG solver. 19 | 20 | Parameters 21 | ---------- 22 | A : matrix 23 | The matrix to solve, must have a ``diagonal()`` method. 24 | maxiter : int, optional 25 | The maximum number of BiCG iterations to perform. 26 | rtol : float, optional 27 | The relative tolerance for the BiCG solver to terminate. 28 | atol : float, optional 29 | The absolute tolerance for the BiCG solver to terminate. 30 | check_accuracy : bool, optional 31 | Whether to check the accuracy of the solution. 32 | check_rtol : float, optional 33 | The relative tolerance to check against for accuracy. 34 | check_atol : float, optional 35 | The absolute tolerance to check against for accuracy. 36 | **kwargs 37 | Extra keyword arguments passed to the base class. 38 | """ 39 | 40 | def __init__(self, A, maxiter=1000, rtol=1E-6, atol=0.0, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs): 41 | if "symmetric" in kwargs: 42 | raise TypeError("The symmetric keyword argument was been removed in pymatsolver 0.4.0.") 43 | super().__init__(A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs) 44 | self._factored = False 45 | self.maxiter = maxiter 46 | self.rtol = rtol 47 | self.atol = atol 48 | 49 | @property 50 | def maxiter(self): 51 | return self._maxiter 52 | 53 | @maxiter.setter 54 | def maxiter(self, value): 55 | self._maxiter = int(value) 56 | 57 | @property 58 | def rtol(self): 59 | return self._rtol 60 | 61 | @rtol.setter 62 | def rtol(self, value): 63 | value = float(value) 64 | if value > 0: 65 | self._rtol = value 66 | else: 67 | raise ValueError("rtol must be greater than 0.") 68 | 69 | @property 70 | def atol(self): 71 | return self._atol 72 | 73 | @atol.setter 74 | def atol(self, value): 75 | value = float(value) 76 | if value >= 0: 77 | self._atol = value 78 | else: 79 | raise ValueError("atol must be greater than or equal to 0.") 80 | 81 | def get_attributes(self): 82 | attrs = super().get_attributes() 83 | attrs["maxiter"] = self.maxiter 84 | attrs["rtol"] = self.rtol 85 | attrs["atol"] = self.atol 86 | return attrs 87 | 88 | def factor(self): 89 | if self._factored: 90 | return 91 | nSize = self.A.shape[0] 92 | Ainv = sp.spdiags(1./self.A.diagonal(), 0, nSize, nSize) 93 | self.M = aslinearoperator(Ainv) 94 | self._factored = True 95 | 96 | @property 97 | def _tols(self): 98 | return {RTOL_ARG_NAME: self.rtol, 'atol': self.atol} 99 | 100 | 101 | def _solve_single(self, rhs): 102 | self.factor() 103 | sol, info = bicgstab( 104 | self.A, rhs, 105 | maxiter=self.maxiter, 106 | M=self.M, 107 | **self._tols, 108 | ) 109 | return sol 110 | 111 | def _solve_multiple(self, rhs): 112 | self.factor() 113 | sol = np.empty_like(rhs) 114 | for icol in range(rhs.shape[1]): 115 | sol[:, icol] = self._solve_single(rhs[:, icol]) 116 | return sol 117 | 118 | -------------------------------------------------------------------------------- /.github/workflows/python-package-conda.yml: -------------------------------------------------------------------------------- 1 | name: Testing 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build-and-test: 7 | name: Testing (${{ matrix.python-version }} on ${{ matrix.os }}) with ${{ matrix.solver }}. 8 | runs-on: ${{ matrix.os }} 9 | defaults: 10 | run: 11 | shell: bash -l {0} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | os: [ubuntu-latest, windows-latest, macOS-15-intel, macOS-latest] 16 | solver: [mumps, pardiso] 17 | python-version : ['3.11', '3.12', '3.13', '3.14'] 18 | include: 19 | - os: ubuntu-latest 20 | python-version: '3.13' 21 | full-test: true 22 | exclude: 23 | - os: macOS-latest 24 | solver: pardiso 25 | 26 | steps: 27 | - uses: actions/checkout@v4 28 | - name: Setup Conda 29 | uses: conda-incubator/setup-miniconda@v3 30 | with: 31 | auto-update-conda: true 32 | channels: conda-forge 33 | python-version: ${{ matrix.python-version }} 34 | - name: Install Base Env 35 | run: | 36 | conda info 37 | conda list 38 | conda config --show 39 | conda install --quiet --yes pip numpy scipy pytest pytest-cov 40 | 41 | - name: Install MKL solver interface 42 | if: ${{ matrix.solver == 'pardiso' }} 43 | run: 44 | conda install --quiet --yes "pydiso>=0.1" 45 | 46 | - name: Install MUMPS solver interface 47 | if: ${{ matrix.solver == 'mumps' }} 48 | run: 49 | conda install --quiet --yes python-mumps 50 | 51 | - name: Install Our Package 52 | run: | 53 | pip install -v -e . 54 | conda list 55 | 56 | - name: Run Tests 57 | run: | 58 | make coverage 59 | 60 | - name: Generate Source Distribution 61 | if: ${{ matrix.full-test }} 62 | run: | 63 | pip install build twine 64 | python -m build --sdist . 65 | twine check dist/* 66 | 67 | - name: Test Documentation 68 | if: ${{ matrix.full-test }} 69 | run: | 70 | pip install .[docs] 71 | cd docs 72 | make html 73 | cd .. 74 | 75 | - name: Upload coverage 76 | if: ${{ matrix.full-test }} 77 | uses: codecov/codecov-action@v4 78 | with: 79 | verbose: true # optional (default = false) 80 | env: 81 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 82 | 83 | distribute: 84 | name: Distributing 85 | needs: build-and-test 86 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 87 | runs-on: ubuntu-latest 88 | defaults: 89 | run: 90 | shell: bash -l {0} 91 | env: 92 | PYTHON_VERSION: '3.13' 93 | 94 | steps: 95 | - uses: actions/checkout@v4 96 | - name: Setup Conda 97 | uses: conda-incubator/setup-miniconda@v3 98 | with: 99 | auto-update-conda: true 100 | channels: conda-forge 101 | python-version: ${{ env.PYTHON_VERSION }} 102 | 103 | - name: Install Base Env 104 | run: | 105 | conda info 106 | conda list 107 | conda config --show 108 | conda install --quiet --yes pip numpy scipy 109 | 110 | - name: Install Our Package 111 | run: | 112 | pip install -v -e .[docs] 113 | 114 | - name: Generate Source Distribution 115 | run: | 116 | pip install build twine 117 | python -m build --sdist . 118 | twine check dist/* 119 | 120 | - name: Build Documentation 121 | run: | 122 | cd docs 123 | make html 124 | cd .. 125 | 126 | - name: GitHub Pages 127 | uses: crazy-max/ghaction-github-pages@v4 128 | with: 129 | build_dir: docs/_build/html 130 | jekyll: false 131 | keep_history: true 132 | env: 133 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 134 | 135 | - name: pypi-publish 136 | uses: pypa/gh-action-pypi-publish@release/v1 137 | with: 138 | user: __token__ 139 | password: ${{ secrets.PYPI_API_TOKEN }} 140 | skip_existing: true 141 | -------------------------------------------------------------------------------- /pymatsolver/direct/mumps.py: -------------------------------------------------------------------------------- 1 | from pymatsolver.solvers import Base 2 | try: 3 | from mumps import Context 4 | _available = True 5 | except ImportError: 6 | Context = None 7 | _available = False 8 | 9 | class Mumps(Base): 10 | """The MUMPS direct solver. 11 | 12 | This solver uses the python-mumps wrappers to factorize a sparse matrix, and use that factorization for solving. 13 | 14 | Parameters 15 | ---------- 16 | A 17 | Matrix to solve with. 18 | ordering : str, default 'metis' 19 | Which ordering algorithm to use. See the `python-mumps` documentation for more details. 20 | is_symmetric : bool, optional 21 | Whether the matrix is symmetric. By default, it will perform some simple tests to check for symmetry, and 22 | default to ``False`` if those fail. 23 | is_positive_definite : bool, optional 24 | Whether the matrix is positive definite. 25 | check_accuracy : bool, optional 26 | Whether to check the accuracy of the solution. 27 | check_rtol : float, optional 28 | The relative tolerance to check against for accuracy. 29 | check_atol : float, optional 30 | The absolute tolerance to check against for accuracy. 31 | **kwargs 32 | Extra keyword arguments. If there are any left here a warning will be raised. 33 | """ 34 | _transposed = False 35 | 36 | def __init__(self, A, ordering=None, is_symmetric=None, is_positive_definite=False, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs): 37 | if not _available: 38 | raise ImportError( 39 | "The Mumps solver requires the python-mumps package to be installed." 40 | ) 41 | is_hermitian = kwargs.pop('is_hermitian', False) 42 | super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs) 43 | if ordering is None: 44 | ordering = "metis" 45 | self.ordering = ordering 46 | self.solver = Context() 47 | self._set_A(self.A) 48 | 49 | def _set_A(self, A): 50 | self.solver.set_matrix( 51 | A, 52 | symmetric=self.is_symmetric, 53 | ) 54 | 55 | @property 56 | def ordering(self): 57 | """The ordering algorithm to use. 58 | 59 | Returns 60 | ------- 61 | str 62 | """ 63 | return self._ordering 64 | 65 | @ordering.setter 66 | def ordering(self, value): 67 | self._ordering = str(value) 68 | 69 | @property 70 | def _factored(self): 71 | return self.solver.factored 72 | 73 | def get_attributes(self): 74 | attrs = super().get_attributes() 75 | attrs['ordering'] = self.ordering 76 | return attrs 77 | 78 | def transpose(self): 79 | trans_obj = Mumps.__new__(Mumps) 80 | trans_obj._A = self.A 81 | for attr, value in self.get_attributes().items(): 82 | setattr(trans_obj, attr, value) 83 | trans_obj.solver = self.solver 84 | trans_obj._transposed = not self._transposed 85 | return trans_obj 86 | 87 | def factor(self, A=None): 88 | """(Re)factor the A matrix. 89 | 90 | Parameters 91 | ---------- 92 | A : scipy.sparse.spmatrix 93 | The matrix to be factorized. If a previous factorization has been performed, this will 94 | reuse the previous factorization's analysis. 95 | """ 96 | reuse_analysis = self._factored 97 | do_factor = not self._factored 98 | if A is not None and A is not self.A: 99 | # if it was previously factored then re-use the analysis. 100 | self._set_A(A) 101 | self._A = A 102 | do_factor = True 103 | if do_factor: 104 | pivot_tol = 0.0 if self.is_positive_definite else 0.01 105 | self.solver.factor( 106 | ordering=self.ordering, reuse_analysis=reuse_analysis, pivot_tol=pivot_tol 107 | ) 108 | 109 | def _solve_multiple(self, rhs): 110 | self.factor() 111 | if self._transposed: 112 | self.solver.mumps_instance.icntl[9] = 0 113 | else: 114 | self.solver.mumps_instance.icntl[9] = 1 115 | sol = self.solver.solve(rhs) 116 | return sol 117 | 118 | _solve_single = _solve_multiple -------------------------------------------------------------------------------- /pymatsolver/direct/pardiso.py: -------------------------------------------------------------------------------- 1 | from pymatsolver.solvers import Base 2 | try: 3 | from pydiso.mkl_solver import MKLPardisoSolver 4 | from pydiso.mkl_solver import set_mkl_pardiso_threads, get_mkl_pardiso_max_threads 5 | _available = True 6 | except ImportError: 7 | _available = False 8 | 9 | class Pardiso(Base): 10 | """The Pardiso direct solver. 11 | 12 | This solver uses the `pydiso` Intel MKL wrapper to factorize a sparse matrix, and use that 13 | factorization for solving. 14 | 15 | Parameters 16 | ---------- 17 | A : scipy.sparse.spmatrix 18 | Matrix to solve with. 19 | n_threads : int, optional 20 | Number of threads to use for the `Pardiso` routine in Intel's MKL. 21 | is_symmetric : bool, optional 22 | Whether the matrix is symmetric. By default, it will perform some simple tests to check for symmetry, and 23 | default to ``False`` if those fail. 24 | is_positive_definite : bool, optional 25 | Whether the matrix is positive definite. 26 | is_hermitian : bool, optional 27 | Whether the matrix is hermitian. By default, it will perform some simple tests to check, and default to 28 | ``False`` if those fail. 29 | check_accuracy : bool, optional 30 | Whether to check the accuracy of the solution. 31 | check_rtol : float, optional 32 | The relative tolerance to check against for accuracy. 33 | check_atol : float, optional 34 | The absolute tolerance to check against for accuracy. 35 | **kwargs 36 | Extra keyword arguments. If there are any left here a warning will be raised. 37 | """ 38 | 39 | _transposed = False 40 | 41 | def __init__(self, A, n_threads=None, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs): 42 | if not _available: 43 | raise ImportError("Pardiso solver requires the pydiso package to be installed.") 44 | super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs) 45 | self.solver = MKLPardisoSolver( 46 | self.A, 47 | matrix_type=self._matrixType(), 48 | factor=False 49 | ) 50 | if n_threads is not None: 51 | self.n_threads = n_threads 52 | 53 | def _matrixType(self): 54 | """ 55 | Set basic matrix type: 56 | 57 | Real:: 58 | 59 | 1: structurally symmetric 60 | 2: symmetric positive definite 61 | -2: symmetric indefinite 62 | 11: nonsymmetric 63 | 64 | Complex:: 65 | 66 | 6: symmetric 67 | 4: hermitian positive definite 68 | -4: hermitian indefinite 69 | 3: structurally symmetric 70 | 13: nonsymmetric 71 | 72 | """ 73 | if self.is_real: 74 | if self.is_symmetric: 75 | if self.is_positive_definite: 76 | return 2 77 | else: 78 | return -2 79 | else: 80 | return 11 81 | else: 82 | if self.is_symmetric: 83 | return 6 84 | elif self.is_hermitian: 85 | if self.is_positive_definite: 86 | return 4 87 | else: 88 | return -4 89 | else: 90 | return 13 91 | 92 | def factor(self, A=None): 93 | """(Re)factor the A matrix. 94 | 95 | Parameters 96 | ---------- 97 | A : scipy.sparse.spmatrix 98 | The matrix to be factorized. If a previous factorization has been performed, this will 99 | reuse the previous factorization's analysis. 100 | """ 101 | if A is not None and self.A is not A: 102 | self._A = A 103 | self.solver.refactor(self.A) 104 | 105 | def _solve_multiple(self, rhs): 106 | sol = self.solver.solve(rhs, transpose=self._transposed) 107 | return sol 108 | 109 | def transpose(self): 110 | trans_obj = Pardiso.__new__(Pardiso) 111 | trans_obj._A = self.A 112 | for attr, value in self.get_attributes().items(): 113 | setattr(trans_obj, attr, value) 114 | trans_obj.solver = self.solver 115 | trans_obj._transposed = not self._transposed 116 | return trans_obj 117 | 118 | @property 119 | def n_threads(self): 120 | """Number of threads to use for the Pardiso solver routine. 121 | 122 | This property is global to all Pardiso solver objects for a single python process. 123 | 124 | Returns 125 | ------- 126 | int 127 | """ 128 | return get_mkl_pardiso_max_threads() 129 | 130 | @n_threads.setter 131 | def n_threads(self, n_threads): 132 | set_mkl_pardiso_threads(n_threads) 133 | 134 | _solve_single = _solve_multiple 135 | -------------------------------------------------------------------------------- /tests/test_Pardiso.py: -------------------------------------------------------------------------------- 1 | import pymatsolver 2 | import numpy as np 3 | import numpy.testing as npt 4 | import pytest 5 | import scipy.sparse as sp 6 | import os 7 | 8 | if not pymatsolver.AvailableSolvers['Pardiso']: 9 | pytest.skip(reason="Pardiso solver is not installed", allow_module_level=True) 10 | else: 11 | from pydiso.mkl_solver import get_mkl_pardiso_max_threads 12 | 13 | TOL = 1e-10 14 | 15 | @pytest.fixture() 16 | def test_mat_data(): 17 | nSize = 100 18 | A = sp.rand(nSize, nSize, 0.05, format='csr', random_state=100) 19 | A = A + sp.spdiags(np.ones(nSize), 0, nSize, nSize) 20 | A = A.T*A 21 | A = A.tocsr() 22 | sol = np.linspace(0.9, 1.1, nSize) 23 | sol = np.repeat(sol[:, None], 5, axis=-1) 24 | return A, sol 25 | 26 | @pytest.mark.parametrize('transpose', [True, False]) 27 | @pytest.mark.parametrize('dtype', [np.float64, np.complex128]) 28 | @pytest.mark.parametrize('symmetry', ["S", "H", None]) 29 | def test_solve(test_mat_data, dtype, transpose, symmetry): 30 | 31 | A, sol = test_mat_data 32 | 33 | if symmetry is None: 34 | D = sp.diags(np.linspace(2, 3, A.shape[0])) 35 | A = D @ A 36 | symmetric = False 37 | hermitian = False 38 | elif symmetry == "H": 39 | D = sp.diags(np.linspace(2, 3, A.shape[0])) 40 | if np.issubdtype(dtype, np.complexfloating): 41 | D = D + 1j * sp.diags(np.linspace(3, 4, A.shape[0])) 42 | A = D @ A @ D.T.conjugate() 43 | symmetric = False 44 | hermitian = True 45 | else: 46 | symmetric = True 47 | hermitian = False 48 | 49 | sol = sol.astype(dtype) 50 | A = A.astype(dtype) 51 | 52 | Ainv = pymatsolver.Pardiso(A, is_symmetric=symmetric, is_hermitian=hermitian) 53 | if transpose: 54 | rhs = A.T @ sol 55 | Ainv = Ainv.T 56 | else: 57 | rhs = A @ sol 58 | 59 | for i in range(rhs.shape[1]): 60 | npt.assert_allclose(Ainv * rhs[:, i], sol[:, i], atol=TOL) 61 | npt.assert_allclose(Ainv * rhs, sol, atol=TOL) 62 | 63 | @pytest.mark.parametrize('transpose', [True, False]) 64 | @pytest.mark.parametrize('dtype', [np.float64, np.complex128]) 65 | def test_pardiso_positive_definite(dtype, transpose): 66 | n = 5 67 | if dtype == np.float64: 68 | L = sp.diags([1, -1], [0, -1], shape=(n, n)) 69 | else: 70 | L = sp.diags([1, -1j], [0, -1], shape=(n, n)) 71 | D = sp.diags(np.linspace(1, 2, n)) 72 | A_pd = L @ D @ (L.T.conjugate()) 73 | 74 | sol = np.linspace(0.9, 1.1, n) 75 | 76 | is_symmetric = dtype == np.float64 77 | Ainv = pymatsolver.Pardiso(A_pd, is_symmetric=is_symmetric, is_hermitian=True, is_positive_definite=True) 78 | if transpose: 79 | rhs = A_pd.T @ sol 80 | Ainv = Ainv.T 81 | else: 82 | rhs = A_pd @ sol 83 | 84 | npt.assert_allclose(Ainv @ rhs, sol) 85 | 86 | 87 | def test_refactor(test_mat_data): 88 | A, sol = test_mat_data 89 | rhs = A @ sol 90 | Ainv = pymatsolver.Pardiso(A, is_symmetric=True) 91 | npt.assert_allclose(Ainv * rhs, sol, atol=TOL) 92 | 93 | # scale rows and columns 94 | D = sp.diags(np.random.rand(A.shape[0]) + 1.0) 95 | A2 = D.T @ A @ D 96 | 97 | rhs2 = A2 @ sol 98 | Ainv.factor(A2) 99 | npt.assert_allclose(Ainv * rhs2, sol, atol=TOL) 100 | 101 | def test_n_threads(test_mat_data): 102 | A, sol = test_mat_data 103 | 104 | max_threads = get_mkl_pardiso_max_threads() 105 | print(f'testing setting n_threads to 1 and {max_threads}') 106 | Ainv = pymatsolver.Pardiso(A, is_symmetric=True, n_threads=1) 107 | assert Ainv.n_threads == 1 108 | 109 | Ainv2 = pymatsolver.Pardiso(A, is_symmetric=True, n_threads=max_threads) 110 | assert Ainv2.n_threads == max_threads 111 | 112 | # the n_threads setting is global so setting Ainv2's n_threads will 113 | # change Ainv's n_threads. 114 | assert Ainv2.n_threads == Ainv.n_threads 115 | 116 | # setting one object's n_threads should change all 117 | Ainv.n_threads = 1 118 | assert Ainv.n_threads == 1 119 | assert Ainv2.n_threads == Ainv.n_threads 120 | 121 | with pytest.raises(TypeError): 122 | Ainv.n_threads = "2" 123 | 124 | def test_inacurrate_symmetry(test_mat_data): 125 | A, sol = test_mat_data 126 | rhs = A @ sol 127 | # make A not symmetric 128 | D = sp.diags(np.linspace(2, 3, A.shape[0])) 129 | A = A @ D 130 | Ainv = pymatsolver.Pardiso(A, is_symmetric=True, check_accuracy=True) 131 | with pytest.raises(pymatsolver.SolverAccuracyError): 132 | Ainv * rhs 133 | 134 | 135 | 136 | def test_pardiso_fdem(): 137 | base_path = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'fdem') 138 | 139 | data = np.load(os.path.join(base_path, 'A_data.npy')) 140 | indices = np.load(os.path.join(base_path, 'A_indices.npy')) 141 | indptr = np.load(os.path.join(base_path, 'A_indptr.npy')) 142 | 143 | A = sp.csr_matrix((data, indices, indptr), shape=(13872, 13872)) 144 | rhs = np.load(os.path.join(base_path, 'RHS.npy')) 145 | 146 | Ainv = pymatsolver.Pardiso(A, check_accuracy=True) 147 | print(Ainv.is_symmetric) 148 | 149 | sol = Ainv * rhs 150 | 151 | npt.assert_allclose(A @ sol, rhs, atol=TOL) -------------------------------------------------------------------------------- /tests/test_Basic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.testing as npt 3 | import pytest 4 | import scipy.sparse as sp 5 | import pymatsolver 6 | from pymatsolver import Diagonal 7 | from pymatsolver.solvers import UnusedArgumentWarning 8 | 9 | TOL = 1e-12 10 | 11 | class IdentitySolver(pymatsolver.solvers.Base): 12 | """"A concrete implementation of Base, for testing purposes""" 13 | def _solve_single(self, rhs): 14 | return rhs 15 | 16 | def _solve_multiple(self, rhs): 17 | return rhs 18 | 19 | def clean(self): 20 | # this is to test that the __del__ still executes if the object doesn't successfully clean. 21 | raise MemoryError("Nothing to cleanup!") 22 | 23 | class NotTransposableIdentitySolver(IdentitySolver): 24 | """ A class that can't be transposed.""" 25 | 26 | @property 27 | def _transpose_class(self): 28 | return None 29 | 30 | 31 | def test_basics(): 32 | 33 | Ainv = IdentitySolver(np.eye(4)) 34 | assert Ainv.is_symmetric 35 | assert Ainv.is_hermitian 36 | assert Ainv.shape == (4, 4) 37 | assert Ainv.is_real 38 | 39 | Ainv = IdentitySolver(np.eye(4) + 0j) 40 | assert Ainv.is_symmetric 41 | assert Ainv.is_hermitian 42 | assert not Ainv.is_real 43 | 44 | Ainv = IdentitySolver(sp.eye(4)) 45 | assert Ainv.is_symmetric 46 | assert Ainv.is_hermitian 47 | assert Ainv.shape == (4, 4) 48 | 49 | Ainv = IdentitySolver(sp.eye(4).astype(np.complex128)) 50 | assert Ainv.is_symmetric 51 | assert Ainv.is_hermitian 52 | assert Ainv.shape == (4, 4) 53 | assert not Ainv.is_real 54 | 55 | def test_basic_solve(): 56 | Ainv = IdentitySolver(np.eye(4)) 57 | 58 | rhs = np.arange(4) 59 | rhs1d = np.arange(4).reshape(4, 1) 60 | rhs2d = np.arange(8).reshape(4, 2) 61 | rhs3d = np.arange(24).reshape(3, 4, 2) 62 | 63 | npt.assert_equal(Ainv @ rhs, rhs) 64 | npt.assert_equal(Ainv @ rhs1d, rhs1d) 65 | npt.assert_equal(Ainv @ rhs2d, rhs2d) 66 | npt.assert_equal(Ainv @ rhs3d, rhs3d) 67 | 68 | npt.assert_equal(rhs @ Ainv, rhs) 69 | npt.assert_equal(rhs * Ainv, rhs) 70 | 71 | 72 | npt.assert_equal(rhs1d.T @ Ainv, rhs1d.T) 73 | npt.assert_equal(rhs1d.T * Ainv, rhs1d.T) 74 | 75 | npt.assert_equal(rhs2d.T @ Ainv, rhs2d.T) 76 | npt.assert_equal(rhs2d.T * Ainv, rhs2d.T) 77 | 78 | npt.assert_equal(rhs3d.swapaxes(-1, -2) @ Ainv, rhs3d.swapaxes(-1, -2)) 79 | npt.assert_equal(rhs3d.swapaxes(-1, -2) * Ainv, rhs3d.swapaxes(-1, -2)) 80 | 81 | 82 | def test_errors_and_warnings(): 83 | 84 | # from Base... 85 | with pytest.raises(ValueError, match="A must be 2-dimensional."): 86 | IdentitySolver(np.full((3, 3, 3), 1)) 87 | 88 | with pytest.raises(ValueError, match="A is not a square matrix."): 89 | IdentitySolver(np.full((3, 5), 1)) 90 | 91 | with pytest.raises(TypeError, match=r"'accuracy_tol' was removed.*"): 92 | IdentitySolver(np.full((4, 4), 1), accuracy_tol=0.41) 93 | 94 | with pytest.warns(UnusedArgumentWarning, match="Unused keyword arguments.*"): 95 | IdentitySolver(np.full((4, 4), 1), not_an_argument=4) 96 | 97 | with pytest.raises(TypeError, match="is_symmetric must be a boolean."): 98 | IdentitySolver(np.full((4, 4), 1), is_symmetric="True") 99 | 100 | with pytest.raises(TypeError, match="is_hermitian must be a boolean."): 101 | IdentitySolver(np.full((4, 4), 1), is_hermitian="True") 102 | 103 | with pytest.raises(TypeError, match="is_positive_definite must be a boolean."): 104 | IdentitySolver(np.full((4, 4), 1), is_positive_definite="True") 105 | 106 | with pytest.raises(ValueError, match="check_rtol must.*"): 107 | IdentitySolver(np.full((4, 4), 1), check_rtol=0.0) 108 | 109 | with pytest.raises(ValueError, match="check_atol must.*"): 110 | IdentitySolver(np.full((4, 4), 1), check_atol=-1.0) 111 | 112 | with pytest.raises(ValueError, match="Expected a vector of length.*"): 113 | Ainv = IdentitySolver(np.eye(4, 4)) 114 | Ainv @ np.ones(3) 115 | 116 | with pytest.raises(ValueError, match="Second to last dimension should be.*"): 117 | Ainv = IdentitySolver(np.eye(4, 4)) 118 | Ainv @ np.ones((3, 2)) 119 | 120 | with pytest.raises(NotImplementedError, match="The transpose for the.*"): 121 | Ainv = NotTransposableIdentitySolver(np.eye(4, 4), is_symmetric=False) 122 | Ainv.T 123 | 124 | 125 | 126 | def test_DiagonalSolver(): 127 | 128 | A = sp.identity(5)*2.0 129 | rhs = np.c_[np.arange(1, 6), np.arange(2, 11, 2)] 130 | X = Diagonal(A) * rhs 131 | x = Diagonal(A) * rhs[:, 0] 132 | 133 | sol = rhs/2.0 134 | 135 | with pytest.raises(TypeError): 136 | Diagonal(A, check_accuracy=np.array([1, 2, 3])) 137 | with pytest.raises(ValueError): 138 | Diagonal(A, check_rtol=0) 139 | 140 | npt.assert_allclose(sol, X, atol=TOL) 141 | npt.assert_allclose(sol[:, 0], x, atol=TOL) 142 | 143 | def test_diagonal_errors(): 144 | 145 | with pytest.raises(TypeError, match="A must have a diagonal.*"): 146 | Diagonal( 147 | [ 148 | [2, 0], 149 | [0, 1] 150 | ] 151 | ) 152 | 153 | with pytest.raises(ValueError, match="Diagonal matrix has a zero along the diagonal."): 154 | Diagonal( 155 | np.array( 156 | [ 157 | [0, 0], 158 | [0, 1] 159 | ] 160 | ) 161 | ) 162 | 163 | def test_diagonal_inferance(): 164 | 165 | Ainv = Diagonal( 166 | np.array( 167 | [ 168 | [2., 0.], 169 | [0., 1.], 170 | ] 171 | ), 172 | ) 173 | 174 | assert Ainv.is_symmetric 175 | assert Ainv.is_positive_definite 176 | assert Ainv.is_hermitian 177 | assert Ainv.is_real 178 | 179 | Ainv = Diagonal( 180 | np.array( 181 | [ 182 | [2.0, 0], 183 | [0, -1.0], 184 | ] 185 | ), 186 | ) 187 | 188 | assert Ainv.is_symmetric 189 | assert not Ainv.is_positive_definite 190 | assert Ainv.is_hermitian 191 | assert Ainv.is_real 192 | 193 | Ainv = Diagonal( 194 | np.array( 195 | [ 196 | [2 + 0j, 0], 197 | [0, 2 + 0j], 198 | ] 199 | ) 200 | ) 201 | assert not Ainv.is_real 202 | assert Ainv.is_symmetric 203 | assert Ainv.is_hermitian 204 | assert Ainv.is_positive_definite 205 | 206 | Ainv = Diagonal( 207 | np.array( 208 | [ 209 | [2 + 0j, 0], 210 | [0, -2 + 0j], 211 | ] 212 | ) 213 | ) 214 | assert not Ainv.is_real 215 | assert Ainv.is_symmetric 216 | assert Ainv.is_hermitian 217 | assert not Ainv.is_positive_definite 218 | 219 | Ainv = Diagonal( 220 | np.array( 221 | [ 222 | [2 + 1j, 0], 223 | [0, 2 + 0j], 224 | ] 225 | ) 226 | ) 227 | assert not Ainv.is_real 228 | assert Ainv.is_symmetric 229 | assert not Ainv.is_hermitian 230 | assert not Ainv.is_positive_definite 231 | -------------------------------------------------------------------------------- /pymatsolver/wrappers.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from inspect import signature 3 | import numpy as np 4 | 5 | from pymatsolver.solvers import Base, UnusedArgumentWarning 6 | 7 | def _valid_kwargs_for_func(func, **kwargs): 8 | """Validates keyword arguments for a function by inspecting its signature. 9 | 10 | This will issue a warning if the function does not accept the keyword. 11 | 12 | Returns 13 | ------- 14 | valid_kwargs : dict 15 | Arguments able to be passed to the function (based on its signature). 16 | 17 | Notes 18 | ----- 19 | If a function's signature accepts `**kwargs` then all keyword arguments are 20 | valid by definition, even if the function might throw its own exceptions based 21 | off of the arguments you pass. This function will not check for those keyword 22 | arguments. 23 | """ 24 | sig = signature(func) 25 | valid_kwargs = {} 26 | for key, value in kwargs.items(): 27 | try: 28 | sig.bind_partial(**{key: value}) 29 | valid_kwargs[key] = value 30 | except TypeError: 31 | warnings.warn(f'Unused keyword argument "{key}" for {func.__name__}.', UnusedArgumentWarning, stacklevel=3) 32 | # stack level of three because we want the warning issued at the call 33 | # to the wrapped solver's `__init__` method. 34 | return valid_kwargs 35 | 36 | 37 | def wrap_direct(fun, factorize=True, name=None): 38 | """Wraps a direct Solver. 39 | 40 | Parameters 41 | ---------- 42 | fun : callable 43 | The solver function to be wrapped. 44 | factorize : bool 45 | Set to ``True`` if `fun` will return a factorized object that has a ``solve()`` 46 | method. This allows it to be re-used for repeated solve calls. 47 | name : str, optional 48 | The name of the wrapped class to return. 49 | 50 | Returns 51 | ------- 52 | wrapped : pymatsolver.solvers.Base 53 | The wrapped function as a `pymatsolver` class. 54 | 55 | Notes 56 | ----- 57 | Keyword arguments passed to the returned object on initialization will be checked 58 | against `fun`'s signature. If `factorize` is ``True``, then they will additionally be 59 | checked against the factorized object's ``solve()`` method signature. These checks 60 | will not cause errors, but will issue warnings saying they are unused. 61 | 62 | Examples 63 | -------- 64 | >>> import pymatsolver 65 | >>> from scipy.sparse.linalg import spsolve, splu 66 | 67 | Scipy's ``spsolve`` does not support reuse, so we must pass ``factorize=false``. 68 | >>> Solver = pymatsolver.WrapDirect(spsolve, factorize=False) 69 | 70 | Scipy's ``splu`` returns an `SuperLU` object that has a `solve` method, and therefore 71 | does support reuse, so we must pass ``factorize=true``. 72 | >>> SolverLU = pymatsolver.WrapDirect(splu, factorize=True) 73 | """ 74 | 75 | def __init__(self, A, check_accuracy=False, check_rtol=1E-6, check_atol=0, **kwargs): 76 | Base.__init__( 77 | self, A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, 78 | ) 79 | self.kwargs = kwargs 80 | if factorize: 81 | self.solver = fun(self.A, **self.kwargs) 82 | if not hasattr(self.solver, "solve"): 83 | raise TypeError(f"instance returned by {fun.__name__} must have a solve() method.") 84 | 85 | @property 86 | def kwargs(self): 87 | return self._kwargs 88 | 89 | @kwargs.setter 90 | def kwargs(self, keyword_arguments): 91 | self._kwargs = _valid_kwargs_for_func(fun, **keyword_arguments) 92 | 93 | def _solve_single(self, rhs): 94 | rhs = rhs.astype(self.dtype) 95 | 96 | if factorize: 97 | X = self.solver.solve(rhs) 98 | else: 99 | X = fun(self.A, rhs, **self.kwargs) 100 | 101 | return X 102 | 103 | def _solve_multiple(self, rhs): 104 | rhs = rhs.astype(self.dtype) 105 | 106 | X = np.empty_like(rhs) 107 | for i in range(rhs.shape[1]): 108 | X[:, i] = self._solve_single(rhs[:, i]) 109 | 110 | return X 111 | 112 | def clean(self): 113 | if factorize and hasattr(self.solver, 'clean'): 114 | self.solver.clean() 115 | 116 | class_name = str(name if name is not None else fun.__name__) 117 | WrappedClass = type( 118 | class_name, 119 | (Base,), 120 | { 121 | "__init__": __init__, 122 | "_solve_single": _solve_single, 123 | "_solve_multiple": _solve_multiple, 124 | "kwargs": kwargs, 125 | "clean": clean, 126 | } 127 | ) 128 | WrappedClass.__doc__ = f"""Wrapped {class_name} solver. 129 | 130 | Parameters 131 | ---------- 132 | A 133 | The matrix to use for the solver. 134 | check_accuracy : bool, optional 135 | Whether to check the accuracy of the solution. 136 | check_rtol : float, optional 137 | The relative tolerance to check against for accuracy. 138 | check_atol : float, optional 139 | The absolute tolerance to check against for accuracy. 140 | **kwargs 141 | Extra keyword arguments which will attempted to be passed to the wrapped function. 142 | """ 143 | return WrappedClass 144 | 145 | 146 | def wrap_iterative(fun, name=None): 147 | """ 148 | Wraps an iterative Solver. 149 | 150 | Parameters 151 | ---------- 152 | fun : callable 153 | The iterative Solver function. 154 | name : string, optional 155 | The name of the wrapper class to construct. Defaults to the name of `fun`. 156 | 157 | Returns 158 | ------- 159 | wrapped : pymatsolver.solvers.Base 160 | 161 | Notes 162 | ----- 163 | Keyword arguments passed to the returned object on initialization will be checked 164 | against `fun`'s signature.These checks will not cause errors, but will issue warnings 165 | saying they are unused. 166 | 167 | Examples 168 | -------- 169 | >>> import pymatsolver 170 | >>> from scipy.sparse.linalg import cg 171 | >>> SolverCG = pymatsolver.WrapIterative(cg) 172 | 173 | """ 174 | 175 | def __init__(self, A, check_accuracy=False, check_rtol=1E-6, check_atol=0, **kwargs): 176 | Base.__init__( 177 | self, A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, 178 | ) 179 | self.kwargs = kwargs 180 | 181 | @property 182 | def kwargs(self): 183 | return self._kwargs 184 | 185 | @kwargs.setter 186 | def kwargs(self, keyword_arguments): 187 | self._kwargs = _valid_kwargs_for_func(fun, **keyword_arguments) 188 | 189 | def _solve_single(self, rhs): 190 | 191 | out = fun(self.A, rhs, **self.kwargs) 192 | if type(out) is tuple and len(out) == 2: 193 | # We are dealing with scipy output with an info! 194 | X = out[0] 195 | self.info = out[1] 196 | else: 197 | X = out 198 | return X 199 | 200 | def _solve_multiple(self, rhs): 201 | 202 | X = np.empty_like(rhs) 203 | for i in range(rhs.shape[1]): 204 | X[:, i] = self._solve_single(rhs[:, i]) 205 | return X 206 | 207 | class_name = str(name if name is not None else fun.__name__) 208 | WrappedClass = type( 209 | class_name, 210 | (Base,), 211 | { 212 | "__init__": __init__, 213 | "_solve_single": _solve_single, 214 | "_solve_multiple": _solve_multiple, 215 | "kwargs": kwargs, 216 | } 217 | ) 218 | WrappedClass.__doc__ = f"""Wrapped {class_name} solver. 219 | 220 | Parameters 221 | ---------- 222 | A 223 | The matrix to use for the solver. 224 | check_accuracy : bool, optional 225 | Whether to check the accuracy of the solution. 226 | check_rtol : float, optional 227 | The relative tolerance to check against for accuracy. 228 | check_atol : float, optional 229 | The absolute tolerance to check against for accuracy. 230 | **kwargs 231 | Extra keyword arguments which will attempted to be passed to the wrapped function. 232 | """ 233 | 234 | return WrappedClass 235 | 236 | 237 | WrapDirect = wrap_direct 238 | WrapIterative = wrap_iterative -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help 23 | help: 24 | @echo "Please use \`make ' where is one of" 25 | @echo " html to make standalone HTML files" 26 | @echo " dirhtml to make HTML files named index.html in directories" 27 | @echo " singlehtml to make a single large HTML file" 28 | @echo " pickle to make pickle files" 29 | @echo " json to make JSON files" 30 | @echo " htmlhelp to make HTML files and a HTML help project" 31 | @echo " qthelp to make HTML files and a qthelp project" 32 | @echo " applehelp to make an Apple Help Book" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | @echo " coverage to run coverage check of the documentation (if enabled)" 49 | 50 | .PHONY: clean 51 | clean: 52 | rm -rf $(BUILDDIR)/* 53 | 54 | .PHONY: html 55 | html: 56 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 57 | @echo 58 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 59 | 60 | .PHONY: dirhtml 61 | dirhtml: 62 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 63 | @echo 64 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 65 | 66 | .PHONY: singlehtml 67 | singlehtml: 68 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 69 | @echo 70 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 71 | 72 | .PHONY: pickle 73 | pickle: 74 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 75 | @echo 76 | @echo "Build finished; now you can process the pickle files." 77 | 78 | .PHONY: json 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | .PHONY: htmlhelp 85 | htmlhelp: 86 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 87 | @echo 88 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 89 | ".hhp project file in $(BUILDDIR)/htmlhelp." 90 | 91 | .PHONY: qthelp 92 | qthelp: 93 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 94 | @echo 95 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 96 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 97 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/pymatsolver.qhcp" 98 | @echo "To view the help file:" 99 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/pymatsolver.qhc" 100 | 101 | .PHONY: applehelp 102 | applehelp: 103 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 104 | @echo 105 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 106 | @echo "N.B. You won't be able to view it unless you put it in" \ 107 | "~/Library/Documentation/Help or install it in your application" \ 108 | "bundle." 109 | 110 | .PHONY: devhelp 111 | devhelp: 112 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 113 | @echo 114 | @echo "Build finished." 115 | @echo "To view the help file:" 116 | @echo "# mkdir -p $$HOME/.local/share/devhelp/pymatsolver" 117 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/pymatsolver" 118 | @echo "# devhelp" 119 | 120 | .PHONY: epub 121 | epub: 122 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 123 | @echo 124 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 125 | 126 | .PHONY: latex 127 | latex: 128 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 129 | @echo 130 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 131 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 132 | "(use \`make latexpdf' here to do that automatically)." 133 | 134 | .PHONY: latexpdf 135 | latexpdf: 136 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 137 | @echo "Running LaTeX files through pdflatex..." 138 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 139 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 140 | 141 | .PHONY: latexpdfja 142 | latexpdfja: 143 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 144 | @echo "Running LaTeX files through platex and dvipdfmx..." 145 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 146 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 147 | 148 | .PHONY: text 149 | text: 150 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 151 | @echo 152 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 153 | 154 | .PHONY: man 155 | man: 156 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 157 | @echo 158 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 159 | 160 | .PHONY: texinfo 161 | texinfo: 162 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 163 | @echo 164 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 165 | @echo "Run \`make' in that directory to run these through makeinfo" \ 166 | "(use \`make info' here to do that automatically)." 167 | 168 | .PHONY: info 169 | info: 170 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 171 | @echo "Running Texinfo files through makeinfo..." 172 | make -C $(BUILDDIR)/texinfo info 173 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 174 | 175 | .PHONY: gettext 176 | gettext: 177 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 178 | @echo 179 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 180 | 181 | .PHONY: changes 182 | changes: 183 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 184 | @echo 185 | @echo "The overview file is in $(BUILDDIR)/changes." 186 | 187 | .PHONY: linkcheck 188 | linkcheck: 189 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 190 | @echo 191 | @echo "Link check complete; look for any errors in the above output " \ 192 | "or in $(BUILDDIR)/linkcheck/output.txt." 193 | 194 | .PHONY: doctest 195 | doctest: 196 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 197 | @echo "Testing of doctests in the sources finished, look at the " \ 198 | "results in $(BUILDDIR)/doctest/output.txt." 199 | 200 | .PHONY: coverage 201 | coverage: 202 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 203 | @echo "Testing of coverage in the sources finished, look at the " \ 204 | "results in $(BUILDDIR)/coverage/python.txt." 205 | 206 | .PHONY: xml 207 | xml: 208 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 209 | @echo 210 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 211 | 212 | .PHONY: pseudoxml 213 | pseudoxml: 214 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 215 | @echo 216 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 217 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # pymatsolver documentation build configuration file, created by 4 | # sphinx-quickstart on Sat Jan 7 17:12:34 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | import os 17 | import pymatsolver 18 | 19 | # If extensions (or modules to document with autodoc) are in another directory, 20 | # add these directories to sys.path here. If the directory is relative to the 21 | # documentation root, use os.path.abspath to make it absolute, like shown here. 22 | #sys.path.insert(0, os.path.abspath('.')) 23 | 24 | # -- General configuration ------------------------------------------------ 25 | 26 | # If your documentation needs a minimal Sphinx version, state it here. 27 | #needs_sphinx = '1.0' 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = [ 33 | 'sphinx.ext.intersphinx', 34 | 'sphinx.ext.coverage', 35 | 'sphinx.ext.mathjax', 36 | 'sphinx.ext.autodoc', 37 | "numpydoc", 38 | "sphinx.ext.autosummary", 39 | ] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | templates_path = ['_templates'] 43 | 44 | # Autosummary pages will be generated by sphinx-autogen instead of sphinx-build 45 | autosummary_generate = True 46 | 47 | # The suffix(es) of source filenames. 48 | # You can specify multiple suffix as a list of string: 49 | # source_suffix = ['.rst', '.md'] 50 | source_suffix = '.rst' 51 | 52 | # The encoding of source files. 53 | #source_encoding = 'utf-8-sig' 54 | 55 | # The master toctree document. 56 | master_doc = 'index' 57 | 58 | # General information about the project. 59 | project = u'pymatsolver' 60 | copyright = u'2017, Rowan Cockett' 61 | author = u'Rowan Cockett' 62 | 63 | # The version info for the project you're documenting, acts as replacement for 64 | # |version| and |release|, also used in various other places throughout the 65 | # built documents. 66 | # 67 | # The full version, including alpha/beta/rc tags. 68 | release = pymatsolver.__version__ 69 | # The short X.Y version. 70 | version = ".".join(release.split(".")[:2]) 71 | 72 | # The language for content autogenerated by Sphinx. Refer to documentation 73 | # for a list of supported languages. 74 | # 75 | # This is also used if you do content translation via gettext catalogs. 76 | # Usually you set "language" from the command line for these cases. 77 | language = None 78 | 79 | # There are two options for replacing |today|: either, you set today to some 80 | # non-false value, then it is used: 81 | #today = '' 82 | # Else, today_fmt is used as the format for a strftime call. 83 | #today_fmt = '%B %d, %Y' 84 | 85 | # List of patterns, relative to source directory, that match files and 86 | # directories to ignore when looking for source files. 87 | exclude_patterns = ['_build'] 88 | 89 | # The reST default role (used for this markup: `text`) to use for all 90 | # documents. 91 | #default_role = None 92 | 93 | # If true, '()' will be appended to :func: etc. cross-reference text. 94 | #add_function_parentheses = True 95 | 96 | # If true, the current module name will be prepended to all description 97 | # unit titles (such as .. function::). 98 | #add_module_names = True 99 | 100 | # If true, sectionauthor and moduleauthor directives will be shown in the 101 | # output. They are ignored by default. 102 | #show_authors = False 103 | 104 | # The name of the Pygments (syntax highlighting) style to use. 105 | pygments_style = 'sphinx' 106 | 107 | # A list of ignored prefixes for module index sorting. 108 | #modindex_common_prefix = [] 109 | 110 | # If true, keep warnings as "system message" paragraphs in the built documents. 111 | #keep_warnings = False 112 | 113 | # If true, `todo` and `todoList` produce output, else they produce nothing. 114 | todo_include_todos = False 115 | 116 | # source code links 117 | link_github = True 118 | # You can build old with link_github = False 119 | 120 | if link_github: 121 | 122 | import inspect 123 | from os.path import relpath, dirname 124 | 125 | extensions.append('sphinx.ext.linkcode') 126 | 127 | def linkcode_resolve(domain, info): 128 | if domain != "py": 129 | return None 130 | 131 | modname = info["module"] 132 | fullname = info["fullname"] 133 | 134 | submod = sys.modules.get(modname) 135 | if submod is None: 136 | return None 137 | 138 | obj = submod 139 | for part in fullname.split("."): 140 | try: 141 | obj = getattr(obj, part) 142 | except Exception: 143 | return None 144 | 145 | try: 146 | unwrap = inspect.unwrap 147 | except AttributeError: 148 | pass 149 | else: 150 | obj = unwrap(obj) 151 | 152 | try: 153 | fn = inspect.getsourcefile(obj) 154 | except Exception: 155 | fn = None 156 | if not fn: 157 | return None 158 | 159 | try: 160 | source, lineno = inspect.getsourcelines(obj) 161 | except Exception: 162 | lineno = None 163 | 164 | if lineno: 165 | linespec = "#L%d-L%d" % (lineno, lineno + len(source) - 1) 166 | else: 167 | linespec = "" 168 | 169 | try: 170 | fn = relpath(fn, start=dirname(pymatsolver.__file__)) 171 | except ValueError: 172 | return None 173 | 174 | return f"https://github.com/simpeg/pymatsolver/blob/main/pymatsolver/{fn}{linespec}" 175 | else: 176 | extensions.append('sphinx.ext.viewcode') 177 | 178 | 179 | # -- Options for HTML output ---------------------------------------------- 180 | 181 | # The theme to use for HTML and HTML Help pages. See the documentation for 182 | # a list of builtin themes. 183 | 184 | try: 185 | import pydata_sphinx_theme 186 | 187 | html_theme = "pydata_sphinx_theme" 188 | 189 | # If false, no module index is generated. 190 | html_use_modindex = True 191 | 192 | html_theme_options = { 193 | "external_links": [ 194 | {"name": "SimPEG", "url": "https://simpeg.xyz"}, 195 | {"name": "Contact", "url": "http://slack.simpeg.xyz"} 196 | ], 197 | "icon_links": [ 198 | { 199 | "name": "GitHub", 200 | "url": "https://github.com/simpeg/pymatsolver", 201 | "icon": "fab fa-github", 202 | }, 203 | { 204 | "name": "Slack", 205 | "url": "http://slack.simpeg.xyz/", 206 | "icon": "fab fa-slack", 207 | }, 208 | { 209 | "name": "Discourse", 210 | "url": "https://simpeg.discourse.group/", 211 | "icon": "fab fa-discourse", 212 | }, 213 | { 214 | "name": "Youtube", 215 | "url": "https://www.youtube.com/c/geoscixyz", 216 | "icon": "fab fa-youtube", 217 | }, 218 | { 219 | "name": "Twitter", 220 | "url": "https://twitter.com/simpegpy", 221 | "icon": "fab fa-twitter", 222 | }, 223 | ], 224 | "use_edit_page_button": False, 225 | } 226 | 227 | html_static_path = ['_static'] 228 | 229 | html_css_files = [ 230 | 'css/custom.css', 231 | ] 232 | 233 | html_context = { 234 | "github_user": "simpeg", 235 | "github_repo": "pymatsolver", 236 | "github_version": "main", 237 | "doc_path": "docs", 238 | } 239 | except Exception: 240 | html_theme = "default" 241 | 242 | # Theme options are theme-specific and customize the look and feel of a theme 243 | # further. For a list of options available for each theme, see the 244 | # documentation. 245 | #html_theme_options = {} 246 | 247 | # Add any paths that contain custom themes here, relative to this directory. 248 | #html_theme_path = [] 249 | 250 | # The name for this set of Sphinx documents. If None, it defaults to 251 | # " v documentation". 252 | #html_title = None 253 | 254 | # A shorter title for the navigation bar. Default is the same as html_title. 255 | #html_short_title = None 256 | 257 | # The name of an image file (relative to this directory) to place at the top 258 | # of the sidebar. 259 | #html_logo = None 260 | 261 | # The name of an image file (within the static path) to use as favicon of the 262 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 263 | # pixels large. 264 | html_favicon = "./images/logo-block.ico" 265 | 266 | # Add any paths that contain custom static files (such as style sheets) here, 267 | # relative to this directory. They are copied after the builtin static files, 268 | # so a file named "default.css" will overwrite the builtin "default.css". 269 | html_static_path = ['_static'] 270 | 271 | # Add any extra paths that contain custom files (such as robots.txt or 272 | # .htaccess) here, relative to this directory. These files are copied 273 | # directly to the root of the documentation. 274 | #html_extra_path = [] 275 | 276 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 277 | # using the given strftime format. 278 | #html_last_updated_fmt = '%b %d, %Y' 279 | 280 | # If true, SmartyPants will be used to convert quotes and dashes to 281 | # typographically correct entities. 282 | #html_use_smartypants = True 283 | 284 | # Custom sidebar templates, maps document names to template names. 285 | #html_sidebars = {} 286 | 287 | # Additional templates that should be rendered to pages, maps page names to 288 | # template names. 289 | #html_additional_pages = {} 290 | 291 | # If false, no module index is generated. 292 | #html_domain_indices = True 293 | 294 | # If false, no index is generated. 295 | #html_use_index = True 296 | 297 | # If true, the index is split into individual pages for each letter. 298 | #html_split_index = False 299 | 300 | # If true, links to the reST sources are added to the pages. 301 | #html_show_sourcelink = True 302 | 303 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 304 | #html_show_sphinx = True 305 | 306 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 307 | #html_show_copyright = True 308 | 309 | # If true, an OpenSearch description file will be output, and all pages will 310 | # contain a tag referring to it. The value of this option must be the 311 | # base URL from which the finished HTML is served. 312 | #html_use_opensearch = '' 313 | 314 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 315 | #html_file_suffix = None 316 | 317 | # Language to be used for generating the HTML full-text search index. 318 | # Sphinx supports the following languages: 319 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' 320 | # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' 321 | #html_search_language = 'en' 322 | 323 | # A dictionary with options for the search language support, empty by default. 324 | # Now only 'ja' uses this config value 325 | #html_search_options = {'type': 'default'} 326 | 327 | # The name of a javascript file (relative to the configuration directory) that 328 | # implements a search results scorer. If empty, the default will be used. 329 | #html_search_scorer = 'scorer.js' 330 | 331 | # Output file base name for HTML help builder. 332 | htmlhelp_basename = 'pymatsolver' 333 | 334 | # -- Options for LaTeX output --------------------------------------------- 335 | 336 | latex_elements = { 337 | # The paper size ('letterpaper' or 'a4paper'). 338 | #'papersize': 'letterpaper', 339 | 340 | # The font size ('10pt', '11pt' or '12pt'). 341 | #'pointsize': '10pt', 342 | 343 | # Additional stuff for the LaTeX preamble. 344 | #'preamble': '', 345 | 346 | # Latex figure (float) alignment 347 | #'figure_align': 'htbp', 348 | } 349 | 350 | # Grouping the document tree into LaTeX files. List of tuples 351 | # (source start file, target name, title, 352 | # author, documentclass [howto, manual, or own class]). 353 | latex_documents = [ 354 | (master_doc, 'pymatsolver.tex', u'pymatsolver Documentation', 355 | u'Rowan Cockett', 'manual'), 356 | ] 357 | 358 | # The name of an image file (relative to this directory) to place at the top of 359 | # the title page. 360 | #latex_logo = None 361 | 362 | # For "manual" documents, if this is true, then toplevel headings are parts, 363 | # not chapters. 364 | #latex_use_parts = False 365 | 366 | # If true, show page references after internal links. 367 | #latex_show_pagerefs = False 368 | 369 | # If true, show URL addresses after external links. 370 | #latex_show_urls = False 371 | 372 | # Documents to append as an appendix to all manuals. 373 | #latex_appendices = [] 374 | 375 | # If false, no module index is generated. 376 | #latex_domain_indices = True 377 | 378 | 379 | # -- Options for manual page output --------------------------------------- 380 | 381 | # One entry per manual page. List of tuples 382 | # (source start file, name, description, authors, manual section). 383 | man_pages = [ 384 | (master_doc, 'pymatsolver', u'pymatsolver Documentation', 385 | [author], 1) 386 | ] 387 | 388 | # If true, show URL addresses after external links. 389 | #man_show_urls = False 390 | 391 | 392 | # -- Options for Texinfo output ------------------------------------------- 393 | 394 | # Grouping the document tree into Texinfo files. List of tuples 395 | # (source start file, target name, title, author, 396 | # dir menu entry, description, category) 397 | texinfo_documents = [ 398 | (master_doc, 'pymatsolver', u'pymatsolver Documentation', 399 | author, 'pymatsolver', 'One line description of project.', 400 | 'Miscellaneous'), 401 | ] 402 | 403 | # Documents to append as an appendix to all manuals. 404 | #texinfo_appendices = [] 405 | 406 | # If false, no module index is generated. 407 | #texinfo_domain_indices = True 408 | 409 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 410 | #texinfo_show_urls = 'footnote' 411 | 412 | # If true, do not generate a @detailmenu in the "Top" node's menu. 413 | #texinfo_no_detailmenu = False 414 | 415 | # Intersphinx 416 | intersphinx_mapping = { 417 | "python": ("https://docs.python.org/3/", None), 418 | "numpy": ("https://numpy.org/doc/stable/", None), 419 | "scipy": ("https://docs.scipy.org/doc/scipy/", None), 420 | } 421 | numpydoc_xref_param_type = True 422 | -------------------------------------------------------------------------------- /pymatsolver/solvers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | import scipy.sparse as sp 4 | from scipy.sparse.linalg import spsolve_triangular 5 | from scipy.linalg import issymmetric, ishermitian 6 | from abc import ABC, abstractmethod 7 | import copy 8 | 9 | 10 | class SolverAccuracyError(Exception): 11 | pass 12 | 13 | 14 | class UnusedArgumentWarning(UserWarning): 15 | pass 16 | 17 | 18 | class Base(ABC): 19 | """Base class for all solvers used in the pymatsolver package. 20 | 21 | Parameters 22 | ---------- 23 | A 24 | Matrix to solve with. 25 | is_symmetric : bool, optional 26 | Whether the matrix is symmetric. By default, it will perform some simple tests to check for symmetry, and 27 | default to ``False`` if those fail. 28 | is_positive_definite : bool, optional 29 | Whether the matrix is positive definite. 30 | is_hermitian : bool, optional 31 | Whether the matrix is hermitian. By default, it will perform some simple tests to check, and default to 32 | ``False`` if those fail. 33 | check_accuracy : bool, optional 34 | Whether to check the accuracy of the solution. 35 | check_rtol : float, optional 36 | The relative tolerance to check against for accuracy. 37 | check_atol : float, optional 38 | The absolute tolerance to check against for accuracy. 39 | **kwargs 40 | Extra keyword arguments. If there are any left here a warning will be raised. 41 | """ 42 | 43 | __numpy_ufunc__ = True 44 | __array_ufunc__ = None 45 | 46 | _is_conjugate = False 47 | 48 | def __init__( 49 | self, A, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs 50 | ): 51 | # don't make any assumptions on what A is, let the individual solvers handle that 52 | shape = A.shape 53 | if len(shape) != 2: 54 | raise ValueError("A must be 2-dimensional.") 55 | if shape[0] != shape[1]: 56 | raise ValueError("A is not a square matrix.") 57 | self._A = A 58 | self._dtype = np.dtype(A.dtype) 59 | 60 | if 'accuracy_tol' in kwargs: 61 | raise TypeError("'accuracy_tol' was removed in v0.4.0, use 'check_rtol' and 'check_atol'.") 62 | 63 | self.check_accuracy = check_accuracy 64 | self.check_rtol = check_rtol 65 | self.check_atol = check_atol 66 | 67 | # do some symmetry checks that likely speed up the defualt solve operation 68 | if is_symmetric is None: 69 | if sp.issparse(A): 70 | is_symmetric = (A.T != A).nnz == 0 71 | elif isinstance(A, np.ndarray): 72 | is_symmetric = issymmetric(A) 73 | else: 74 | is_symmetric = False 75 | self.is_symmetric = is_symmetric 76 | if is_hermitian is None: 77 | if self.is_real: 78 | is_hermitian = self.is_symmetric 79 | else: 80 | if sp.issparse(A): 81 | is_hermitian = (A.T.conjugate() != A).nnz == 0 82 | elif isinstance(A, np.ndarray): 83 | is_hermitian = ishermitian(A) 84 | else: 85 | is_hermitian = False 86 | 87 | self.is_hermitian = is_hermitian 88 | 89 | # Can't check for positive definiteness until it is factored. 90 | # This should be defaulted to False. If the user knows ahead of time that it is positive definite 91 | # they should set this to be true. 92 | self.is_positive_definite = is_positive_definite 93 | 94 | if kwargs: 95 | warnings.warn( 96 | f"Unused keyword arguments for {self.__class__.__name__}: {kwargs.keys()}", 97 | UnusedArgumentWarning, 98 | stacklevel=3 99 | ) 100 | 101 | @property 102 | def A(self): 103 | """The matrix to solve with.""" 104 | return self._A 105 | 106 | @property 107 | def dtype(self): 108 | """The data type of the matrix. 109 | 110 | Returns 111 | ------- 112 | numpy.dtype 113 | """ 114 | return self._dtype 115 | 116 | @property 117 | def shape(self): 118 | """The input matrix dimensions. 119 | 120 | Returns 121 | ------- 122 | (2, ) tuple 123 | """ 124 | return self.A.shape 125 | 126 | @property 127 | def is_real(self): 128 | """Whether the matrix is real. 129 | 130 | Returns 131 | ------- 132 | bool 133 | """ 134 | return np.issubdtype(self.A.dtype, np.floating) 135 | 136 | @property 137 | def is_symmetric(self): 138 | """Whether the matrix is symmetric. 139 | 140 | Returns 141 | ------- 142 | bool 143 | """ 144 | return self._is_symmetric 145 | 146 | @is_symmetric.setter 147 | def is_symmetric(self, value): 148 | if isinstance(value, bool): 149 | self._is_symmetric = value 150 | else: 151 | raise TypeError("is_symmetric must be a boolean.") 152 | 153 | @property 154 | def is_hermitian(self): 155 | """Whether the matrix is hermitian. 156 | 157 | Returns 158 | ------- 159 | bool 160 | """ 161 | if self.is_real and self.is_symmetric: 162 | return True 163 | else: 164 | return self._is_hermitian 165 | 166 | @is_hermitian.setter 167 | def is_hermitian(self, value): 168 | if isinstance(value, bool): 169 | self._is_hermitian = value 170 | else: 171 | raise TypeError("is_hermitian must be a boolean.") 172 | 173 | @property 174 | def is_positive_definite(self): 175 | """Whether the matrix is positive definite. 176 | 177 | Returns 178 | ------- 179 | bool 180 | """ 181 | return self._is_positive_definite 182 | 183 | @is_positive_definite.setter 184 | def is_positive_definite(self, value): 185 | if isinstance(value, bool): 186 | self._is_positive_definite = value 187 | else: 188 | raise TypeError("is_positive_definite must be a boolean.") 189 | 190 | @property 191 | def check_accuracy(self): 192 | """Whether the check the accuracy after a solve. 193 | 194 | Performs a test of: 195 | >>> all(A @ x_solve - rhs <= max(check_rtol * norm(rhs), check_atol)) 196 | 197 | Returns 198 | ------- 199 | bool 200 | """ 201 | return self._check_accuracy 202 | 203 | @check_accuracy.setter 204 | def check_accuracy(self, value): 205 | if isinstance(value, bool): 206 | self._check_accuracy = value 207 | else: 208 | raise TypeError("check_accuracy must be a boolean.") 209 | 210 | @property 211 | def check_rtol(self): 212 | """The relative tolerance used to check the solve operation. 213 | 214 | Returns 215 | ------- 216 | bool 217 | """ 218 | return self._check_rtol 219 | 220 | @check_rtol.setter 221 | def check_rtol(self, value): 222 | value = float(value) 223 | if value > 0: 224 | self._check_rtol = float(value) 225 | else: 226 | raise ValueError("check_rtol must be greater than zero.") 227 | 228 | @property 229 | def check_atol(self): 230 | """The absolute tolerance used to check the solve operation. 231 | 232 | Returns 233 | ------- 234 | bool 235 | """ 236 | return self._check_atol 237 | 238 | @check_atol.setter 239 | def check_atol(self, value): 240 | value = float(value) 241 | if value >= 0: 242 | self._check_atol = float(value) 243 | else: 244 | raise ValueError("check_atol must be greater than or equal to zero.") 245 | 246 | @property 247 | def _transpose_class(self): 248 | return self.__class__ 249 | 250 | def transpose(self): 251 | """Return the transposed solve operator. 252 | 253 | Returns 254 | ------- 255 | pymatsolver.solvers.Base 256 | """ 257 | 258 | if self.is_symmetric: 259 | return self 260 | if self._transpose_class is None: 261 | raise NotImplementedError( 262 | 'The transpose for the {} class is not possible.'.format( 263 | self.__class__.__name__ 264 | ) 265 | ) 266 | newS = self._transpose_class(self.A.T, **self.get_attributes()) 267 | return newS 268 | 269 | @property 270 | def T(self): 271 | """The transposed solve operator 272 | 273 | See Also 274 | -------- 275 | transpose 276 | `T` is an alias for `transpose()`. 277 | """ 278 | return self.transpose() 279 | 280 | def conjugate(self): 281 | """Return the complex conjugate version of this solver. 282 | 283 | Returns 284 | ------- 285 | pymatsolver.solvers.Base 286 | """ 287 | if self.is_real: 288 | return self 289 | else: 290 | # make a shallow copy of myself 291 | conjugated = copy.copy(self) 292 | conjugated._is_conjugate = not self._is_conjugate 293 | return conjugated 294 | 295 | conj = conjugate 296 | 297 | def _compute_accuracy(self, rhs, x): 298 | resid_norm = np.linalg.norm(rhs - self.A @ x) 299 | rhs_norm = np.linalg.norm(rhs) 300 | tolerance = max(self.check_rtol * rhs_norm, self.check_atol) 301 | if resid_norm > tolerance: 302 | raise SolverAccuracyError( 303 | f'Accuracy on solve is above tolerance: {resid_norm} > {tolerance}' 304 | ) 305 | 306 | def solve(self, rhs): 307 | """Solves the system of equations for the given right hand side. 308 | 309 | Parameters 310 | ---------- 311 | rhs : (..., M, N) or (M, ) array_like 312 | The right handside of A @ x = b. 313 | 314 | Returns 315 | ------- 316 | x : (..., M, N) or (M, ) array_like 317 | The solution to the system of equations. 318 | 319 | See Also 320 | -------- 321 | numpy.linalg.solve 322 | Examples of how broadcasting works for this operation. 323 | """ 324 | # Make this broadcast just like numpy.linalg.solve! 325 | 326 | n = self.A.shape[0] 327 | ndim = len(rhs.shape) 328 | if ndim == 1: 329 | if len(rhs) != n: 330 | raise ValueError(f'Expected a vector of length {n}, got {len(rhs)}') 331 | if self._is_conjugate: 332 | rhs = rhs.conjugate() 333 | x = self._solve_single(rhs) 334 | else: 335 | if rhs.shape[-2] != n: 336 | raise ValueError(f'Second to last dimension should be {n}, got {rhs.shape}') 337 | do_broadcast = rhs.ndim > 2 338 | if do_broadcast: 339 | # swap last two dimensions 340 | rhs = rhs.swapaxes(-1, -2) 341 | in_shape = rhs.shape 342 | # Then collapse all other vectors into the first dimension 343 | rhs = rhs.reshape((-1, in_shape[-1])) 344 | # Then reverse the two axes to get the array to end up in fortran order 345 | # (which is more common for direct solvers). 346 | rhs = rhs.transpose() 347 | # should end up with shape (n, -1) 348 | if self._is_conjugate: 349 | rhs = rhs.conjugate() 350 | x = self._solve_multiple(rhs) 351 | if do_broadcast: 352 | # undo the reshaping above 353 | # so first, reverse the axes again. 354 | x = x.transpose() 355 | # then expand out the first dimension into multiple dimensions. 356 | x = x.reshape(in_shape) 357 | # then switch last two dimensions again. 358 | x = x.swapaxes(-1, -2) 359 | 360 | if self.check_accuracy: 361 | self._compute_accuracy(rhs, x) 362 | 363 | if self._is_conjugate: 364 | x = x.conjugate() 365 | return x 366 | 367 | @abstractmethod 368 | def _solve_single(self, rhs): 369 | ... 370 | 371 | @abstractmethod 372 | def _solve_multiple(self, rhs): 373 | ... 374 | 375 | def clean(self): 376 | pass 377 | 378 | def __del__(self): 379 | """Destruct to call clean when object is garbage collected.""" 380 | try: 381 | # make sure clean is called in case the underlying solver 382 | # doesn't automatically cleanup itself when garbage collected... 383 | self.clean() 384 | except: 385 | pass 386 | 387 | def __mul__(self, val): 388 | return self.__matmul__(val) 389 | 390 | def __rmul__(self, val): 391 | return self.__rmatmul__(val) 392 | 393 | def __matmul__(self, val): 394 | return self.solve(val) 395 | 396 | def __rmatmul__(self, val): 397 | tran_solver = self.transpose() 398 | # transpose last two axes of val 399 | if val.ndim > 1: 400 | val = val.swapaxes(-1, -2) 401 | out = tran_solver.solve(val) 402 | if val.ndim > 1: 403 | out = out.swapaxes(-1, -2) 404 | return out 405 | 406 | def get_attributes(self): 407 | attrs = { 408 | "is_symmetric": self.is_symmetric, 409 | "is_hermitian": self.is_hermitian, 410 | "is_positive_definite": self.is_positive_definite, 411 | "check_accuracy": self.check_accuracy, 412 | "check_rtol": self.check_rtol, 413 | "check_atol": self.check_atol, 414 | } 415 | return attrs 416 | 417 | 418 | class Diagonal(Base): 419 | """A solver for a diagonal matrix. 420 | 421 | Parameters 422 | ---------- 423 | A 424 | The diagonal matrix, must have a ``diagonal()`` method. 425 | check_accuracy : bool, optional 426 | Whether to check the accuracy of the solution. 427 | check_rtol : float, optional 428 | The relative tolerance to check against for accuracy. 429 | check_atol : float, optional 430 | The absolute tolerance to check against for accuracy. 431 | **kwargs 432 | Extra keyword arguments passed to the base class. 433 | """ 434 | 435 | def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs): 436 | try: 437 | self._diagonal = np.asarray(A.diagonal()) 438 | if not np.all(self._diagonal): 439 | # this works because 0.0 evaluates as False! 440 | raise ValueError("Diagonal matrix has a zero along the diagonal.") 441 | except AttributeError: 442 | raise TypeError("A must have a diagonal() method.") 443 | kwargs.pop("is_symmetric", None) 444 | is_hermitian = kwargs.pop("is_hermitian", None) 445 | is_positive_definite = kwargs.pop("is_positive_definite", None) 446 | super().__init__( 447 | A, is_symmetric=True, is_hermitian=False, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs 448 | ) 449 | if is_positive_definite is None: 450 | if self.is_real: 451 | is_positive_definite = self._diagonal.min() > 0 452 | else: 453 | is_positive_definite = (not np.any(self._diagonal.imag)) and self._diagonal.real.min() > 0 454 | is_positive_definite = bool(is_positive_definite) 455 | self.is_positive_definite = is_positive_definite 456 | 457 | if is_hermitian is None: 458 | if self.is_real: 459 | is_hermitian = True 460 | else: 461 | # can only be hermitian if all imaginary components on diagonal are zero. 462 | is_hermitian = not np.any(self._diagonal.imag) 463 | self.is_hermitian = is_hermitian 464 | 465 | def _solve_single(self, rhs): 466 | return rhs / self._diagonal 467 | 468 | def _solve_multiple(self, rhs): 469 | # broadcast the division 470 | return rhs / self._diagonal[:, None] 471 | 472 | 473 | class Triangle(Base): 474 | """A solver for a diagonal matrix. 475 | 476 | Parameters 477 | ---------- 478 | A : scipy.sparse.sparray or scipy.sparse.spmatrix 479 | The matrix to solve. 480 | lower : bool, optional 481 | Whether A is lower triangular (``True``), or upper triangular (``False``). 482 | check_accuracy : bool, optional 483 | Whether to check the accuracy of the solution. 484 | check_rtol : float, optional 485 | The relative tolerance to check against for accuracy. 486 | check_atol : float, optional 487 | The absolute tolerance to check against for accuracy. 488 | **kwargs 489 | Extra keyword arguments passed to the base class. 490 | """ 491 | 492 | def __init__(self, A, lower=True, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs): 493 | # pop off unneeded keyword arguments. 494 | is_hermitian = kwargs.pop("is_hermitian", False) 495 | is_symmetric = kwargs.pop("is_symmetric", False) 496 | is_positive_definite = kwargs.pop("is_positive_definite", False) 497 | if not (sp.issparse(A) and A.format in ['csr', 'csc']): 498 | A = sp.csc_matrix(A) 499 | A.sum_duplicates() 500 | super().__init__(A, is_hermitian=is_hermitian, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs) 501 | 502 | self.lower = lower 503 | 504 | @property 505 | def lower(self): 506 | return self._lower 507 | 508 | @lower.setter 509 | def lower(self, value): 510 | if isinstance(value, bool): 511 | self._lower = value 512 | else: 513 | raise TypeError("lower must be a bool.") 514 | 515 | def _solve_multiple(self, rhs): 516 | return spsolve_triangular(self.A, rhs, lower=self.lower) 517 | 518 | _solve_single = _solve_multiple 519 | 520 | def transpose(self): 521 | trans = super().transpose() 522 | trans.lower = not self.lower 523 | return trans 524 | 525 | 526 | class Forward(Triangle): 527 | """A solver for a lower triangular matrix. 528 | 529 | Parameters 530 | ---------- 531 | A : scipy.sparse.sparray or scipy.sparse.spmatrix 532 | The lower triangular matrix to solve. 533 | check_accuracy : bool, optional 534 | Whether to check the accuracy of the solution. 535 | check_rtol : float, optional 536 | The relative tolerance to check against for accuracy. 537 | check_atol : float, optional 538 | The absolute tolerance to check against for accuracy. 539 | **kwargs 540 | Extra keyword arguments passed to the base class. 541 | """ 542 | 543 | def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs): 544 | kwargs.pop("lower", None) 545 | super().__init__(A, lower=True, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs) 546 | 547 | 548 | class Backward(Triangle): 549 | """A solver for ann upper triangular matrix. 550 | 551 | Parameters 552 | ---------- 553 | A : scipy.sparse.sparray or scipy.sparse.spmatrix 554 | The upper triangular matrix to solve. 555 | check_accuracy : bool, optional 556 | Whether to check the accuracy of the solution. 557 | check_rtol : float, optional 558 | The relative tolerance to check against for accuracy. 559 | check_atol : float, optional 560 | The absolute tolerance to check against for accuracy. 561 | **kwargs 562 | Extra keyword arguments passed to the base class. 563 | """ 564 | 565 | _transpose_class = Forward 566 | 567 | def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, **kwargs): 568 | kwargs.pop("lower", None) 569 | super().__init__(A, lower=False, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, **kwargs) 570 | 571 | 572 | Forward._transpose_class = Backward --------------------------------------------------------------------------------