├── tests ├── __init__.py ├── test_conditions.py └── test_objective_functions.py ├── requirements-dev.txt ├── src └── inversion_ideas │ ├── minimize │ ├── __init__.py │ ├── _utils.py │ ├── _functions.py │ └── _minimizers.py │ ├── regularization │ ├── __init__.py │ └── _general.py │ ├── base │ ├── directive.py │ ├── __init__.py │ ├── minimizer.py │ ├── simulation.py │ ├── conditions.py │ └── objective_function.py │ ├── errors.py │ ├── typing.py │ ├── __init__.py │ ├── _utils.py │ ├── preconditioners.py │ ├── simulations.py │ ├── utils.py │ ├── data_misfit.py │ ├── inversion.py │ ├── inversion_log.py │ ├── conditions.py │ ├── recipes.py │ └── directives.py ├── environment.yml ├── Makefile ├── LICENSE ├── README.md ├── notebooks ├── regressor.py ├── 05_caching-data-misfit-values.ipynb ├── 08_conditions.ipynb └── 02_linear-regressor.ipynb ├── .github └── workflows │ ├── style.yml │ └── test.yml ├── pyproject.toml └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | numpy>=2.0.0 2 | scipy 3 | rich 4 | pytest 5 | ruff 6 | mypy 7 | scipy-stubs 8 | discretize==0.12.0 9 | simpeg==0.25.0 10 | pymatsolver==0.4.0 11 | -------------------------------------------------------------------------------- /src/inversion_ideas/minimize/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimizer functions and classes. 3 | """ 4 | 5 | from ._functions import conjugate_gradient 6 | from ._minimizers import GaussNewtonConjugateGradient 7 | 8 | __all__ = ["GaussNewtonConjugateGradient", "conjugate_gradient"] 9 | -------------------------------------------------------------------------------- /src/inversion_ideas/regularization/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Regularization classes. 3 | """ 4 | 5 | from ._general import TikhonovZero 6 | from ._mesh_based import Flatness, Smallness, SparseSmallness 7 | 8 | __all__ = [ 9 | "Flatness", 10 | "Smallness", 11 | "SparseSmallness", 12 | "TikhonovZero", 13 | ] 14 | -------------------------------------------------------------------------------- /src/inversion_ideas/base/directive.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for directives. 3 | """ 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | from ..typing import Model 8 | 9 | 10 | class Directive(ABC): 11 | """ 12 | Abstract class for directives. 13 | """ 14 | 15 | @abstractmethod 16 | def __call__(self, model: Model, iteration: int): 17 | pass 18 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: inversion_ideas 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python>=3.11 6 | - pip 7 | # Requirements 8 | - numpy>=2.0.0 9 | - scipy 10 | - rich 11 | - simpeg==0.25 12 | - pymatsolver==0.4.0 13 | - discretize==0.12.0 14 | # Required by notebooks 15 | - harmonica 16 | - verde 17 | - pandas 18 | # Jupyter 19 | - jupyter 20 | # Dev 21 | - pytest 22 | - ruff 23 | - mypy 24 | - scipy-stubs 25 | -------------------------------------------------------------------------------- /src/inversion_ideas/base/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base classes of the inversion framework. 3 | """ 4 | 5 | from .conditions import Condition 6 | from .directive import Directive 7 | from .minimizer import Minimizer 8 | from .objective_function import Combo, Objective, Scaled 9 | from .simulation import Simulation 10 | 11 | __all__ = [ 12 | "Combo", 13 | "Condition", 14 | "Directive", 15 | "Minimizer", 16 | "Objective", 17 | "Scaled", 18 | "Simulation", 19 | ] 20 | -------------------------------------------------------------------------------- /src/inversion_ideas/errors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom errors and warnings. 3 | """ 4 | 5 | 6 | class ConvergenceWarning(Warning): 7 | """ 8 | Warning raised for issues with convergence. 9 | """ 10 | 11 | 12 | class NotInitializedError(Exception): 13 | """ 14 | Exception raised when inversion is not yet initialized. 15 | 16 | Parameters 17 | ---------- 18 | message : str 19 | Explanation of the error 20 | """ 21 | 22 | def __init__(self, message=None): 23 | self.message = message if message is not None else "" 24 | super().__init__(self.message) 25 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PYTEST_TARGETS=src 2 | CHECK_STYLE=src notebooks 3 | 4 | .PHONY: help install test check check-format check_style format 5 | 6 | help: 7 | @echo "Commands:" 8 | @echo "" 9 | @echo " install install in editable mode" 10 | @echo " test run the test suite (including doctests) and report coverage" 11 | @echo " check run code style and quality checks with Ruff" 12 | @echo " format automatically format the code with Ruff" 13 | @echo "" 14 | 15 | install: 16 | python -m pip install --no-deps --editable . 17 | 18 | test: 19 | pytest --verbose --doctest-modules $(PYTEST_TARGETS) 20 | 21 | check: check-format check-style 22 | 23 | check-format: 24 | ruff format --check $(CHECK_STYLE) 25 | 26 | check-style: 27 | ruff check $(CHECK_STYLE) 28 | 29 | format: 30 | ruff check --fix $(CHECK_STYLE) 31 | ruff format $(CHECK_STYLE) 32 | 33 | -------------------------------------------------------------------------------- /src/inversion_ideas/base/minimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for minimizer. 3 | """ 4 | 5 | from abc import ABC, abstractmethod 6 | from collections.abc import Generator 7 | 8 | from ..typing import Model 9 | from .objective_function import Objective 10 | 11 | 12 | class Minimizer(ABC): 13 | """ 14 | Base class to represent minimizers as generators. 15 | """ 16 | 17 | @abstractmethod 18 | def __call__(self, objective: Objective, initial_model: Model) -> Generator[Model]: 19 | """ 20 | Minimize objective function. 21 | 22 | Parameters 23 | ---------- 24 | objective : Objective 25 | Objective function to be minimized. 26 | initial_model : (n_params) array 27 | Initial model used to start the minimization. 28 | 29 | Returns 30 | ------- 31 | Generator[Model] 32 | Generator that yields models after each iteration of the minimizer. 33 | """ 34 | -------------------------------------------------------------------------------- /src/inversion_ideas/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom types used for type hints. 3 | """ 4 | 5 | from typing import Protocol, TypeAlias 6 | 7 | import numpy as np 8 | import numpy.typing as npt 9 | from scipy.sparse import sparray 10 | from scipy.sparse.linalg import LinearOperator 11 | 12 | Model: TypeAlias = npt.NDArray[np.float64] 13 | """ 14 | Type alias to represent models in the inversion framework as 1D arrays. 15 | """ 16 | 17 | Preconditioner: TypeAlias = npt.NDArray[np.float64] | sparray | LinearOperator 18 | """ 19 | Type for static preconditioners. 20 | 21 | Static preconditioners can either be a dense matrix, a sparse matrix or 22 | a ``LinearOperator``. 23 | """ 24 | 25 | 26 | class SparseRegularization(Protocol): 27 | """ 28 | Protocol to define sparse regularizations that can be used with a IRLS algorithm. 29 | """ 30 | 31 | irls: bool 32 | 33 | def update_irls(self, model: Model) -> None: 34 | raise NotImplementedError 35 | 36 | def activate_irls(self, model_previous: Model) -> None: 37 | raise NotImplementedError 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 SimPEG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/inversion_ideas/base/simulation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes to represent simulations. 3 | """ 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | import numpy as np 8 | from numpy.typing import NDArray 9 | from scipy.sparse.linalg import LinearOperator 10 | 11 | from ..typing import Model 12 | 13 | 14 | class Simulation(ABC): 15 | """ 16 | Abstract representation of a simulation. 17 | """ 18 | 19 | @abstractmethod 20 | def __init__(self): 21 | pass 22 | 23 | @property 24 | @abstractmethod 25 | def n_params(self) -> int: 26 | """ 27 | Number of model parameters. 28 | """ 29 | 30 | @property 31 | @abstractmethod 32 | def n_data(self) -> int: 33 | """ 34 | Number of data values. 35 | """ 36 | 37 | @abstractmethod 38 | def __call__(self, model: Model) -> NDArray[np.float64]: 39 | """ 40 | Evaluate simulation for a given model. 41 | """ 42 | 43 | @abstractmethod 44 | def jacobian(self, model: Model) -> NDArray[np.float64] | LinearOperator: 45 | """ 46 | Jacobian matrix for a given model. 47 | """ 48 | -------------------------------------------------------------------------------- /src/inversion_ideas/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ideas for inversion framework. 3 | """ 4 | 5 | from . import base, typing, utils 6 | from ._version import __version__ 7 | from .conditions import ChiTarget, CustomCondition, ModelChanged, ObjectiveChanged 8 | from .data_misfit import DataMisfit 9 | from .directives import ( 10 | Irls, 11 | MultiplierCooler, 12 | UpdateSensitivityWeights, 13 | ) 14 | from .errors import ConvergenceWarning 15 | from .inversion import Inversion 16 | from .inversion_log import InversionLog, InversionLogRich 17 | from .minimize import GaussNewtonConjugateGradient, conjugate_gradient 18 | from .preconditioners import JacobiPreconditioner, get_jacobi_preconditioner 19 | from .recipes import ( 20 | create_l2_inversion, 21 | create_sparse_inversion, 22 | create_tikhonov_regularization, 23 | ) 24 | from .regularization import Flatness, Smallness, SparseSmallness, TikhonovZero 25 | from .simulations import wrap_simulation 26 | 27 | __all__ = [ 28 | "ChiTarget", 29 | "ConvergenceWarning", 30 | "CustomCondition", 31 | "DataMisfit", 32 | "Flatness", 33 | "GaussNewtonConjugateGradient", 34 | "Inversion", 35 | "InversionLog", 36 | "InversionLogRich", 37 | "Irls", 38 | "JacobiPreconditioner", 39 | "ModelChanged", 40 | "MultiplierCooler", 41 | "ObjectiveChanged", 42 | "Smallness", 43 | "SparseSmallness", 44 | "TikhonovZero", 45 | "UpdateSensitivityWeights", 46 | "__version__", 47 | "base", 48 | "conjugate_gradient", 49 | "create_l2_inversion", 50 | "create_sparse_inversion", 51 | "create_tikhonov_regularization", 52 | "get_jacobi_preconditioner", 53 | "typing", 54 | "utils", 55 | "wrap_simulation", 56 | ] 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Draft ideas for a new design of the inversion framework 2 | 3 | This repo was created as motivation to start implementing some of the ideas 4 | laid out in https://curvenote.com/@simpeg/simpeg-inversion-refactor. 5 | 6 | > [!WARNING] 7 | > This repository will never host a stable and well tested codebase. Its 8 | > purpose is to have a place where we can freely try new ideas without having 9 | > to worry about breaking existing code, supporting backward compatibility, or 10 | > providing support to the community. 11 | 12 | 13 | ## Goals 14 | 15 | The main goal of the redesign is to create a new public inversion framework in 16 | SimPEG that: 17 | 18 | * Is modular, allowing users to plug custom or third-party minimizers, 19 | simulations, regularizations, etc. 20 | * Can be easily extended (with minimum extra work) to problems outside the 21 | traditional $\phi(m) = \phi_d(m) + \beta \phi_m(m)$ inversion problem. 22 | * Is implemented in such way that is easier to read, study, and understand by 23 | beginners that are taking their first steps in inversion theory. 24 | * Is implemented in such way that can be easily extended to more complex and/or 25 | complicated use cases. 26 | * Defines a clear and minimal interface for each one of the classes. 27 | * Makes use of [abstract classes](https://docs.python.org/3/library/abc.html) 28 | to enforce implementation of required methods and properties. 29 | * Simplifies the inheritance tree by lowering the amount of inheritance levels. 30 | 31 | > [!NOTE] 32 | > These goals are not set in stone and are flexible. We are free to add, 33 | > remove, and edit them at any point. 34 | 35 | ## License 36 | 37 | The code in this repository is made available under an *MIT License**. 38 | A copy of this license is provided in [`LICENSE`](LICENSE). 39 | -------------------------------------------------------------------------------- /notebooks/regressor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inversion framework to implement a linear regressor. 3 | """ 4 | 5 | import time 6 | 7 | import numpy as np 8 | from numpy.typing import NDArray 9 | from scipy.sparse.linalg import LinearOperator 10 | 11 | from inversion_ideas.base import Simulation 12 | from inversion_ideas.utils import cache_on_model 13 | 14 | 15 | class LinearRegressor(Simulation): 16 | r""" 17 | Linear regressor. 18 | 19 | .. math:: 20 | 21 | \mathbf{y} = \mathbf{X} \cdot \mathbf{m} 22 | """ 23 | 24 | def __init__(self, X, linop=False, sleep=0, cache=True): 25 | self.X = X 26 | self.linop = linop 27 | self.sleep = sleep 28 | self.cache = cache 29 | 30 | @property 31 | def n_params(self) -> int: 32 | """ 33 | Number of model parameters. 34 | """ 35 | return self.X.shape[1] 36 | 37 | @property 38 | def n_data(self) -> int: 39 | """ 40 | Number of data values. 41 | """ 42 | return self.X.shape[0] 43 | 44 | @cache_on_model 45 | def __call__(self, model) -> NDArray[np.float64]: 46 | """ 47 | Evaluate simulation for a given model. 48 | """ 49 | if self.sleep != 0: 50 | time.sleep(self.sleep) 51 | return self.X @ model 52 | 53 | def jacobian(self, model) -> NDArray[np.float64] | LinearOperator: # noqa: ARG002 54 | """ 55 | Jacobian matrix for a given model. 56 | """ 57 | if self.linop: 58 | linear_op = LinearOperator( 59 | shape=(self.n_data, self.n_params), 60 | matvec=lambda model: self.X @ model, 61 | rmatvec=lambda model: self.X.T @ model, 62 | dtype=np.float64, 63 | ) 64 | return linear_op 65 | return self.X 66 | -------------------------------------------------------------------------------- /.github/workflows/style.yml: -------------------------------------------------------------------------------- 1 | # Check style 2 | # 3 | # NOTE: Pin actions to a specific commit to avoid having the authentication 4 | # token stolen if the Action is compromised. See the comments and links here: 5 | # https://github.com/pypa/gh-action-pypi-publish/issues/27 6 | # 7 | name: style 8 | 9 | # Only build PRs, the main branch, and releases. Pushes to branches will only 10 | # be built when a PR is opened. This avoids duplicated buids in PRs comming 11 | # from branches in the origin repository (1 for PR and 1 for push). 12 | on: 13 | pull_request: 14 | push: 15 | branches: 16 | - main 17 | release: 18 | types: 19 | - published 20 | 21 | permissions: {} 22 | 23 | jobs: 24 | style: 25 | name: Style 26 | runs-on: ubuntu-latest 27 | 28 | steps: 29 | # Cancel any previous run of the test job 30 | # We pin the commit hash corresponding to v0.5.0, and not pinning the tag 31 | # because we are giving full access through the github.token. 32 | - name: Cancel Previous Runs 33 | uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa 34 | with: 35 | access_token: ${{ github.token }} 36 | 37 | # Checks-out your repository under $GITHUB_WORKSPACE 38 | - name: Checkout 39 | uses: actions/checkout@v4 40 | with: 41 | # The GitHub token is preserved by default but this job doesn't need 42 | # to be able to push to GitHub. 43 | persist-credentials: false 44 | 45 | - name: Setup Python 46 | uses: actions/setup-python@v5 47 | with: 48 | python-version: 3.12 49 | 50 | - name: Install requirements 51 | run: python -m pip install --requirement requirements-dev.txt 52 | 53 | - name: List installed packages 54 | run: python -m pip freeze 55 | 56 | - name: Check format with Ruff 57 | run: make check-format 58 | 59 | - name: Check style with Ruff 60 | run: make check-style 61 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # Run tests in GitHub Actions 2 | # 3 | # NOTE: Pin actions to a specific commit to avoid having the authentication 4 | # token stolen if the Action is compromised. See the comments and links here: 5 | # https://github.com/pypa/gh-action-pypi-publish/issues/27 6 | # 7 | name: test 8 | 9 | # Only build PRs, the main branch, and releases. Pushes to branches will only 10 | # be built when a PR is opened. This avoids duplicated buids in PRs comming 11 | # from branches in the origin repository (1 for PR and 1 for push). 12 | on: 13 | pull_request: 14 | push: 15 | branches: 16 | - main 17 | release: 18 | types: 19 | - published 20 | 21 | permissions: {} 22 | 23 | jobs: 24 | ############################################################################# 25 | # Run tests 26 | test: 27 | name: Test 28 | runs-on: ubuntu-latest 29 | 30 | steps: 31 | # Cancel any previous run of the test job 32 | # We pin the commit hash corresponding to v0.5.0, and not pinning the tag 33 | # because we are giving full access through the github.token. 34 | - name: Cancel Previous Runs 35 | uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa 36 | with: 37 | access_token: ${{ github.token }} 38 | 39 | # Checks-out your repository under $GITHUB_WORKSPACE 40 | - name: Checkout 41 | uses: actions/checkout@v4 42 | with: 43 | # The GitHub token is preserved by default but this job doesn't need 44 | # to be able to push to GitHub. 45 | persist-credentials: false 46 | 47 | - name: Setup Python 48 | uses: actions/setup-python@v5 49 | with: 50 | python-version: 3.12 51 | 52 | - name: Install requirements 53 | run: python -m pip install --requirement requirements-dev.txt 54 | 55 | - name: Install the package 56 | run: python -m pip install --no-deps --editable . 57 | 58 | - name: List installed packages 59 | run: python -m pip freeze 60 | 61 | - name: Run the tests 62 | run: make test 63 | -------------------------------------------------------------------------------- /src/inversion_ideas/_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code utilities. 3 | 4 | Objects in this submodule are meant to be private. 5 | """ 6 | 7 | from collections.abc import Callable, Iterable, Iterator 8 | from copy import copy 9 | 10 | import numpy as np 11 | import numpy.typing as npt 12 | 13 | from inversion_ideas.base.objective_function import Objective, Scaled 14 | 15 | 16 | def prod_arrays(arrays: Iterator[npt.NDArray[np.float64]]) -> npt.NDArray[np.float64]: 17 | """ 18 | Compute product of arrays within an iterator. 19 | 20 | Parameters 21 | ---------- 22 | arrays : Iterator 23 | Iterator with arrays. 24 | """ 25 | if not arrays: 26 | msg = "Invalid empty 'arrays' array when summing." 27 | raise ValueError(msg) 28 | 29 | result = copy(next(arrays)) 30 | for array in arrays: 31 | result *= array 32 | return result 33 | 34 | 35 | def extract_from_combo( 36 | objective: Objective, condition: Callable[[Objective], bool] 37 | ) -> list[Objective]: 38 | """ 39 | Extract objective functions within a Combo objective function recursively. 40 | 41 | .. important:: 42 | 43 | Scaled objective functions are not going to be included in the extracted list, 44 | but their underlying functions are going to be considered. 45 | 46 | Parameters 47 | ---------- 48 | objective : Objective 49 | Objective function to explore. 50 | condition : Callable 51 | Condition that each objective function must satisfy to be included in the 52 | returned list. 53 | 54 | Returns 55 | ------- 56 | list of Objective 57 | List of extracted objective functions within the ``objective`` that satisfy the 58 | ``condition``. 59 | """ 60 | if not isinstance(objective, Iterable): 61 | if isinstance(objective, Scaled): 62 | extracted = extract_from_combo(objective.function, condition) 63 | else: 64 | extracted = [objective] if condition(objective) else [] 65 | return extracted 66 | 67 | extracted = [] 68 | for reg in objective: 69 | extracted += extract_from_combo(reg, condition) 70 | 71 | return extracted 72 | -------------------------------------------------------------------------------- /src/inversion_ideas/minimize/_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for minimizers. 3 | """ 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | 8 | from ..base import Objective 9 | from ..typing import Model 10 | 11 | 12 | def backtracking_line_search( 13 | phi: Objective, 14 | model: Model, 15 | search_direction: npt.NDArray[np.float64], 16 | *, 17 | contraction_factor: float = 0.5, 18 | c_factor: float = 0.5, 19 | phi_value: float | None = None, 20 | phi_gradient: npt.NDArray[np.float64] | None = None, 21 | maxiter: int = 20, 22 | ) -> tuple[float | None, int]: 23 | """ 24 | Implement the backtracking line search algorithm. 25 | 26 | Parameters 27 | ---------- 28 | phi : Objective 29 | Objective function to which the line search will be applied. 30 | model : (n_params) array 31 | Current model. 32 | search_direction: (n_params) array 33 | Vector used as a search direction. 34 | contraction_factor : float 35 | Contraction factor for the step length. Must be greater than 0 and lower than 1. 36 | c_factor : float 37 | The c factor used in the descent condition. 38 | Must be greater than 0 and lower than 1. 39 | phi_value : float or None, optional 40 | Precomputed value of ``phi(model)``. If None, it will be computed. 41 | phi_gradient : (n_params) array, optional 42 | Precomputed value of ``phi.gradient(model)``. If None, it will be computed. 43 | maxiter : int, optional 44 | Maximum number of line search iterations. 45 | 46 | Returns 47 | ------- 48 | step_length : float or None 49 | Alpha for which `x_new = x0 + alpha * pk`, or None if the line search algorithm 50 | did not converge. 51 | n_iterations : int 52 | Number of line search iterations. 53 | 54 | Notes 55 | ----- 56 | TODO 57 | 58 | Nocedal & Wright (1999), page 41. 59 | 60 | References 61 | ---------- 62 | Nocedal, J., & Wright, S. J. (1999). Numerical optimization. Springer. 63 | """ 64 | phi_value = phi_value if phi_value is not None else phi(model) 65 | phi_gradient = phi_gradient if phi_gradient is not None else phi.gradient(model) 66 | 67 | def stop_condition(step_length): 68 | return ( 69 | phi(model + step_length * search_direction) 70 | <= phi_value + c_factor * step_length * phi_gradient @ search_direction 71 | ) 72 | 73 | step_length = 1.0 74 | n_iterations = 0 75 | while not stop_condition(step_length): 76 | step_length *= contraction_factor 77 | n_iterations += 1 78 | 79 | if n_iterations >= maxiter: 80 | return None, n_iterations 81 | 82 | return step_length, n_iterations 83 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "inversion_ideas" 3 | dynamic = ["version"] 4 | description = "New design ideas for inversion framework." 5 | readme = "README.md" 6 | license = "MIT" 7 | requires-python = ">=3.11" 8 | dependencies = [ 9 | "numpy>=2.0.0", 10 | "scipy", 11 | "rich", 12 | "simpeg==0.25.0", 13 | "pymatsolver==0.4.0", 14 | "discretize==0.12.0", 15 | ] 16 | 17 | [build-system] 18 | requires = ["setuptools>=61", "wheel", "setuptools_scm[toml]>=6.2"] 19 | build-backend = "setuptools.build_meta" 20 | 21 | [tool.setuptools_scm] 22 | version_scheme = "post-release" 23 | local_scheme = "no-local-version" 24 | write_to = "src/inversion_ideas/_version.py" 25 | 26 | [tool.ruff] 27 | line-length = 88 28 | exclude = [ 29 | "src/inversion_ideas/_version.py", 30 | ] 31 | 32 | [tool.ruff.lint] 33 | extend-select = [ 34 | "ARG", # flake8-unused-arguments 35 | "B", # flake8-bugbear 36 | "C4", # flake8-comprehensions 37 | "D", # pydocstyle 38 | "EM", # flake8-errmsg 39 | "EXE", # flake8-executable 40 | "FURB", # refurb 41 | "G", # flake8-logging-format 42 | "I", # isort 43 | "ICN", # flake8-import-conventions 44 | "NPY", # numPy specific rules 45 | "PD", # pandas-vet 46 | "PGH", # pygrep-hooks 47 | "PIE", # flake8-pie 48 | "PL", # pylint 49 | "PT", # flake8-pytest-style 50 | "PTH", # flake8-use-pathlib 51 | "PYI", # flake8-pyi 52 | "RET", # flake8-return 53 | "RUF", # ruff-specific 54 | "SIM", # flake8-simplify 55 | "T20", # flake8-print 56 | "UP", # pyupgrade 57 | "YTT", # flake8-2020 58 | ] 59 | ignore = [ 60 | "ISC001", # Conflicts with formatter 61 | "PLR09", # Too many <...> 62 | "PLR2004", # Magic value used in comparison 63 | "RET504", # Allow variable assignment only for return 64 | "PT001", # Conventions for parenthesis on pytest.fixture 65 | "D200", # Allow single line docstrings in their own line 66 | # Temporary ignores: 67 | "D102", # Allow no docstrings in public methods 68 | "D105", # Allow no docstrings in magic methods 69 | "D414", # Allow empty sections in docstrings 70 | ] 71 | 72 | [tool.ruff.lint.per-file-ignores] 73 | "__init__.py" = [ 74 | "F401", # Disable unused-imports errors on __init__.py 75 | ] 76 | "test/**" = [ 77 | "D", # Ignore pydocstyle warnings in tests 78 | "T20", # Allow print statements in tests 79 | ] 80 | "notebooks/**.ipynb" = [ 81 | "B018", # Allow unused expression (prints on notebooks) 82 | "D", # Ignore pydocstyle warnings in notebooks 83 | "PD901", # Allow to use df as variables for pandas DataFrames 84 | "T20", # Allow print statements in notebooks 85 | ] 86 | 87 | [tool.ruff.lint.pydocstyle] 88 | convention = "numpy" 89 | -------------------------------------------------------------------------------- /src/inversion_ideas/preconditioners.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes and functions to build preconditioners. 3 | """ 4 | 5 | from scipy.sparse import diags_array, sparray 6 | 7 | from .base import Objective 8 | from .typing import Model 9 | 10 | 11 | class JacobiPreconditioner: 12 | """ 13 | Jacobi preconditioner for a given objective function. 14 | 15 | Use this class to define a dynamic Jacobi preconditioner from an objective function. 16 | This class is a callable that will update the preconditioner for the given model 17 | each time it gets called. Use this class if you want to update the preconditioner 18 | on every iteration of the `Inversion`. 19 | 20 | Parameters 21 | ---------- 22 | objective_function : Objective 23 | Objective function for which the Jacobi preconditioner will be built. 24 | 25 | See Also 26 | -------- 27 | get_jacobi_preconditioner 28 | """ 29 | 30 | def __init__(self, objective_function: Objective): 31 | self.objective_function = objective_function 32 | 33 | def __call__(self, model: Model) -> sparray: 34 | """ 35 | Generate a Jacobi preconditioner as a sparse diagonal array for a given model. 36 | 37 | Parameters 38 | ---------- 39 | model : (n_params) array 40 | Model that will be used to build the Jacobi preconditioner from the 41 | ``objective_function``. 42 | 43 | Returns 44 | ------- 45 | dia_array 46 | """ 47 | return get_jacobi_preconditioner(self.objective_function, model) 48 | 49 | 50 | def get_jacobi_preconditioner(objective_function: Objective, model: Model): 51 | r""" 52 | Obtain a Jacobi preconditioner from an objective function. 53 | 54 | Parameters 55 | ---------- 56 | objective_function : Objective 57 | Objective function from which the preconditioner will be built. 58 | model : (n_params) array 59 | Model used to build the preconditioner. 60 | 61 | Returns 62 | ------- 63 | diag_array 64 | Preconditioner as a sparse diagonal array. 65 | 66 | Notes 67 | ----- 68 | Given an objective function :math:`\phi(\mathbf{m})`, this function builds the 69 | Jacobi preconditioner :math:`\mathbf{P}(\mathbf{m})` as the inverse of the diagonal 70 | of the Hessian of :math:`\phi(\mathbf{m})`: 71 | 72 | .. math:: 73 | 74 | \mathbf{P}(\mathbf{m}) = \text{diag}[ \bar{\bar{\nabla}} \phi(\mathbf{m}) ]^{-1} 75 | 76 | where :math:`\bar{\bar{\nabla}} \phi(\mathbf{m})` is the Hessian of 77 | :math:`\phi(\mathbf{m})`. 78 | """ 79 | hessian_diag = objective_function.hessian_diagonal(model) 80 | 81 | # Compute inverse only for non-zero elements 82 | zeros = hessian_diag == 0.0 83 | hessian_diag[~zeros] **= -1 84 | 85 | return diags_array(hessian_diag) 86 | -------------------------------------------------------------------------------- /src/inversion_ideas/minimize/_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimizer functions. 3 | 4 | Define functions that can be use to minimize an objective function in a single call. 5 | """ 6 | 7 | import warnings 8 | from collections.abc import Callable 9 | 10 | from scipy.sparse.linalg import cg 11 | 12 | from ..base import Objective 13 | from ..errors import ConvergenceWarning 14 | from ..typing import Model, Preconditioner 15 | 16 | 17 | def conjugate_gradient( 18 | objective: Objective, 19 | initial_model: Model, 20 | preconditioner: Preconditioner | Callable[[Model], Preconditioner] | None = None, 21 | **kwargs, 22 | ) -> Model: 23 | r""" 24 | Minimize objective function with a Conjugate Gradient method. 25 | 26 | .. important:: 27 | 28 | This minimizer should be used only for linear objective functions. 29 | 30 | Parameters 31 | ---------- 32 | objective : Objective 33 | Objective function to be minimized. 34 | initial_model : (n_params) array 35 | Initial model used to start the minimization. 36 | preconditioner : (n_params, n_params) array, sparray or LinearOperator or Callable, optional 37 | Matrix used as preconditioner in the conjugant gradient algorithm. 38 | If None, no preconditioner will be used. 39 | A callable can be passed to build the preconditioner dynamically: such 40 | callable should take a single ``initial_model`` argument and return an 41 | array, `sparray` or a `LinearOperator`. 42 | kwargs : dict 43 | Extra arguments that will be passed to the :func:`scipy.sparse.linalg.cg` 44 | function. 45 | 46 | Returns 47 | ------- 48 | inverted_model : (n_params) array 49 | Inverted model obtained after minimization. 50 | 51 | Notes 52 | ----- 53 | Minimize the objective function :math:`\phi(\mathbf{m})` by solving the system: 54 | 55 | .. math:: 56 | 57 | \bar{\bar{\nabla}} \phi \mathbf{m}^{*} = - \bar{\nabla} \phi 58 | 59 | through a Conjugate Gradient algorithm, where :math:`\bar{\bar{\nabla}} \phi` 60 | and :math:`\bar{\nabla} \phi` are the the Hessian and the gradient of the 61 | objective function, respectively. 62 | """ 63 | if preconditioner is not None and "M" in kwargs: 64 | msg = "Cannot simultanously pass `preconditioner` and `M`." 65 | raise ValueError(msg) 66 | 67 | if preconditioner is not None: 68 | if callable(preconditioner): 69 | preconditioner = preconditioner(initial_model) 70 | kwargs["M"] = preconditioner 71 | 72 | # TODO: maybe it would be nice to add a `is_linear` attribute to the objective 73 | # functions for the ones that generate a linear problem. 74 | gradient = objective.gradient(initial_model) 75 | hessian = objective.hessian(initial_model) 76 | model_step, info = cg(hessian, -gradient, **kwargs) 77 | if info != 0: 78 | warnings.warn( 79 | "Conjugate gradient convergence to tolerance not achieved after " 80 | f"{info} number of iterations.", 81 | ConvergenceWarning, 82 | stacklevel=2, 83 | ) 84 | inverted_model = initial_model + model_step 85 | return inverted_model 86 | -------------------------------------------------------------------------------- /src/inversion_ideas/simulations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrap SimPEG simulations to work with this new inversion framework. 3 | """ 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | from scipy.sparse.linalg import LinearOperator 8 | 9 | from .base import Simulation 10 | from .typing import Model 11 | 12 | 13 | def wrap_simulation(simulation, *, store_jacobian=False): 14 | """ 15 | Wrap a SimPEG's simulation. 16 | 17 | Parameters 18 | ---------- 19 | simulation : object 20 | Instance of a SimPEG simulation. 21 | store_jacobian : bool, optional 22 | Whether to store the jacobian matrix as a dense or sparse matrix. 23 | If False, the ``jacobian`` method will return 24 | a :class:`~scipy.sparse.linalg.LinearOperator` that calls the ``Jvec`` and 25 | ``Jtvec`` methods of the SimPEG simulation. 26 | Default to False. 27 | 28 | Returns 29 | ------- 30 | WrappedSimulation 31 | """ 32 | return WrappedSimulation(simulation, store_jacobian=store_jacobian) 33 | 34 | 35 | class WrappedSimulation(Simulation): 36 | """ 37 | Wrapper of SimPEG's simulations. 38 | 39 | This class is meant to be used within the new framework. 40 | 41 | Parameters 42 | ---------- 43 | simulation : object 44 | Instance of a SimPEG simulation. 45 | store_jacobian : bool, optional 46 | Whether to store the jacobian matrix as a dense or sparse matrix. 47 | If False, the ``jacobian`` method will return 48 | a :class:`~scipy.sparse.linalg.LinearOperator` that calls the ``Jvec`` and 49 | ``Jtvec`` methods of the SimPEG simulation. 50 | Default to False. 51 | """ 52 | 53 | def __init__(self, simulation, *, store_jacobian=False): 54 | has_getJ = hasattr(simulation, "getJ") and callable(simulation.getJ) 55 | if store_jacobian and not has_getJ: 56 | msg = ( 57 | "Not possible to set `store_jacobian` to True when wrapping the " 58 | f"`{type(simulation).__name__}`: the simulation doesn't have a " 59 | "`getJ` method to build the jacobian matrix." 60 | ) 61 | raise TypeError(msg) 62 | 63 | self.simulation = simulation 64 | self.store_jacobian = store_jacobian 65 | 66 | @property 67 | def n_params(self) -> int: 68 | """ 69 | Number of model parameters. 70 | """ 71 | # Potential field simulations have nC attribute with number of parameters 72 | if hasattr(self.simulation, "nC"): 73 | return self.simulation.nC 74 | 75 | # Cover other type of simulations 76 | if hasattr(self.simulation, "model") and self.simulation.model is not None: 77 | return len(self.simulation.model) 78 | 79 | msg = f"Cannot obtain number of parameters for simulation '{self.simulation}'." 80 | raise AttributeError(msg) 81 | 82 | @property 83 | def n_data(self) -> int: 84 | """ 85 | Number of data values. 86 | """ 87 | return self.simulation.survey.nD 88 | 89 | def __call__(self, model: Model) -> npt.NDArray[np.float64]: 90 | """ 91 | Evaluate simulation for a given model. 92 | """ 93 | return self.simulation.dpred(model) 94 | 95 | def jacobian(self, model: Model) -> npt.NDArray[np.float64] | LinearOperator: 96 | """ 97 | Jacobian matrix for a given model. 98 | """ 99 | if self.store_jacobian: 100 | jac = self.simulation.getJ(model) 101 | else: 102 | jac = LinearOperator( 103 | shape=(self.n_data, self.n_params), 104 | dtype=np.float64, 105 | matvec=lambda v: self.simulation.Jvec(model, v), 106 | rmatvec=lambda v: self.simulation.Jtvec(model, v), 107 | ) 108 | return jac 109 | -------------------------------------------------------------------------------- /src/inversion_ideas/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code utilities. 3 | """ 4 | 5 | import functools 6 | import hashlib 7 | import logging 8 | 9 | import numpy as np 10 | import numpy.typing as npt 11 | from scipy.sparse import sparray 12 | 13 | __all__ = [ 14 | "cache_on_model", 15 | "get_logger", 16 | "get_sensitivity_weights", 17 | ] 18 | 19 | LOGGER = logging.Logger("inversions") 20 | LOGGER.addHandler(logging.StreamHandler()) 21 | 22 | 23 | def _create_logger(): 24 | """ 25 | Create custom logger. 26 | """ 27 | logger = logging.getLogger("inversions") 28 | logger.setLevel(logging.INFO) 29 | handler = logging.StreamHandler() 30 | formatter = logging.Formatter("{levelname}: {message}", style="{") 31 | handler.setFormatter(formatter) 32 | logger.addHandler(handler) 33 | return logger 34 | 35 | 36 | LOGGER = _create_logger() 37 | 38 | 39 | def get_logger(): 40 | r""" 41 | Get the default event logger. 42 | 43 | The logger records events and relevant information while setting up simulations and 44 | inversions. By default the logger will stream to stderr and using the INFO level. 45 | 46 | Returns 47 | ------- 48 | logger : :class:`logging.Logger` 49 | The logger object for SimPEG. 50 | """ 51 | return LOGGER 52 | 53 | 54 | def cache_on_model(func): 55 | """ 56 | Cache the last result of a method within the instance using the model hash. 57 | 58 | .. important:: 59 | 60 | Use this decorator only for methods that take the ``model`` as the first 61 | argument. 62 | 63 | .. important:: 64 | 65 | The instance needs to have a ``cache`` bool attribute. If True, the result 66 | of the decorated method will be cached. If False, no caching will be performed. 67 | 68 | Examples 69 | -------- 70 | >>> import numpy as np 71 | >>> 72 | >>> class MyClass: 73 | ... 74 | ... def __init__(self): 75 | ... self.cache = True 76 | ... 77 | ... @cache_on_model 78 | ... def squared(self, model) -> float: 79 | ... return (model ** 2).sum() 80 | >>> 81 | >>> sq = MyClass() 82 | >>> model = np.array([1.0, 2.0, 3.0]) 83 | >>> print(sq.squared(model)) # perform the computation 84 | 14.0 85 | >>> print(sq.squared(model)) # access the cached result 86 | 14.0 87 | 88 | >>> model_new = np.array([4.0, 5.0, 6.0]) 89 | >>> print(sq.squared(model_new)) # perform a new computation 90 | 77.0 91 | """ 92 | # Define attribute name for the model hash 93 | model_hash_attr = "_model_hash" 94 | 95 | # Define attribute name for the cached result using the hash of the function 96 | cache_attr = f"_cache_{hash(func)}" 97 | 98 | @functools.wraps(func) 99 | def wrapper(self, model, *args, **kwargs): 100 | if not hasattr(self, "cache"): 101 | msg = f"Missing 'cache' attribute in {self}" 102 | raise AttributeError(msg) 103 | 104 | if self.cache: 105 | model_hash = hashlib.sha256(model) 106 | if ( 107 | hasattr(self, model_hash_attr) 108 | and getattr(self, model_hash_attr).digest() == model_hash.digest() 109 | ): 110 | return getattr(self, cache_attr) 111 | 112 | result = func(self, model, *args, **kwargs) 113 | setattr(self, cache_attr, result) 114 | setattr(self, model_hash_attr, model_hash) 115 | else: 116 | result = func(self, model, *args, **kwargs) 117 | return result 118 | 119 | return wrapper 120 | 121 | 122 | def get_sensitivity_weights( 123 | jacobian: npt.NDArray[np.float64], 124 | *, 125 | data_weights: npt.NDArray[np.float64] | sparray | None = None, 126 | volumes: npt.NDArray[np.float64] | None = None, 127 | vmin: float | None = 1e-12, 128 | ): 129 | """ 130 | Compute sensitivity weights. 131 | 132 | Parameters 133 | ---------- 134 | jacobian : (n_data, n_params) array 135 | Jacobian matrix used to compute sensitivity weights. 136 | data_weights : (n_data, n_data) array or None, optional 137 | Data weights matrix used to compute the sensitivity weights. 138 | volumes : (n_params) array 139 | Array with the volumes of the active cells. Sensitivity weights are 140 | divided by the volumes to account for sensitivity changes due to cell sizes. 141 | vmin : float or None, optional 142 | Minimum value used for clipping. 143 | 144 | Notes 145 | ----- 146 | """ 147 | matrix = data_weights @ jacobian if data_weights is not None else jacobian 148 | sensitivty_weights = np.sqrt(np.sum(matrix**2, axis=0)) 149 | 150 | if volumes is not None: 151 | sensitivty_weights /= volumes 152 | 153 | # Normalize it by maximum value 154 | sensitivty_weights /= sensitivty_weights.max() 155 | 156 | # Clip to vmin 157 | if vmin is not None: 158 | sensitivty_weights[sensitivty_weights < vmin] = vmin 159 | 160 | return sensitivty_weights 161 | -------------------------------------------------------------------------------- /src/inversion_ideas/regularization/_general.py: -------------------------------------------------------------------------------- 1 | """ 2 | General purpose regularization classes. 3 | """ 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | from scipy.sparse import dia_array, diags_array 8 | 9 | from .._utils import prod_arrays 10 | from ..base import Objective 11 | from ..typing import Model 12 | 13 | 14 | class TikhonovZero(Objective): 15 | r""" 16 | Tikhonov zero-th order regularization. 17 | 18 | Parameters 19 | ---------- 20 | n_params : int 21 | Number of elements in the ``model`` array. 22 | weights : (n_params) array or dict of (n_params) arrays or None, optional 23 | Array with regularization weights. 24 | For multiple weights, pass a dictionary where keys are strings and values are 25 | the different weights arrays. 26 | If None, no weights are going to be used. 27 | reference_model : (n_params) array or None, optional 28 | Array with values for the reference model. 29 | 30 | Notes 31 | ----- 32 | Implement a Tikhonov zero-th order regularization as follows: 33 | 34 | .. math:: 35 | 36 | \phi(\mathbf{m}) 37 | = \sum\limits_{i=1}^M w_i |m_i - m_i^\text{ref}|^2 38 | = \lVert \mathbf{W} (\mathbf{m} - \mathbf{m}^\text{ref}) \rVert^2 39 | 40 | where :math:`\mathbf{W} = [\sqrt{w_1}, \dots, \sqrt{w_M}]` are the square roots of 41 | the regularization weights, 42 | :math:`\mathbf{m} = [m_1, \dots, m_M]` and :math:`\mathbf{m}^\text{ref} 43 | = [m_1^\text{ref}, \dots, m_M^\text{ref}]` are the model and reference model 44 | vectors, respectively. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | n_params: int, 50 | weights: npt.NDArray | dict[str, npt.NDArray] | None = None, 51 | reference_model=None, 52 | ): 53 | self._n_params = n_params 54 | 55 | if weights is None: 56 | weights = np.ones(n_params, dtype=np.float64) 57 | self.weights = weights 58 | 59 | self.reference_model = ( 60 | reference_model 61 | if reference_model is not None 62 | else np.zeros(n_params, dtype=np.float64) 63 | ) 64 | self.set_name("0") 65 | 66 | def __call__(self, model: Model) -> float: 67 | """ 68 | Evaluate the regularization on a given model. 69 | 70 | Parameters 71 | ---------- 72 | model : (n_params) array 73 | Array with model values. 74 | """ 75 | model_diff = model - self.reference_model 76 | weights_matrix = self.weights_matrix 77 | return model_diff.T @ weights_matrix.T @ weights_matrix @ model_diff 78 | 79 | def gradient(self, model: Model): 80 | """ 81 | Gradient vector. 82 | 83 | Parameters 84 | ---------- 85 | model : (n_params) array 86 | Array with model values. 87 | """ 88 | model_diff = model - self.reference_model 89 | weights_matrix = self.weights_matrix 90 | return 2 * weights_matrix.T @ weights_matrix @ model_diff 91 | 92 | def hessian(self, model: Model): # noqa: ARG002 93 | """ 94 | Hessian matrix. 95 | 96 | Parameters 97 | ---------- 98 | model : (n_params) array 99 | Array with model values. 100 | """ 101 | weights_matrix = self.weights_matrix 102 | return 2 * weights_matrix.T @ weights_matrix 103 | 104 | def hessian_diagonal(self, model: Model) -> npt.NDArray[np.float64]: 105 | """ 106 | Diagonal of the Hessian. 107 | 108 | Parameters 109 | ---------- 110 | model : (n_params) array 111 | Array with model values. 112 | """ 113 | return self.hessian(model).diagonal() 114 | 115 | @property 116 | def n_params(self): 117 | """ 118 | Number of model parameters. 119 | """ 120 | return self._n_params 121 | 122 | @property 123 | def weights(self) -> npt.NDArray[np.float64] | dict[str, npt.NDArray[np.float64]]: 124 | """ 125 | Regularization weights. 126 | """ 127 | return self._weights 128 | 129 | @weights.setter 130 | def weights( 131 | self, value: npt.NDArray[np.float64] | dict[str, npt.NDArray[np.float64]] 132 | ): 133 | """ 134 | Setter for weights. 135 | """ 136 | if not isinstance(value, np.ndarray | dict): 137 | msg = ( 138 | f"Invalid weights of type {type(value)}. " 139 | "It must be an array or a dictionary." 140 | ) 141 | raise TypeError(msg) 142 | self._weights = value 143 | 144 | @property 145 | def weights_matrix(self) -> dia_array: 146 | """ 147 | Diagonal matrix with the square root of the regularization weights. 148 | """ 149 | if isinstance(self.weights, np.ndarray): 150 | weights_array = self.weights 151 | elif isinstance(self.weights, dict): 152 | weights_array = prod_arrays(iter(self.weights.values())) 153 | else: 154 | msg = f"Invalid weights of type '{type(self.weights)}'." 155 | raise TypeError(msg) 156 | return diags_array(np.sqrt(weights_array)) 157 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Autogenerated version file 2 | **/_version.py 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # UV 101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | #uv.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 119 | .pdm.toml 120 | .pdm-python 121 | .pdm-build/ 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | 173 | # Abstra 174 | # Abstra is an AI-powered process automation framework. 175 | # Ignore directories containing user credentials, local state, and settings. 176 | # Learn more at https://abstra.io/docs 177 | .abstra/ 178 | 179 | # Visual Studio Code 180 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 181 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 182 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 183 | # you could uncomment the following to ignore the enitre vscode folder 184 | # .vscode/ 185 | 186 | # Ruff stuff: 187 | .ruff_cache/ 188 | 189 | # PyPI configuration file 190 | .pypirc 191 | 192 | # Cursor 193 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 194 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 195 | # refer to https://docs.cursor.com/context/ignore-files 196 | .cursorignore 197 | .cursorindexingignore 198 | -------------------------------------------------------------------------------- /src/inversion_ideas/base/conditions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base classes for defining conditions. 3 | 4 | Conditions are callable objects that return either a bool when either a certain 5 | condition is met or not. They are use to define abstract objects like stopping criteria 6 | for inversions. We can use binary operators (and, or, xor) to logically group multiple 7 | conditions together. 8 | """ 9 | 10 | from abc import ABC, abstractmethod 11 | 12 | from rich.panel import Panel 13 | from rich.tree import Tree 14 | 15 | from ..typing import Model 16 | 17 | 18 | def _get_info_title(condition, model) -> str: 19 | """ 20 | Generate title for condition's information. 21 | """ 22 | status = condition(model) 23 | checkbox = "x" if status else " " 24 | color = "green" if status else "red" 25 | text = rf"[bold {color}]\[{checkbox}] {type(condition).__name__}[/bold {color}]" 26 | return text 27 | 28 | 29 | class Condition(ABC): 30 | """ 31 | Base abstract class for conditions. 32 | """ 33 | 34 | @abstractmethod 35 | def __call__(self, model: Model) -> bool: ... 36 | 37 | def update(self, model: Model): # noqa: B027 38 | """ 39 | Update the condition. 40 | """ 41 | # This is not an abstract method. Children classes can choose to override it if 42 | # necessary. The base class implements it to provide a common interface, even 43 | # for those children that don't implement it. 44 | 45 | def initialize(self): # noqa: B027 46 | """ 47 | Initialize the condition. 48 | """ 49 | # This is not an abstract method. Children classes can choose to override it if 50 | # necessary. The base class implements it to provide a common interface, even 51 | # for those children that don't implement it. 52 | 53 | def info(self, model: Model) -> Tree: 54 | """ 55 | Display information about the condition for a given model. 56 | """ 57 | return Tree(_get_info_title(self, model)) 58 | 59 | def __and__(self, other) -> "LogicalAnd": 60 | return LogicalAnd(self, other) 61 | 62 | def __or__(self, other) -> "LogicalOr": 63 | return LogicalOr(self, other) 64 | 65 | def __xor__(self, other) -> "LogicalXor": 66 | return LogicalXor(self, other) 67 | 68 | def __iand__(self, other): 69 | msg = "Inplace AND binary operation is not supported for conditions." 70 | raise TypeError(msg) 71 | 72 | def __ior__(self, other): 73 | msg = "Inplace OR binary operation is not supported for conditions." 74 | raise TypeError(msg) 75 | 76 | def __ixor__(self, other): 77 | msg = "Inplace XOR binary operation is not supported for conditions." 78 | raise TypeError(msg) 79 | 80 | 81 | class _Mixin(ABC): 82 | """ 83 | Base class for Mixin classes. 84 | """ 85 | 86 | def __init__(self, condition_a, condition_b): 87 | self.condition_a = condition_a 88 | self.condition_b = condition_b 89 | 90 | @abstractmethod 91 | def __call__(self, model: Model) -> bool: ... 92 | 93 | def update(self, model: Model): 94 | """ 95 | Update the underlying conditions. 96 | """ 97 | for condition in (self.condition_a, self.condition_b): 98 | if hasattr(condition, "update"): 99 | condition.update(model) 100 | 101 | def info(self, model: Model) -> Tree: 102 | status = self(model) 103 | checkbox = "x" if status else " " 104 | color = "green" if status else "red" 105 | text = rf"[bold {color}]\[{checkbox}] {type(self).__name__}[/bold {color}]" 106 | tree = Tree(text, guide_style=color) 107 | for condition in (self.condition_a, self.condition_b): 108 | if hasattr(condition, "info"): 109 | subtree = condition.info(model) 110 | if isinstance(condition, _Mixin): 111 | tree.add(subtree) 112 | else: 113 | color = "green" if condition(model) else "red" 114 | tree.add(Panel(subtree, border_style=color)) 115 | else: 116 | raise NotImplementedError() 117 | return tree 118 | 119 | def initialize(self): 120 | """ 121 | Initialize the underlying conditions. 122 | """ 123 | for condition in (self.condition_a, self.condition_b): 124 | if hasattr(condition, "initialize"): 125 | condition.initialize() 126 | 127 | 128 | class LogicalAnd(_Mixin, Condition): 129 | """ 130 | Mixin condition for the AND operation between two other conditions. 131 | """ 132 | 133 | def __call__(self, model: Model) -> bool: 134 | return self.condition_a(model) and self.condition_b(model) 135 | 136 | 137 | class LogicalOr(_Mixin, Condition): 138 | """ 139 | Mixin condition for the OR operation between two other conditions. 140 | """ 141 | 142 | def __call__(self, model: Model) -> bool: 143 | return self.condition_a(model) or self.condition_b(model) 144 | 145 | 146 | class LogicalXor(_Mixin, Condition): 147 | """ 148 | Mixin condition for the XOR operation between two other conditions. 149 | """ 150 | 151 | def __call__(self, model: Model) -> bool: 152 | return self.condition_a(model) ^ self.condition_b(model) 153 | -------------------------------------------------------------------------------- /src/inversion_ideas/minimize/_minimizers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimizer classes. 3 | """ 4 | 5 | import warnings 6 | from collections.abc import Callable, Generator 7 | from typing import Any 8 | 9 | import numpy as np 10 | from scipy.sparse.linalg import cg 11 | 12 | from ..base import Condition, Minimizer, Objective 13 | from ..errors import ConvergenceWarning 14 | from ..typing import Model, Preconditioner 15 | from ..utils import get_logger 16 | from ._utils import backtracking_line_search 17 | 18 | 19 | class GaussNewtonConjugateGradient(Minimizer): 20 | """ 21 | Minimize non-linear objective functions using a Gauss-Newton Conjugate Gradient. 22 | 23 | Apply Gauss-Newton iterations using a Conjugate Gradient to find search directions, 24 | and use a backtracking line search to update the model after each iteration. 25 | 26 | Parameters 27 | ---------- 28 | maxiter : int, optional 29 | Maximum number of Gauss-Newton iterations. 30 | maxiter_line_search : int, optional 31 | Maximum number of line search iterations. 32 | rtol : float, optional 33 | Relative tolerance for the objective function. If the relative difference 34 | between the current and previous value of the objective function is below 35 | ``rtol``, then the minimization is considered as converged. 36 | stopping_criteria : Condition, Callable or None, optional 37 | Additional stopping condition that will make the Gauss-Newton iterations to 38 | finish. When a condition is passed, the Gauss-Newton iterations will finish if 39 | the condition is met, the Gauss-Newton converges (relative difference below 40 | ``rtol``), or if maximum number of iterations are reached. 41 | cg_kwargs : dict or None, optional 42 | Dictionary with extra arguments passed to the :func:`scipy.sparse.linalg.cg` 43 | function. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | *, 49 | maxiter: int = 100, 50 | maxiter_line_search: int = 10, 51 | rtol=1e-5, 52 | stopping_criteria: Condition | Callable[[Model], bool] | None = None, 53 | cg_kwargs: dict[str, Any] | None = None, 54 | ): 55 | self.maxiter = maxiter 56 | self.maxiter_line_search = maxiter_line_search 57 | self.rtol = rtol 58 | self.stopping_criteria = stopping_criteria 59 | self.cg_kwargs = cg_kwargs if cg_kwargs is not None else {} 60 | 61 | def __call__( 62 | self, 63 | objective: Objective, 64 | initial_model: Model, 65 | preconditioner: Preconditioner 66 | | Callable[[Model], Preconditioner] 67 | | None = None, 68 | ) -> Generator[Model]: 69 | """ 70 | Create iterator over Gauss-Newton minimization. 71 | """ 72 | # Define a static preconditioner for all Gauss-Newton iterations 73 | cg_kwargs = self.cg_kwargs.copy() 74 | 75 | if preconditioner is not None: 76 | if "M" in self.cg_kwargs: 77 | msg = "Cannot simultanously pass `preconditioner` and `M`." 78 | raise ValueError(msg) 79 | preconditioner = ( 80 | preconditioner 81 | if not callable(preconditioner) 82 | else preconditioner(initial_model) 83 | ) 84 | cg_kwargs["M"] = preconditioner 85 | 86 | # Perform Gauss-Newton iterations 87 | iteration = 0 88 | phi_prev_value = np.inf # value of the objective function on previous model 89 | model = initial_model.copy() 90 | 91 | # Yield initial model, so the generator is never empty 92 | yield model 93 | 94 | # Apply Gauss-Newton iterations 95 | while True: 96 | # Stop if reached max number of iterations 97 | if iteration >= self.maxiter: 98 | get_logger().info( 99 | "⚠️ Reached maximum number of Gauss-Newton iterations " 100 | f"({self.maxiter})." 101 | ) 102 | break 103 | 104 | # Check for convergence 105 | phi_value = objective(model) 106 | if ( 107 | not np.isinf(phi_prev_value) 108 | and np.abs(phi_value - phi_prev_value) <= phi_prev_value * self.rtol 109 | ): 110 | break 111 | 112 | # Check for stopping criteria 113 | if self.stopping_criteria is not None and self.stopping_criteria(model): 114 | break 115 | 116 | # Apply Conjugate Gradient to get search direction 117 | gradient, hessian = objective.gradient(model), objective.hessian(model) 118 | search_direction, info = cg(hessian, -gradient, **cg_kwargs) 119 | if info != 0: 120 | warnings.warn( 121 | "Conjugate gradient convergence to tolerance not achieved after " 122 | f"{info} number of iterations.", 123 | ConvergenceWarning, 124 | stacklevel=2, 125 | ) 126 | 127 | # Perform line search 128 | alpha, n_ls_iters = backtracking_line_search( 129 | objective, 130 | model, 131 | search_direction, 132 | phi_value=phi_value, 133 | phi_gradient=gradient, 134 | maxiter=self.maxiter_line_search, 135 | ) 136 | if alpha is None: 137 | msg = ( 138 | "Couldn't find a valid alpha, obtained None. " 139 | f"Ran {n_ls_iters} iterations." 140 | ) 141 | raise RuntimeError(msg) 142 | 143 | # Perform model step 144 | model += alpha * search_direction 145 | 146 | # Update cached values and iteration counter 147 | phi_prev_value = phi_value 148 | iteration += 1 149 | 150 | # Yield inverted model for the current Gauss-Newon iteration 151 | yield model 152 | -------------------------------------------------------------------------------- /src/inversion_ideas/data_misfit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class to represent a data misfit term. 3 | """ 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | from scipy.sparse import dia_array, diags_array, sparray 8 | from scipy.sparse.linalg import LinearOperator, aslinearoperator 9 | 10 | from .base import Objective 11 | from .typing import Model 12 | from .utils import cache_on_model 13 | 14 | 15 | class DataMisfit(Objective): 16 | r""" 17 | L2 data misfit. 18 | 19 | Parameters 20 | ---------- 21 | data : (n_data) array 22 | Array with observed data values. 23 | uncertainty : (n_data) array 24 | Array with data uncertainty. 25 | simulation : Simulation 26 | Instance of Simulation. 27 | cache : bool, optional 28 | Whether to cache the last result of the `__call__` method. 29 | Default to False. 30 | build_hessian : bool, optional 31 | If True, the ``hessian`` method will build the Hessian matrix and allocate it in 32 | memory. If False, the ``hessian`` method will return a linear operator that 33 | represents the Hessian matrix. Default to False. 34 | 35 | .. important:: 36 | 37 | Hessian matrices are usually very large. Use ``build_hessian=True`` only if 38 | you need to build it. 39 | 40 | Notes 41 | ----- 42 | The L2 data misfit objective function is defined as: 43 | 44 | .. math:: 45 | 46 | \phi_d(\mathbf{m}) = 47 | \sum\limits_{i=1}^N 48 | \frac{\lvert d_i^\text{obs} - f_i(\mathbf{m}) \rvert^2}{\epsilon_i^2} 49 | 50 | where :math:`\mathbf{m}` is the model vector, :math:`d_i^\text{obs}` is the 51 | :math:`i`-th observed datum, :math:`f_i(\mathbf{m})` is the forward modelling 52 | function for the :math:`i`-th datum, and :math:`\epsilon_i` is the uncertainty of 53 | the :math:`i`-th datum. 54 | 55 | The data misfit term can be expressed in terms of weights :math:`w_i 56 | = 1 / \epsilon_i^2`: 57 | 58 | .. math:: 59 | 60 | \phi_d(\mathbf{m}) = 61 | \sum\limits_{i=1}^N 62 | w_i \lvert d_i^\text{obs} - f_i(\mathbf{m}) \rvert^2 63 | 64 | And also in matrix form: 65 | 66 | .. math:: 67 | 68 | \phi_d(\mathbf{m}) = 69 | \lVert 70 | \mathbf{W} \left[ \mathbf{d}^\text{obs} - f(\mathbf{m}) \right] 71 | \rVert^2 72 | 73 | where :math:`\mathbf{W}` is a diagonal matrix with the square root of the weights, 74 | :math:`\mathbf{d}^\text{obs}` is the vector of observed data, and 75 | :math:`f(\mathbf{m})` is the forward modelling vector. 76 | 77 | """ 78 | 79 | def __init__( 80 | self, 81 | data: npt.NDArray[np.float64], 82 | uncertainty: npt.NDArray[np.float64], 83 | simulation, 84 | *, 85 | cache=False, 86 | build_hessian=False, 87 | ): 88 | # TODO: Check that the data and uncertainties have the size as ndata in the 89 | # simulation. 90 | self.data = data 91 | self.uncertainty = uncertainty 92 | self.simulation = simulation 93 | self.cache = cache 94 | self.build_hessian = build_hessian 95 | self.set_name("d") 96 | 97 | @cache_on_model 98 | def __call__(self, model: Model) -> float: 99 | # TODO: 100 | # Cache invalidation: we should clean the cache if data or uncertainties change. 101 | # Or they should be immutable. 102 | residual = self.residual(model) 103 | weights_matrix = self.weights_matrix 104 | return residual.T @ weights_matrix.T @ weights_matrix @ residual 105 | 106 | def gradient(self, model: Model) -> npt.NDArray[np.float64]: 107 | """ 108 | Gradient vector. 109 | """ 110 | jac = self.simulation.jacobian(model) 111 | weights_matrix = self.weights_matrix 112 | return 2 * jac.T @ (weights_matrix.T @ weights_matrix @ self.residual(model)) 113 | 114 | def hessian( 115 | self, model: Model 116 | ) -> npt.NDArray[np.float64] | sparray | LinearOperator: 117 | """ 118 | Hessian matrix. 119 | """ 120 | jac = self.simulation.jacobian(model) 121 | if not self.build_hessian: 122 | jac = aslinearoperator(jac) 123 | weights_matrix = aslinearoperator(self.weights_matrix) 124 | return 2 * jac.T @ weights_matrix.T @ weights_matrix @ jac 125 | 126 | def hessian_diagonal(self, model: Model) -> npt.NDArray[np.float64]: 127 | """ 128 | Diagonal of the Hessian. 129 | """ 130 | jac = self.simulation.jacobian(model) 131 | if isinstance(jac, LinearOperator): 132 | msg = ( 133 | "`DataMisfit.hessian_diagonal()` is not implemented for simulations " 134 | "that return the jacobian as a LinearOperator." 135 | ) 136 | raise NotImplementedError(msg) 137 | jtj_diag = np.einsum("i,ij,ij->j", self.weights_matrix.diagonal(), jac, jac) 138 | return 2 * jtj_diag 139 | 140 | @property 141 | def n_params(self): 142 | """ 143 | Number of model parameters. 144 | """ 145 | return self.simulation.n_params 146 | 147 | @property 148 | def n_data(self): 149 | """ 150 | Number of data values. 151 | """ 152 | return self.data.size 153 | 154 | def residual(self, model: Model): 155 | r""" 156 | Residual vector. 157 | 158 | Parameters 159 | ---------- 160 | model : (n_params) array 161 | Array with model values. 162 | 163 | Returns 164 | ------- 165 | (n_data) array 166 | Array with residual vector. 167 | 168 | Notes 169 | ----- 170 | Residual vector defined as: 171 | 172 | .. math:: 173 | 174 | \mathbf{r} = \mathcal{F}(\mathbf{m}) - \mathbf{d} 175 | 176 | where :math:`\mathbf{d}` is the vector with observed data, :math:`\mathcal{F}` 177 | is the forward model, and :math:`\mathbf{m}` is the model vector. 178 | """ 179 | return self.simulation(model) - self.data 180 | 181 | @property 182 | def weights(self) -> npt.NDArray[np.float64]: 183 | """ 184 | Data weights: 1D array with the square of the inverse of the uncertainties. 185 | """ 186 | return 1 / self.uncertainty**2 187 | 188 | @property 189 | def weights_matrix(self) -> dia_array: 190 | """ 191 | Diagonal matrix with the square root of the regularization weights. 192 | """ 193 | return diags_array(1 / self.uncertainty) 194 | 195 | def chi_factor(self, model: Model): 196 | """ 197 | Compute chi factor. 198 | 199 | Parameters 200 | ---------- 201 | model : (n_params) array 202 | Array with model values. 203 | 204 | Return 205 | ------ 206 | float 207 | Chi factor for the given model. 208 | """ 209 | return self(model) / self.n_data 210 | -------------------------------------------------------------------------------- /tests/test_conditions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test conditions. 3 | """ 4 | import pytest 5 | import numpy as np 6 | from inversion_ideas.base import Condition 7 | 8 | 9 | class Even(Condition): 10 | """ 11 | Simple condition that checks if model is even. 12 | """ 13 | 14 | def __call__(self, model) -> bool: 15 | return bool(np.all((model % 2) == 0)) 16 | 17 | 18 | class Positive(Condition): 19 | """ 20 | Simple condition that checks if model is positive. 21 | """ 22 | 23 | def __call__(self, model) -> bool: 24 | return bool(np.all(model > 0)) 25 | 26 | 27 | class TestMixin: 28 | """ 29 | Test mixin operations. 30 | """ 31 | 32 | @pytest.fixture(params=("class", "function")) 33 | def is_even(self, request): 34 | """ 35 | Return a condition that checks if model is even. 36 | 37 | Parametrize it to be either a function or a :class:`Condition`. 38 | """ 39 | if request.param == "function": 40 | 41 | def is_even(model) -> bool: 42 | return bool(np.all((model % 2) == 0)) 43 | 44 | return is_even 45 | return Even() 46 | 47 | def test_positive(self): 48 | is_positive = Positive() 49 | assert is_positive(1.0) 50 | assert is_positive(10.0) 51 | assert not is_positive(0.0) 52 | assert not is_positive(-2.0) 53 | 54 | def test_even(self, is_even): 55 | assert not is_even(1.0) 56 | assert is_even(2.0) 57 | assert not is_even(3.0) 58 | assert is_even(4.0) 59 | assert is_even(0.0) 60 | assert not is_even(-1.0) 61 | assert is_even(-2.0) 62 | assert not is_even(-3.0) 63 | assert is_even(-4.0) 64 | 65 | def test_and(self, is_even): 66 | is_even = Even() 67 | is_positive = Positive() 68 | condition = is_even & is_positive 69 | assert not condition(1.0) 70 | assert condition(2.0) 71 | assert not condition(3.0) 72 | assert condition(4.0) 73 | assert not condition(0.0) 74 | assert not condition(-1.0) 75 | assert not condition(-2.0) 76 | assert not condition(-3.0) 77 | assert not condition(-4.0) 78 | 79 | def test_or(self, is_even): 80 | is_even = Even() 81 | is_positive = Positive() 82 | condition = is_even | is_positive 83 | assert condition(1.0) 84 | assert condition(2.0) 85 | assert condition(3.0) 86 | assert condition(4.0) 87 | assert condition(0.0) 88 | assert not condition(-1.0) 89 | assert condition(-2.0) 90 | assert not condition(-3.0) 91 | assert condition(-4.0) 92 | 93 | def test_xor(self, is_even): 94 | is_even = Even() 95 | is_positive = Positive() 96 | condition = is_even ^ is_positive 97 | assert condition(1.0) 98 | assert not condition(2.0) 99 | assert condition(3.0) 100 | assert not condition(4.0) 101 | assert condition(0.0) 102 | assert not condition(-1.0) 103 | assert condition(-2.0) 104 | assert not condition(-3.0) 105 | assert condition(-4.0) 106 | 107 | 108 | class GreaterThan(Condition): 109 | def __init__(self, value): 110 | self.value = value 111 | 112 | def __call__(self, model) -> bool: 113 | return bool(np.all(model > self.value)) 114 | 115 | def update(self, model): 116 | self.value = model 117 | 118 | def initialize(self): 119 | self.value = None 120 | 121 | 122 | class UpdateMixin: 123 | """ 124 | Test updating conditions in mixins. 125 | """ 126 | 127 | def test_greater_than(self): 128 | condition = GreaterThan(2) 129 | assert condition(3) 130 | assert not condition(2) 131 | assert not condition(1) 132 | 133 | def test_update(self): 134 | condition = GreaterThan(2) 135 | new_value = 3 136 | condition.update(new_value) 137 | assert condition.value == new_value 138 | 139 | @pytest.mark.parametrize("operation", ["and", "or", "xor"]) 140 | def test_update_mixin(self, operation): 141 | condition_a = GreaterThan(2) 142 | condition_b = GreaterThan(3) 143 | match operation: 144 | case "and": 145 | condition = condition_a & condition_b 146 | case "or": 147 | condition = condition_a | condition_b 148 | case "xor": 149 | condition = condition_a ^ condition_b 150 | case _: 151 | msg = f"{operation}" 152 | raise ValueError(msg) 153 | new_value = 4 154 | condition.update(new_value) 155 | assert condition_a.value == new_value 156 | assert condition_b.value == new_value 157 | 158 | @pytest.mark.parametrize("operation", ["and", "or", "xor"]) 159 | def test_update_mixin_with_function(self, operation): 160 | """ 161 | Test if update works in case a condition is a function. 162 | """ 163 | 164 | def is_even(model) -> bool: 165 | return bool(np.all((model % 2) == 0)) 166 | 167 | condition_a = GreaterThan(2) 168 | match operation: 169 | case "and": 170 | condition = condition_a & is_even 171 | case "or": 172 | condition = condition_a | is_even 173 | case "xor": 174 | condition = condition_a ^ is_even 175 | case _: 176 | msg = f"{operation}" 177 | raise ValueError(msg) 178 | new_value = 4 179 | condition.update(new_value) 180 | assert condition_a.value == new_value 181 | assert condition.condition_b is is_even 182 | 183 | 184 | class InitializeMixin: 185 | """ 186 | Test initializing conditions in mixins. 187 | """ 188 | 189 | def test_initialize(self): 190 | condition = GreaterThan(2) 191 | condition.initialize() 192 | assert condition.value is None 193 | 194 | @pytest.mark.parametrize("operation", ["and", "or", "xor"]) 195 | def test_initialize_mixin(self, operation): 196 | condition_a = GreaterThan(2) 197 | condition_b = GreaterThan(3) 198 | match operation: 199 | case "and": 200 | condition = condition_a & condition_b 201 | case "or": 202 | condition = condition_a | condition_b 203 | case "xor": 204 | condition = condition_a ^ condition_b 205 | case _: 206 | msg = f"{operation}" 207 | raise ValueError(msg) 208 | condition.initialize() 209 | assert condition_a.value is None 210 | assert condition_b.value is None 211 | 212 | @pytest.mark.parametrize("operation", ["and", "or", "xor"]) 213 | def test_initialize_mixin_with_function(self, operation): 214 | """ 215 | Test if initialize works in case a condition is a function. 216 | """ 217 | 218 | def is_even(model) -> bool: 219 | return bool(np.all((model % 2) == 0)) 220 | 221 | condition_a = GreaterThan(2) 222 | match operation: 223 | case "and": 224 | condition = condition_a & is_even 225 | case "or": 226 | condition = condition_a | is_even 227 | case "xor": 228 | condition = condition_a ^ is_even 229 | case _: 230 | msg = f"{operation}" 231 | raise ValueError(msg) 232 | condition.initialize() 233 | assert condition_a.value is None 234 | assert condition.condition_b is is_even 235 | -------------------------------------------------------------------------------- /src/inversion_ideas/inversion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handler to run an inversion. 3 | 4 | The :class:`Inversion` class is intended to simplify the process of running a full 5 | inversion, given an objective function, a minimizer, a set of directives that can 6 | modify the objective function after each iteration and optionally a logger. 7 | """ 8 | 9 | import typing 10 | from collections.abc import Callable 11 | 12 | from .base import Condition, Directive, Minimizer, Objective 13 | from .inversion_log import InversionLog, InversionLogRich 14 | from .typing import Model 15 | from .utils import get_logger 16 | 17 | 18 | class Inversion: 19 | """ 20 | Inversion runner. 21 | 22 | Parameters 23 | ---------- 24 | objective_function : Objective 25 | Objective function to minimize. 26 | initial_model : (n_params) array 27 | Starting model for the inversion. 28 | minimizer : Minimizer or callable 29 | Instance of :class:`Minimizer` or callable used to minimize the objective 30 | function during the inversion. It must take the objective function and a model 31 | as arguments. 32 | directives : list of Directive 33 | List of ``Directive``s used to modify the objective function after each 34 | iteration. 35 | stopping_criteria : Condition or callable 36 | Boolean function that takes the model as argument. If this function returns 37 | ``True``, then the inversion will stop. 38 | max_iterations : int, optional 39 | Max amount of iterations that will be performed. If ``None``, then there will be 40 | no limit on the total amount of iterations. 41 | cache_models : bool, optional 42 | Whether to cache each model after each iteration. 43 | log : InversionLog or bool, optional 44 | Instance of :class:`InversionLog` to store information about the inversion. 45 | If `True`, a default :class:`InversionLog` is going to be used. 46 | If `False`, no log will be assigned to the inversion, and :attr:`Inversion.log` 47 | will be ``None``. 48 | minimizer_kwargs : dict, optional 49 | Extra arguments that will be passed to the ``minimizer`` when called. 50 | """ 51 | 52 | def __init__( 53 | self, 54 | objective_function: Objective, 55 | initial_model: Model, 56 | minimizer: Minimizer | Callable[[Objective, Model], Model], 57 | *, 58 | directives: typing.Sequence[Directive], 59 | stopping_criteria: Condition | Callable[[Model], bool], 60 | max_iterations: int | None = None, 61 | cache_models=False, 62 | log: "InversionLog | bool" = True, 63 | minimizer_kwargs: dict | None = None, 64 | ): 65 | self.objective_function = objective_function 66 | self.initial_model = initial_model 67 | self.minimizer = minimizer 68 | self.directives = directives 69 | self.stopping_criteria = stopping_criteria 70 | self.max_iterations = max_iterations 71 | self.cache_models = cache_models 72 | if minimizer_kwargs is None: 73 | minimizer_kwargs = {} 74 | self.minimizer_kwargs = minimizer_kwargs 75 | 76 | # Assign log 77 | if log is False: 78 | self.log = None 79 | elif log is True: 80 | # TODO: this could fail if the objective function is not 81 | # phi_d + beta * phi_m. We should try-error here maybe... 82 | self.log = InversionLogRich.create_from(self.objective_function) 83 | else: 84 | self.log = log 85 | 86 | # Assign model as a copy of the initial model 87 | self.model = initial_model.copy() 88 | 89 | def __next__(self): 90 | """ 91 | Run next iteration in the inversion. 92 | """ 93 | # Zeroth iteration 94 | if not hasattr(self, "_counter"): 95 | # Initialize counter to zero 96 | self._counter = 0 97 | 98 | # Add initial model to log (only on zeroth iteration) 99 | if self.log is not None: 100 | self.log.update(self.counter, self.model) 101 | 102 | # Initialize stopping criteria (if necessary) 103 | if hasattr(self.stopping_criteria, "initialize"): 104 | self.stopping_criteria.initialize() 105 | 106 | # Return the initial model in the zeroth iteration 107 | return self.model 108 | 109 | # Check for stopping criteria before trying to run the iteration 110 | if self.stopping_criteria(self.model): 111 | get_logger().info( 112 | "🎉 Inversion successfully finished due to stopping criteria." 113 | ) 114 | raise StopIteration 115 | 116 | # Check if maximum number of iterations have been reached 117 | if self.max_iterations is not None and self.counter >= self.max_iterations: 118 | get_logger().info( 119 | "⚠️ Inversion finished after reaching maximum number of iterations " 120 | f"({self.max_iterations})." 121 | ) 122 | raise StopIteration 123 | 124 | # Update stopping criteria (if necessary) 125 | if hasattr(self.stopping_criteria, "update"): 126 | self.stopping_criteria.update(self.model) 127 | 128 | # Increase counter by one 129 | self._counter += 1 130 | 131 | # Run directives (only after the first minimization). 132 | # We update the directives here (and not at the end of this method), so after 133 | # each iteration the objective function is still the same we passed to the 134 | # minimizer. 135 | if self._counter > 1: 136 | for directive in self.directives: 137 | directive(self.model, self.counter) 138 | 139 | # Minimize objective function 140 | if isinstance(self.minimizer, Minimizer): 141 | # Keep only the last model of the minimizer iterator 142 | *_, model = self.minimizer( 143 | self.objective_function, self.model, **self.minimizer_kwargs 144 | ) 145 | else: 146 | model = self.minimizer( 147 | self.objective_function, self.model, **self.minimizer_kwargs 148 | ) 149 | 150 | # Cache model if required 151 | if self.cache_models: 152 | self.models.append(model) 153 | 154 | # Assign the model to self 155 | self.model = model 156 | 157 | # Update log 158 | if self.log is not None: 159 | self.log.update(self.counter, self.model) 160 | 161 | return self.model 162 | 163 | def __iter__(self): 164 | return self 165 | 166 | @property 167 | def counter(self) -> int: 168 | """ 169 | Iteration counter. 170 | """ 171 | return self._counter 172 | 173 | @property 174 | def models(self) -> list: 175 | """ 176 | Cached inverted models. 177 | 178 | The first model in the list is the initial model, the one that corresponds to 179 | the zeroth iteration. 180 | """ 181 | if not self.cache_models: 182 | msg = "Inversion doesn't have cached models since `cache_model` is `False`." 183 | raise AttributeError(msg) 184 | if not hasattr(self, "_models"): 185 | self._models = [self.initial_model] 186 | return self._models 187 | 188 | def run(self, show_log=True) -> Model: 189 | """ 190 | Run the inversion. 191 | 192 | Parameters 193 | ---------- 194 | show_log : bool, optional 195 | Whether to show the ``log`` (if it's defined) during the inversion. 196 | """ 197 | if show_log and self.log is not None: 198 | if not hasattr(self.log, "live"): 199 | raise NotImplementedError() 200 | with self.log.live() as live: 201 | for _ in self: 202 | live.refresh() 203 | else: 204 | for _ in self: 205 | pass 206 | return self.model 207 | -------------------------------------------------------------------------------- /src/inversion_ideas/inversion_log.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes for inversion logs. 3 | """ 4 | 5 | import numbers 6 | import typing 7 | from collections.abc import Callable, Iterable 8 | 9 | from rich.console import Console 10 | from rich.live import Live 11 | from rich.table import Table 12 | 13 | try: 14 | import pandas # noqa: ICN001 15 | except ImportError: 16 | pandas = None 17 | 18 | from .base import Combo 19 | from .typing import Model 20 | 21 | 22 | class Column(typing.NamedTuple): 23 | """ 24 | Column for the ``InversionLog``. 25 | """ 26 | 27 | title: str 28 | callable: Callable[[int, Model], typing.Any] 29 | fmt: str | None 30 | 31 | 32 | class InversionLog: 33 | """ 34 | Log the outputs of an inversion. 35 | 36 | Parameters 37 | ---------- 38 | columns : dict 39 | Dictionary with specification for the columns of the log table. 40 | The keys are the column titles as strings. The values can be callables that will 41 | be used to generate the value for each row and column, or ``Column``. Each 42 | callable should take two arguments: ``iteration`` (an integer with the number 43 | of the iteration) and ``model`` (the inverted model as a 1d array). 44 | """ 45 | 46 | def __init__( 47 | self, columns: typing.Mapping[str, Column | Callable[[int, Model], typing.Any]] 48 | ): 49 | for name, column in columns.items(): 50 | self.add_column(name, column) 51 | 52 | @property 53 | def has_records(self) -> bool: 54 | """ 55 | Whether the log has recorded values or not. 56 | """ 57 | if not hasattr(self, "_log"): 58 | return False 59 | has_records = any(bool(c) for c in self.log.values()) 60 | return has_records 61 | 62 | def add_column( 63 | self, name: str, column: Column | Callable[[int, Model], typing.Any] 64 | ) -> typing.Self: 65 | """ 66 | Add column to the log. 67 | 68 | Parameters 69 | ---------- 70 | name : str 71 | Name of the column, used in the :attr:`InversionLog.log` dictionary to 72 | access the recorded values. 73 | column : Callable | Column 74 | A callable that takes the ``iteration`` and the ``model`` as arguments, or 75 | a ``Column``. 76 | 77 | Returns 78 | ------- 79 | self 80 | """ 81 | if self.has_records: 82 | msg = ( 83 | f"{type(self).__name__} has records. " 84 | "No column can be added after the log has already started " 85 | "recording values." 86 | ) 87 | raise TypeError(msg) 88 | 89 | if not hasattr(self, "_columns"): 90 | self._columns: dict[str, Column] = {} 91 | 92 | if callable(column): 93 | column = Column(title=name, callable=column, fmt=None) 94 | 95 | self._columns[name] = column 96 | return self 97 | 98 | @property 99 | def columns(self) -> dict[str, Column]: 100 | """ 101 | Column specifiers. 102 | """ 103 | return self._columns 104 | 105 | @property 106 | def log(self) -> dict[str, list]: 107 | """ 108 | Inversion log. 109 | """ 110 | if not hasattr(self, "_log"): 111 | self._log: dict[str, list] = {col: [] for col in self.columns} 112 | return self._log 113 | 114 | def update(self, iteration: int, model: Model): 115 | """ 116 | Update the log. 117 | """ 118 | for name, column in self.columns.items(): 119 | self.log[name].append(column.callable(iteration, model)) 120 | 121 | def to_pandas(self, index_col=0): 122 | """ 123 | Generate a ``pandas.DataFrame`` out of the log. 124 | """ 125 | if pandas is None: 126 | msg = "Pandas is missing." 127 | raise ImportError(msg) 128 | index = list(self.log.keys())[index_col] 129 | return pandas.DataFrame(self.log).set_index(index) 130 | 131 | @classmethod 132 | def create_from(cls, objective_function: Combo) -> typing.Self: 133 | r""" 134 | Create the standard log for a classic inversion. 135 | 136 | Parameters 137 | ---------- 138 | objective_function : Combo 139 | Combo objective function with two elements: the data misfit and the 140 | regularization (including a trade-off parameter). 141 | 142 | Returns 143 | ------- 144 | Self 145 | 146 | Notes 147 | ----- 148 | The objective function should be of the type: 149 | 150 | .. math:: 151 | 152 | \phi(\mathbf{m}) = \phi_d(\mathbf{m}) + \beta \phi_m(\mathbf{m}) 153 | 154 | where :math:`\phi_d(m)` is the data misfit term, :math:`\phi_m(\mathbf{m})` is 155 | the model norm, and :math:`\beta` is the trade-off parameter. 156 | """ 157 | # TODO: write proper error messages 158 | assert len(objective_function) == 2 159 | data_misfit = objective_function[0] 160 | assert not hasattr(data_misfit, "multiplier") 161 | assert not isinstance(data_misfit, Iterable) 162 | regularization = objective_function[1] 163 | assert hasattr(regularization, "multiplier") 164 | 165 | columns = { 166 | "iter": Column( 167 | title="Iteration", callable=lambda iteration, _: iteration, fmt="d" 168 | ), 169 | "beta": Column( 170 | title="β", callable=lambda _, __: regularization.multiplier, fmt=".2e" 171 | ), 172 | "phi_d": Column( 173 | title="φ_d", callable=lambda _, model: data_misfit(model), fmt=".2e" 174 | ), 175 | "phi_m": Column( 176 | title="φ_m", 177 | callable=lambda _, model: regularization.function(model), 178 | fmt=".2e", 179 | ), 180 | "beta * phi_m": Column( 181 | title="β φ_m", 182 | callable=lambda _, model: regularization(model), 183 | fmt=".2e", 184 | ), 185 | "phi": Column( 186 | title="φ", 187 | callable=lambda _, model: objective_function(model), 188 | fmt=".2e", 189 | ), 190 | "chi": Column( 191 | title="χ", 192 | callable=lambda _, model: data_misfit(model) / data_misfit.n_data, 193 | fmt=".2e", 194 | ), 195 | } 196 | return cls(columns) 197 | 198 | 199 | class InversionLogRich(InversionLog): 200 | """ 201 | Log the outputs of an inversion. 202 | 203 | Parameters 204 | ---------- 205 | columns : dict 206 | Dictionary with specification for the columns of the log table. 207 | The keys are the column titles as strings. The values are callables that will be 208 | used to generate the value for each row and column. Each callable should take 209 | two arguments: ``iteration`` (an integer with the number of the iteration) and 210 | ``model`` (the inverted model as a 1d array). 211 | kwargs : 212 | Pass extra options to :class:`rich.table.Table`. 213 | """ 214 | 215 | def __init__(self, columns: dict[str, Callable | Column], **kwargs): 216 | super().__init__(columns) 217 | self.kwargs = kwargs 218 | 219 | @property 220 | def table(self) -> Table: 221 | """ 222 | Table for the inversion log. 223 | """ 224 | if not hasattr(self, "_table"): 225 | self._table = Table(**self.kwargs) 226 | for column in self.columns.values(): 227 | self._table.add_column(column.title) 228 | return self._table 229 | 230 | def show(self): 231 | """ 232 | Show table. 233 | """ 234 | console = Console() 235 | console.print(self.table) 236 | 237 | def live(self, **kwargs): 238 | """ 239 | Context manager for live update of the table. 240 | """ 241 | return Live(self.table, **kwargs) 242 | 243 | def update_table(self): 244 | """ 245 | Add row to the table given the latest inverted model. 246 | 247 | Parameters 248 | ---------- 249 | model : (n_params) array 250 | """ 251 | row = [] 252 | for name, column in self.columns.items(): 253 | value = self.log[name][-1] # last element in the log 254 | fmt = column.fmt if column.fmt is not None else self._get_fmt(value) 255 | row.append(f"{value:{fmt}}") 256 | self.table.add_row(*row) 257 | 258 | def _get_fmt(self, value): 259 | if isinstance(value, bool): 260 | fmt = "" 261 | elif isinstance(value, numbers.Integral): 262 | fmt = "d" 263 | elif isinstance(value, numbers.Real): 264 | fmt = ".2e" 265 | else: 266 | fmt = "" 267 | return fmt 268 | 269 | def update(self, iteration: int, model: Model): 270 | """ 271 | Update the log. 272 | 273 | Update the table as well. 274 | """ 275 | super().update(iteration, model) 276 | self.update_table() 277 | -------------------------------------------------------------------------------- /src/inversion_ideas/conditions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions and callable classes that define conditions. 3 | 4 | Use these objects as stopping criteria for inversions. 5 | """ 6 | 7 | from collections.abc import Callable 8 | 9 | import numpy as np 10 | from rich.tree import Tree 11 | 12 | from inversion_ideas.data_misfit import DataMisfit 13 | 14 | from .base import Condition, Objective 15 | from .typing import Model 16 | 17 | 18 | class CustomCondition(Condition): 19 | """ 20 | Define a custom :class:`~inversion_ideas.base.Condition` object through a function. 21 | 22 | Parameters 23 | ---------- 24 | func : Callable 25 | Function to use as the condition. It should take a model as only argument 26 | and return a bool to evaluate whether the condition is valid or not for that 27 | particular model. 28 | """ 29 | 30 | def __init__(self, func: Callable[[Model], bool]): 31 | self.func = func 32 | 33 | def __call__(self, model: Model) -> bool: 34 | return self.func(model) 35 | 36 | @classmethod 37 | def create(cls, func: Callable[[Model], bool]): 38 | """ 39 | Create a ``CustomCondition`` object directly from a function. 40 | 41 | Parameters 42 | ---------- 43 | func : Callable 44 | Function to use as the condition. It should take a model as only argument 45 | and return a bool to evaluate whether the condition is valid or not for that 46 | particular model. 47 | 48 | Returns 49 | ------- 50 | CustomCondition 51 | 52 | Examples 53 | -------- 54 | >>> import numpy as np 55 | >>> 56 | >>> def is_all_positive(model): 57 | ... return np.all(model > 0) 58 | >>> 59 | >>> condition = CustomCondition.create(is_all_positive) 60 | >>> model = np.array([1, 2, 3]) 61 | >>> print(condition(model)) 62 | True 63 | 64 | >>> model = np.array([-1, 2, 3]) 65 | >>> print(condition(model)) 66 | False 67 | """ 68 | return cls(func) 69 | 70 | 71 | class ChiTarget(Condition): 72 | """ 73 | Stopping criteria for when chi factor meets the target. 74 | 75 | Parameters 76 | ---------- 77 | data_misfit : DataMisfit 78 | Data misfit term to be evaluated. 79 | chi_target : float 80 | Target for the chi factor. 81 | """ 82 | 83 | def __init__(self, data_misfit: DataMisfit, chi_target=1.0): 84 | if not hasattr(data_misfit, "chi_factor"): 85 | msg = "Invalid `data_misfit`: missing `chi_factor` method." 86 | raise TypeError(msg) 87 | self.data_misfit = data_misfit 88 | self.chi_target = chi_target 89 | 90 | def __call__(self, model: Model) -> bool: 91 | """ 92 | Check if condition has been met. 93 | """ 94 | chi = self.data_misfit.chi_factor(model) 95 | return float(chi) < self.chi_target 96 | 97 | def info(self, model: Model) -> Tree: 98 | tree = super().info(model) 99 | tree.add("Condition: chi < chi_target") 100 | tree.add(f"chi = {self.data_misfit.chi_factor(model):.2e}") 101 | tree.add(f"chi_target = {self.chi_target:.2e}") 102 | return tree 103 | 104 | 105 | class ModelChanged(Condition): 106 | r""" 107 | Stopping criteria for when model didn't changed above tolerance. 108 | 109 | Parameters 110 | ---------- 111 | rtol : float, optional 112 | Relative tolerance below which the model will be considered of not changing 113 | enough. 114 | atol : float, optional 115 | Absolute tolerance below which the model will be considered of not changing 116 | enough. 117 | 118 | Notes 119 | ----- 120 | The stopping criteria evaluates: 121 | 122 | .. math:: 123 | 124 | \frac{ 125 | \lVert \mathbf{m} - \mathbf{m}_\text{prev} \rVert_2 126 | }{ 127 | \lVert \mathbf{m}_\text{old} \rVert_2 128 | } 129 | \le \delta_r, 130 | 131 | and 132 | 133 | .. math:: 134 | 135 | \lVert \mathbf{m} - \mathbf{m}_\text{prev} \rVert_2 \le \delta_a, 136 | 137 | where :math:`\mathbf{m}` is the current model, :math:`\mathbf{m}_\text{prev}` is the 138 | previous model in the inversion, :math:`\lVert \cdot \rVert_2` represents an 139 | :math:`l_2` norm, and :math:`\delta_r` and :math:`\delta_a` are the relative and 140 | absolute tolerances whose values are given 141 | by ``rtol`` and ``atol``, respectively. 142 | 143 | When called, if any of those inequalities hold, the stopping criteria will return 144 | ``True``, and ``False`` otherwise. 145 | """ 146 | 147 | def __init__(self, rtol: float = 1e-3, atol: float = 0.0): 148 | self.rtol = rtol 149 | self.atol = atol 150 | 151 | def __call__(self, model: Model) -> bool: 152 | if not hasattr(self, "previous"): 153 | return False 154 | diff = float(np.linalg.norm(model - self.previous)) 155 | previous = float(np.linalg.norm(self.previous)) 156 | return diff <= max(previous * self.rtol, self.atol) 157 | 158 | def update(self, model: Model): 159 | """ 160 | Cache model as the ``previous`` one. 161 | """ 162 | self.previous = model 163 | 164 | def info(self, model: Model) -> Tree: 165 | tree = super().info(model) 166 | diff = float(np.linalg.norm(model - self.previous)) 167 | previous = float(np.linalg.norm(self.previous)) 168 | tree.add("Condition: |m - m_prev| <= max(|m_prev| * rtol, atol)") 169 | tree.add(f"|m - m_prev| = {diff:.2e}") 170 | tree.add(f"|m_prev| = {previous:.2e}") 171 | tree.add(f"rtol = {self.rtol:.2e}") 172 | tree.add(f"atol = {self.atol:.2e}") 173 | return tree 174 | 175 | def initialize(self): 176 | """ 177 | Initialize condition and clean ``previous`` attribute. 178 | """ 179 | attr = "previous" 180 | if hasattr(self, attr): 181 | delattr(self, attr) 182 | 183 | 184 | class ObjectiveChanged(Condition): 185 | r""" 186 | Stopping criteria for when an objective function didn't changed above a tolerance. 187 | 188 | Parameters 189 | ---------- 190 | objective_function : Objective 191 | Objective function that will be evaluated. 192 | rtol : float, optional 193 | Relative tolerance below which the model will be considered of not changing 194 | enough. 195 | atol : float, optional 196 | Absolute tolerance below which the model will be considered of not changing 197 | enough. 198 | 199 | Notes 200 | ----- 201 | The stopping criteria evaluates: 202 | 203 | .. math:: 204 | 205 | \frac{ 206 | | \phi(\mathbf{m}) - \phi(\mathbf{m}_\text{old}) | 207 | }{ 208 | | \phi(\mathbf{m}_\text{old}) | 209 | } 210 | \le \delta_r, 211 | 212 | and 213 | 214 | .. math:: 215 | 216 | | \phi(\mathbf{m}) - \phi(\mathbf{m}_\text{old}) | \le \delta_a, 217 | 218 | where :math:`\phi`, is the objective function, :math:`\mathbf{m}` is the current 219 | model, :math:`\mathbf{m}_\text{old}` is the previous model in the inversion, 220 | and :math:`\delta_r` and :math:`\delta_a` are the relative and absolute tolerances 221 | whose values are given by ``rtol`` and ``atol``, respectively. 222 | 223 | When called, if any of those inequalities hold, the stopping criteria will return 224 | ``True``, and ``False`` otherwise. 225 | """ 226 | 227 | def __init__(self, objective_function: Objective, rtol: float = 1e-3, atol=0.0): 228 | self.objective_function = objective_function 229 | self.rtol = rtol 230 | self.atol = atol 231 | 232 | def __call__(self, model: Model) -> bool: 233 | if not hasattr(self, "previous"): 234 | return False 235 | diff = abs(self.objective_function(model) - self.previous) 236 | previous = abs(self.previous) 237 | return diff <= max(previous * self.rtol, self.atol) 238 | 239 | def update(self, model: Model): 240 | """ 241 | Cache value of objective function with model as the ``previous`` one. 242 | """ 243 | self.previous: float = float(self.objective_function(model)) 244 | 245 | def info(self, model: Model) -> Tree: 246 | tree = super().info(model) 247 | diff = abs(self.objective_function(model) - self.previous) 248 | previous = abs(self.previous) 249 | tree.add("Condition: |φ(m) - φ(m_prev)| <= max(|φ(m_prev)| * rtol, atol)") 250 | tree.add(f"|φ(m) - φ(m_prev)| = {diff:.2e}") 251 | tree.add(f"|φ(m_prev)| = {previous:.2e}") 252 | tree.add(f"rtol = {self.rtol:.2e}") 253 | tree.add(f"atol = {self.atol:.2e}") 254 | return tree 255 | 256 | def ratio(self, model: Model) -> float: 257 | """ 258 | Ratio ``|φ(m) - φ(m_prev)|/|φ(m_prev)|``. 259 | """ 260 | if not hasattr(self, "previous"): 261 | return np.nan 262 | diff = abs(self.objective_function(model) - self.previous) 263 | previous = abs(self.previous) 264 | if previous == 0.0: 265 | return np.inf 266 | return diff / previous 267 | 268 | def initialize(self): 269 | """ 270 | Initialize condition and clean ``previous`` attribute. 271 | """ 272 | attr = "previous" 273 | if hasattr(self, attr): 274 | delattr(self, attr) 275 | -------------------------------------------------------------------------------- /notebooks/05_caching-data-misfit-values.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "0707778e-d975-46b6-b670-d461e53e5b10", 7 | "metadata": { 8 | "execution": { 9 | "iopub.execute_input": "2025-10-07T20:46:21.008082Z", 10 | "iopub.status.busy": "2025-10-07T20:46:21.007198Z", 11 | "iopub.status.idle": "2025-10-07T20:46:22.891408Z", 12 | "shell.execute_reply": "2025-10-07T20:46:22.890708Z", 13 | "shell.execute_reply.started": "2025-10-07T20:46:21.007999Z" 14 | } 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "from regressor import LinearRegressor\n", 20 | "\n", 21 | "import inversion_ideas as ii" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "6ca1c237-f423-4f85-b19c-2d8fc00ac6e1", 27 | "metadata": {}, 28 | "source": [ 29 | "## Create a true model and synthetic data for a linear regressor" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "179250ca-3da3-4a82-8573-7a285246147b", 36 | "metadata": { 37 | "execution": { 38 | "iopub.execute_input": "2025-10-07T20:46:22.892763Z", 39 | "iopub.status.busy": "2025-10-07T20:46:22.892423Z", 40 | "iopub.status.idle": "2025-10-07T20:46:22.899942Z", 41 | "shell.execute_reply": "2025-10-07T20:46:22.899275Z", 42 | "shell.execute_reply.started": "2025-10-07T20:46:22.892741Z" 43 | } 44 | }, 45 | "outputs": [ 46 | { 47 | "data": { 48 | "text/plain": [ 49 | "array([0.78225148, 0.67148671, 0.2373809 , 0.17946133, 0.34662367,\n", 50 | " 0.15210999, 0.31142952, 0.23900652, 0.54355731, 0.91770851])" 51 | ] 52 | }, 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "n_params = 10\n", 60 | "rng = np.random.default_rng(seed=4242)\n", 61 | "true_model = rng.uniform(size=10)\n", 62 | "true_model" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "id": "a477ebd0-8ca7-4cd2-b17d-30c21dda08c6", 69 | "metadata": { 70 | "execution": { 71 | "iopub.execute_input": "2025-10-07T20:46:22.900936Z", 72 | "iopub.status.busy": "2025-10-07T20:46:22.900641Z", 73 | "iopub.status.idle": "2025-10-07T20:46:22.913564Z", 74 | "shell.execute_reply": "2025-10-07T20:46:22.912862Z", 75 | "shell.execute_reply.started": "2025-10-07T20:46:22.900914Z" 76 | } 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "# Build the X array\n", 81 | "n_data = 25\n", 82 | "shape = (n_data, n_params)\n", 83 | "X = rng.uniform(size=n_data * n_params).reshape(shape)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 4, 89 | "id": "7cc7a7bc-804b-48ea-bdb0-ebf4fa9a8e2e", 90 | "metadata": { 91 | "execution": { 92 | "iopub.execute_input": "2025-10-07T20:46:22.915218Z", 93 | "iopub.status.busy": "2025-10-07T20:46:22.914855Z", 94 | "iopub.status.idle": "2025-10-07T20:46:22.924316Z", 95 | "shell.execute_reply": "2025-10-07T20:46:22.923633Z", 96 | "shell.execute_reply.started": "2025-10-07T20:46:22.915183Z" 97 | } 98 | }, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "array([2.83840696, 2.18091081, 2.00623242, 2.08333039, 2.01694883,\n", 104 | " 2.7826232 , 2.10564027, 1.27333506, 2.08859855, 1.94177648,\n", 105 | " 1.88492037, 2.92394733, 2.17231952, 3.08009275, 1.61670886,\n", 106 | " 1.77403753, 2.67305005, 1.91413882, 2.42117827, 2.13991628,\n", 107 | " 2.0153805 , 2.71388471, 2.65944255, 2.44416121, 3.14217523])" 108 | ] 109 | }, 110 | "execution_count": 4, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "synthetic_data = X @ true_model\n", 117 | "maxabs = np.max(np.abs(synthetic_data))\n", 118 | "noise = rng.normal(scale=1e-2 * maxabs, size=synthetic_data.size)\n", 119 | "synthetic_data += noise\n", 120 | "synthetic_data" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "b47ed2d3-0d71-4ead-8230-6df11d5833a9", 126 | "metadata": {}, 127 | "source": [ 128 | "## Define objective function\n", 129 | "\n", 130 | "Enable caching in the data misfit." 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 5, 136 | "id": "6ec35263-f754-4931-b577-b3d4f0fc803a", 137 | "metadata": { 138 | "execution": { 139 | "iopub.execute_input": "2025-10-07T20:46:24.335249Z", 140 | "iopub.status.busy": "2025-10-07T20:46:24.334360Z", 141 | "iopub.status.idle": "2025-10-07T20:46:24.345904Z", 142 | "shell.execute_reply": "2025-10-07T20:46:24.343849Z", 143 | "shell.execute_reply.started": "2025-10-07T20:46:24.335172Z" 144 | } 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "uncertainty = 1e-2 * maxabs * np.ones_like(synthetic_data)\n", 149 | "simulation = LinearRegressor(X)\n", 150 | "data_misfit = ii.DataMisfit(synthetic_data, uncertainty, simulation, cache=True)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 6, 156 | "id": "050b9c20-4b0a-49af-a900-1bb964588408", 157 | "metadata": { 158 | "execution": { 159 | "iopub.execute_input": "2025-10-07T20:46:25.260465Z", 160 | "iopub.status.busy": "2025-10-07T20:46:25.259017Z", 161 | "iopub.status.idle": "2025-10-07T20:46:25.266536Z", 162 | "shell.execute_reply": "2025-10-07T20:46:25.265094Z", 163 | "shell.execute_reply.started": "2025-10-07T20:46:25.260410Z" 164 | } 165 | }, 166 | "outputs": [], 167 | "source": [ 168 | "smallness = ii.TikhonovZero(n_params)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 7, 174 | "id": "2cf726db-4849-46c3-9fca-41b5650b88b8", 175 | "metadata": { 176 | "execution": { 177 | "iopub.execute_input": "2025-10-07T20:46:25.693809Z", 178 | "iopub.status.busy": "2025-10-07T20:46:25.692963Z", 179 | "iopub.status.idle": "2025-10-07T20:46:25.705581Z", 180 | "shell.execute_reply": "2025-10-07T20:46:25.704708Z", 181 | "shell.execute_reply.started": "2025-10-07T20:46:25.693741Z" 182 | } 183 | }, 184 | "outputs": [ 185 | { 186 | "data": { 187 | "text/plain": [ 188 | "np.float64(18.545642205146727)" 189 | ] 190 | }, 191 | "execution_count": 7, 192 | "metadata": {}, 193 | "output_type": "execute_result" 194 | } 195 | ], 196 | "source": [ 197 | "data_misfit(true_model)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 8, 203 | "id": "8d201eea-294d-490f-8b76-12024acc0808", 204 | "metadata": { 205 | "execution": { 206 | "iopub.execute_input": "2025-10-07T20:46:26.020342Z", 207 | "iopub.status.busy": "2025-10-07T20:46:26.019645Z", 208 | "iopub.status.idle": "2025-10-07T20:46:26.027836Z", 209 | "shell.execute_reply": "2025-10-07T20:46:26.026962Z", 210 | "shell.execute_reply.started": "2025-10-07T20:46:26.020281Z" 211 | } 212 | }, 213 | "outputs": [ 214 | { 215 | "data": { 216 | "text/plain": [ 217 | "np.float64(18.545642205146727)" 218 | ] 219 | }, 220 | "execution_count": 8, 221 | "metadata": {}, 222 | "output_type": "execute_result" 223 | } 224 | ], 225 | "source": [ 226 | "data_misfit(true_model)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 9, 232 | "id": "da1db943-236d-43da-b0b0-e776d21931af", 233 | "metadata": { 234 | "execution": { 235 | "iopub.execute_input": "2025-10-07T20:46:26.500687Z", 236 | "iopub.status.busy": "2025-10-07T20:46:26.500211Z", 237 | "iopub.status.idle": "2025-10-07T20:46:26.508770Z", 238 | "shell.execute_reply": "2025-10-07T20:46:26.507740Z", 239 | "shell.execute_reply.started": "2025-10-07T20:46:26.500642Z" 240 | } 241 | }, 242 | "outputs": [ 243 | { 244 | "data": { 245 | "text/latex": [ 246 | "$ \\phi_{d} (m) + 1.00 \\cdot 10^{-3} \\, \\phi_{0} (m) $" 247 | ], 248 | "text/plain": [ 249 | "φd(m) + 0.00 φ0(m)" 250 | ] 251 | }, 252 | "execution_count": 9, 253 | "metadata": {}, 254 | "output_type": "execute_result" 255 | } 256 | ], 257 | "source": [ 258 | "phi = data_misfit + 1e-3 * smallness\n", 259 | "phi" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 10, 265 | "id": "9032d9d4-9c23-461b-bbc8-f172afdfd9ac", 266 | "metadata": { 267 | "execution": { 268 | "iopub.execute_input": "2025-10-07T20:46:27.676505Z", 269 | "iopub.status.busy": "2025-10-07T20:46:27.676244Z", 270 | "iopub.status.idle": "2025-10-07T20:46:27.682336Z", 271 | "shell.execute_reply": "2025-10-07T20:46:27.681603Z", 272 | "shell.execute_reply.started": "2025-10-07T20:46:27.676483Z" 273 | } 274 | }, 275 | "outputs": [ 276 | { 277 | "data": { 278 | "text/plain": [ 279 | "np.float64(18.54822861431746)" 280 | ] 281 | }, 282 | "execution_count": 10, 283 | "metadata": {}, 284 | "output_type": "execute_result" 285 | } 286 | ], 287 | "source": [ 288 | "phi(true_model)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 11, 294 | "id": "00c6e189-9deb-4cdf-8b78-887f0d8e44b9", 295 | "metadata": { 296 | "execution": { 297 | "iopub.execute_input": "2025-10-07T20:46:30.833424Z", 298 | "iopub.status.busy": "2025-10-07T20:46:30.832446Z", 299 | "iopub.status.idle": "2025-10-07T20:46:30.850636Z", 300 | "shell.execute_reply": "2025-10-07T20:46:30.849001Z", 301 | "shell.execute_reply.started": "2025-10-07T20:46:30.833348Z" 302 | } 303 | }, 304 | "outputs": [ 305 | { 306 | "data": { 307 | "text/plain": [ 308 | "np.float64(5671.840107322676)" 309 | ] 310 | }, 311 | "execution_count": 11, 312 | "metadata": {}, 313 | "output_type": "execute_result" 314 | } 315 | ], 316 | "source": [ 317 | "new_model = rng.uniform(size=n_params)\n", 318 | "\n", 319 | "phi(new_model)" 320 | ] 321 | } 322 | ], 323 | "metadata": { 324 | "kernelspec": { 325 | "display_name": "Python [conda env:inversion_ideas]", 326 | "language": "python", 327 | "name": "conda-env-inversion_ideas-py" 328 | }, 329 | "language_info": { 330 | "codemirror_mode": { 331 | "name": "ipython", 332 | "version": 3 333 | }, 334 | "file_extension": ".py", 335 | "mimetype": "text/x-python", 336 | "name": "python", 337 | "nbconvert_exporter": "python", 338 | "pygments_lexer": "ipython3", 339 | "version": "3.13.5" 340 | } 341 | }, 342 | "nbformat": 4, 343 | "nbformat_minor": 5 344 | } 345 | -------------------------------------------------------------------------------- /tests/test_objective_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test operations for objective functions. 3 | """ 4 | import pytest 5 | import numpy as np 6 | from scipy.sparse import diags_array, sparray 7 | from scipy.sparse.linalg import aslinearoperator, LinearOperator 8 | 9 | from inversion_ideas.base.objective_function import _sum 10 | from inversion_ideas.base import Objective, Combo, Scaled 11 | 12 | 13 | class TestSum: 14 | """ 15 | Test custom sum for operators. 16 | 17 | Test cases: 18 | 19 | - All arrays, should return array. 20 | - One sparse array, should return array. 21 | - All sparse arrays, should return sparse array. 22 | - One linear operator, should return linear operator. 23 | - Put the special objects ones in the beginning and in the middle of the 24 | generator. 25 | """ 26 | 27 | shape = (25, 10) 28 | 29 | @pytest.fixture 30 | def matrices(self): 31 | seeds = (40, 41, 42) 32 | a, b, c = tuple( 33 | np.random.default_rng(seed=seed).uniform(size=self.shape) for seed in seeds 34 | ) 35 | return a, b, c 36 | 37 | @pytest.fixture 38 | def sparse_arrays(self): 39 | seeds = (40, 41, 42) 40 | a, b, c = tuple( 41 | diags_array( 42 | np.random.default_rng(seed=seed).uniform(size=self.shape[0]), 43 | shape=self.shape, 44 | ) 45 | for seed in seeds 46 | ) 47 | return a, b, c 48 | 49 | @pytest.fixture 50 | def vector(self): 51 | return np.random.default_rng(seed=43).uniform(size=self.shape[1]) 52 | 53 | def test_all_arrays(self, matrices): 54 | # Get the sum 55 | result = _sum(op for op in matrices) 56 | 57 | # We should recover a dense array 58 | assert isinstance(result, np.ndarray) 59 | 60 | # Check if result is correct 61 | a, b, c = matrices 62 | expected = a + b + c 63 | np.testing.assert_allclose(result, expected) 64 | 65 | @pytest.mark.parametrize("index", [0, 1]) 66 | def test_one_sparse_array(self, matrices, index): 67 | # Put a sparse array in the list of operators 68 | operators = list(matrices) 69 | operators[index] = diags_array(np.arange(self.shape[1]), shape=self.shape) 70 | 71 | # Get the sum 72 | result = _sum(op for op in operators) 73 | 74 | # We should recover a dense array 75 | assert isinstance(result, np.ndarray) 76 | 77 | # Check if result is correct 78 | a, b, c = operators 79 | expected = a + b + c 80 | np.testing.assert_allclose(result, expected) 81 | 82 | def test_all_sparse_arrays(self, sparse_arrays): 83 | result = _sum(op for op in sparse_arrays) 84 | 85 | # We should recover a sparse array 86 | assert isinstance(result, sparray) 87 | 88 | # Check if result is correct 89 | a, b, c = sparse_arrays 90 | expected = a + b + c 91 | np.testing.assert_allclose(result.toarray(), expected.toarray()) 92 | 93 | @pytest.mark.parametrize("index", [0, 1]) 94 | def test_one_linear_operator(self, matrices, vector, index): 95 | # Put a linear operator in the list of operators 96 | operators = list(matrices) 97 | factor = 5.1 98 | operators[index] = factor * aslinearoperator(operators[index]) 99 | 100 | # Get the sum 101 | result = _sum(op for op in operators) 102 | 103 | # We should recover a linear operator 104 | assert isinstance(result, LinearOperator) 105 | 106 | # Check if result is correct 107 | a, b, c = matrices 108 | expected = factor * a + b + c if index == 0 else a + factor * b + c 109 | np.testing.assert_allclose(result @ vector, expected @ vector) 110 | 111 | 112 | class Dummy(Objective): 113 | def __init__(self, n_params): 114 | self._n_params = n_params 115 | 116 | @property 117 | def n_params(self): 118 | return self._n_params 119 | 120 | def __call__(self, model): # noqa: ARG002 121 | return 2.0 122 | 123 | def gradient(self, model): # noqa: ARG002 124 | return np.ones(self.n_params) 125 | 126 | def hessian(self, model): # noqa: ARG002 127 | return np.eye(self.n_params) 128 | 129 | def hessian_diagonal(self, model): # noqa: ARG002 130 | return np.ones(self.n_params) 131 | 132 | 133 | class TestObjectiveOperations: 134 | """ 135 | Test objective functions operations. 136 | 137 | Test cases: 138 | - Sum two objective functions, should obtain Combo. 139 | - Scalar times objective function, should get Scaled. 140 | - Sum two combos, should generate a Combo (without unpacking). 141 | - Sum combo and Scaled, should return another Combo (without unpacking). 142 | - Test iadd and imul: 143 | - Errors on Objective. 144 | - Error on imul for Combo. 145 | - Error on iadd for Scaled. 146 | - imul works ok for Scaled. 147 | - idiv works ok for Scaled. 148 | - iadd works ok for Combo. 149 | """ 150 | 151 | n_params = 5 152 | 153 | def test_add(self): 154 | a, b = Dummy(self.n_params), Dummy(self.n_params) 155 | combo = a + b 156 | assert isinstance(combo, Combo) 157 | assert len(combo) == 2 158 | assert a in combo 159 | assert b in combo 160 | assert combo[0] is a 161 | assert combo[1] is b 162 | 163 | def test_add_n(self): 164 | """ 165 | Test addition of multiple objective functions into nested Combos. 166 | 167 | Since Combos are not unpacked by default, adding together more than 2 objective 168 | functions create a nested structure of Combos. 169 | """ 170 | a, b, c, d = tuple(Dummy(self.n_params) for _ in range(4)) 171 | full_combo = a + b + c + d 172 | assert isinstance(full_combo, Combo) 173 | assert len(full_combo) == 2 # combo with (a + b + c) and d 174 | assert a not in full_combo 175 | assert b not in full_combo 176 | assert c not in full_combo 177 | assert full_combo[1] is d 178 | 179 | # First level 180 | combo = full_combo[0] # combo with (a + b) and c 181 | assert isinstance(combo, Combo) 182 | assert len(combo) == 2 183 | assert a not in combo 184 | assert b not in combo 185 | assert combo[1] is c 186 | 187 | # Second level 188 | combo = full_combo[0][0] # combo with a and b 189 | assert isinstance(combo, Combo) 190 | assert len(combo) == 2 191 | assert combo[0] is a 192 | assert combo[1] is b 193 | 194 | def test_mul(self): 195 | a = Dummy(self.n_params) 196 | scalar = 3.14 197 | scaled = scalar * a 198 | assert isinstance(scaled, Scaled) 199 | assert scaled.function is a 200 | assert scaled.multiplier == scalar 201 | 202 | def test_add_combos(self): 203 | a, b, c, d = tuple(Dummy(self.n_params) for _ in range(4)) 204 | combo_a = a + b 205 | combo_b = c + d 206 | combo = combo_a + combo_b 207 | assert isinstance(combo, Combo) 208 | assert len(combo) == 2 209 | assert combo_a in combo 210 | assert combo_b in combo 211 | assert combo[0] is combo_a 212 | assert combo[1] is combo_b 213 | 214 | def test_add_scaled_and_combo(self): 215 | a, b, c = tuple(Dummy(self.n_params) for _ in range(3)) 216 | combo = a + b 217 | scaled = 3.14 * c 218 | new_combo = combo + scaled 219 | assert isinstance(new_combo, Combo) 220 | assert len(new_combo) == 2 221 | assert combo in new_combo 222 | assert scaled in new_combo 223 | assert new_combo[0] is combo 224 | assert new_combo[1] is scaled 225 | 226 | def test_iadd_combo(self): 227 | a, b, c = tuple(Dummy(self.n_params) for _ in range(3)) 228 | combo = a + b 229 | combo_bkp = combo 230 | combo += c 231 | assert isinstance(combo, Combo) 232 | assert combo is combo_bkp # assert inplace operation 233 | assert len(combo) == 3 234 | assert combo[0] is a 235 | assert combo[1] is b 236 | assert combo[2] is c 237 | 238 | @pytest.mark.parametrize("function_type", ["objective", "scaled"]) 239 | def test_iadd_error(self, function_type): 240 | phi, other = Dummy(self.n_params), Dummy(self.n_params) 241 | if function_type == "scaled": 242 | phi = 3.5 * phi 243 | with pytest.raises(TypeError): 244 | phi += other 245 | 246 | def test_imul_scaled(self): 247 | a = Dummy(self.n_params) 248 | scalar = 3.14 249 | scaled = scalar * a 250 | scaled_bkp = scaled 251 | new_scalar = 4.0 252 | scaled *= new_scalar 253 | assert isinstance(scaled, Scaled) 254 | assert scaled is scaled_bkp # assert inplace operation 255 | assert scaled.function is a 256 | assert scaled.multiplier == scalar * new_scalar 257 | 258 | @pytest.mark.parametrize("function_type", ["objective", "combo"]) 259 | def test_imul_error(self, function_type): 260 | phi = Dummy(self.n_params) 261 | if function_type == "combo": 262 | other = Dummy(self.n_params) 263 | phi = phi + other 264 | with pytest.raises(TypeError): 265 | phi *= 2.71 266 | 267 | def test_idiv_scaled(self): 268 | a = Dummy(self.n_params) 269 | scalar = 3.14 270 | scaled = scalar * a 271 | scaled_bkp = scaled 272 | new_scalar = 4.0 273 | scaled /= new_scalar 274 | assert isinstance(scaled, Scaled) 275 | assert scaled is scaled_bkp # assert inplace operation 276 | assert scaled.function is a 277 | assert scaled.multiplier == scalar / new_scalar 278 | 279 | @pytest.mark.parametrize("function_type", ["objective", "combo"]) 280 | def test_idiv_error(self, function_type): 281 | phi = Dummy(self.n_params) 282 | if function_type == "combo": 283 | other = Dummy(self.n_params) 284 | phi = phi + other 285 | with pytest.raises(TypeError): 286 | phi /= 2.71 287 | 288 | 289 | def test_combo_flatten(): 290 | """ 291 | Test flatenning of a Combo. 292 | """ 293 | a, b, c, d, e = tuple(Dummy(3) for _ in range(5)) 294 | f = 2.5 * c 295 | g = d + e 296 | 297 | # build combo: (((a + b) + 2.5 * c) + (d + e)) 298 | combo = a + b + f + g 299 | assert len(combo) == 2 300 | 301 | # Flatten it into: a + b + 2.5 * c + d + e 302 | flat_combo = combo.flatten() 303 | 304 | # Check the result of the operation 305 | assert len(flat_combo) == 5 306 | assert flat_combo[0] is a 307 | assert flat_combo[1] is b 308 | assert flat_combo[2] is f 309 | assert flat_combo[3] is d 310 | assert flat_combo[4] is e 311 | 312 | 313 | class TestComboMethods: 314 | """ 315 | Test ``__call__``, ``gradient`` and ``hessian`` for a ``Combo``. 316 | """ 317 | 318 | 319 | class TestScaledMethods: 320 | """ 321 | Test ``__call__``, ``gradient`` and ``hessian`` for a ``Scaled``. 322 | """ 323 | -------------------------------------------------------------------------------- /src/inversion_ideas/base/objective_function.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes to represent objective functions. 3 | """ 4 | 5 | from abc import ABC, abstractmethod 6 | from collections.abc import Iterable, Iterator 7 | from copy import copy 8 | from numbers import Real 9 | from typing import Self 10 | 11 | import numpy as np 12 | import numpy.typing as npt 13 | from scipy.sparse import sparray 14 | from scipy.sparse.linalg import LinearOperator, aslinearoperator 15 | 16 | from ..typing import Model 17 | 18 | 19 | class Objective(ABC): 20 | """ 21 | Abstract representation of an objective function. 22 | """ 23 | 24 | _base_str = "φ" 25 | _base_latex = r"\phi" 26 | name = None 27 | 28 | @abstractmethod 29 | def __init__(self): 30 | pass 31 | 32 | @property 33 | @abstractmethod 34 | def n_params(self) -> int: 35 | """ 36 | Number of model parameters. 37 | """ 38 | 39 | @abstractmethod 40 | def __call__(self, model: Model) -> float: 41 | """ 42 | Evaluate the objective function for a given model. 43 | """ 44 | 45 | @abstractmethod 46 | def gradient(self, model: Model) -> npt.NDArray[np.float64]: 47 | """ 48 | Evaluate the gradient of the objective function for a given model. 49 | """ 50 | 51 | @abstractmethod 52 | def hessian( 53 | self, model: Model 54 | ) -> npt.NDArray[np.float64] | sparray | LinearOperator: 55 | """ 56 | Evaluate the hessian of the objective function for a given model. 57 | """ 58 | 59 | @abstractmethod 60 | def hessian_diagonal(self, model: Model) -> npt.NDArray[np.float64]: 61 | """ 62 | Diagonal of the Hessian. 63 | """ 64 | 65 | def set_name(self, value): 66 | """ 67 | Set name for the objective function. 68 | """ 69 | if not (isinstance(value, str) or value is None): 70 | msg = ( 71 | f"Invalid name '{value}' of type {type(value)}. " 72 | "Please provide a string or None." 73 | ) 74 | raise TypeError(msg) 75 | self.name = value 76 | # Return self so we can pipe this method 77 | return self 78 | 79 | def __repr__(self): 80 | repr_ = f"{self._base_str}" 81 | if self.name is not None: 82 | repr_ += f"{self.name}" 83 | return f"{repr_}(m)" 84 | 85 | def _repr_latex_(self): 86 | repr_ = f"{self._base_latex}" 87 | if self.name is not None: 88 | repr_ += rf"_{{{self.name}}}" 89 | return f"${repr_} (m)$" 90 | 91 | def __add__(self, other) -> "Combo": 92 | return Combo([self, other]) 93 | 94 | def __radd__(self, other) -> "Combo": 95 | return Combo([other, self]) 96 | 97 | def __mul__(self, value) -> "Scaled": 98 | return Scaled(value, self) 99 | 100 | def __rmul__(self, value): 101 | return self.__mul__(value) 102 | 103 | def __truediv__(self, denominator): 104 | return self * (1.0 / denominator) 105 | 106 | def __floordiv__(self, denominator): 107 | msg = "Floor division is not implemented for objective functions." 108 | raise TypeError(msg) 109 | 110 | def __iadd__(self, other) -> Self: 111 | msg = "Inplace addition is not implemented for this class." 112 | raise TypeError(msg) 113 | 114 | def __imul__(self, other) -> Self: 115 | msg = "Inplace multiplication is not implemented for this class." 116 | raise TypeError(msg) 117 | 118 | def __itruediv__(self, value) -> Self: 119 | msg = "Inplace division is not implemented for this class." 120 | raise TypeError(msg) 121 | 122 | 123 | class Scaled(Objective): 124 | """ 125 | Scaled objective function. 126 | """ 127 | 128 | def __init__(self, multiplier, function): 129 | self.multiplier = multiplier 130 | self.function = function 131 | 132 | @property 133 | def n_params(self) -> int: 134 | """ 135 | Number of model parameters. 136 | """ 137 | return self.function.n_params 138 | 139 | def __call__(self, model: Model): 140 | """ 141 | Evaluate the objective function. 142 | """ 143 | return self.multiplier * self.function(model) 144 | 145 | def gradient(self, model: Model) -> npt.NDArray[np.float64]: 146 | """ 147 | Evaluate the gradient of the objective function for a given model. 148 | """ 149 | return self.multiplier * self.function.gradient(model) 150 | 151 | def hessian( 152 | self, model: Model 153 | ) -> npt.NDArray[np.float64] | sparray | LinearOperator: 154 | """ 155 | Evaluate the hessian of the objective function for a given model. 156 | """ 157 | return self.multiplier * self.function.hessian(model) 158 | 159 | def hessian_diagonal(self, model: Model) -> npt.NDArray[np.float64]: 160 | """ 161 | Diagonal of the Hessian. 162 | """ 163 | return self.multiplier * self.function.hessian_diagonal(model) 164 | 165 | def __repr__(self): 166 | fmt = ".2e" if np.abs(self.multiplier) > 1e3 else ".2f" 167 | phi_repr = f"{self.function}" 168 | # Add brackets in case that the function has a multiplier or is a Combo 169 | if isinstance(self.function, Iterable) or hasattr(self.function, "multiplier"): 170 | phi_repr = f"[{phi_repr}]" 171 | return f"{self.multiplier:{fmt}} {phi_repr}" 172 | 173 | def _repr_latex_(self): 174 | fmt = ( 175 | ".2e" 176 | if np.abs(self.multiplier) > 1e2 or np.abs(self.multiplier) < 1e-2 177 | else ".2f" 178 | ) 179 | multiplier_str = f"{self.multiplier:{fmt}}" 180 | if "e" in multiplier_str: 181 | base, exp = multiplier_str.split("e") 182 | exp = exp.replace("+", "") 183 | exp = str(int(exp)) 184 | multiplier_str = rf"{base} \cdot 10^{{{exp}}}" 185 | phi_str = self.function._repr_latex_().strip("$") 186 | # Add brackets in case that the function has a multiplier or is a Combo 187 | if isinstance(self.function, Iterable) or hasattr(self.function, "multiplier"): 188 | phi_str = f"[ {phi_str} ]" 189 | return rf"${multiplier_str} \, {phi_str}$" 190 | 191 | def __imul__(self, value: Real) -> Self: 192 | self.multiplier *= value 193 | return self 194 | 195 | def __itruediv__(self, value: Real) -> Self: 196 | self.multiplier /= value 197 | return self 198 | 199 | 200 | class Combo(Objective): 201 | """ 202 | Sum of objective functions. 203 | """ 204 | 205 | def __init__(self, functions: list[Objective]): 206 | _get_n_params(functions) # check if functions have the same n_params 207 | self._functions = functions 208 | 209 | def __iter__(self): 210 | return (f for f in self.functions) 211 | 212 | def __len__(self): 213 | return len(self._functions) 214 | 215 | def __getitem__(self, index): 216 | return self.functions[index] 217 | 218 | @property 219 | def functions(self) -> list[Objective]: 220 | """ 221 | List of objective functions in the sum. 222 | """ 223 | return self._functions 224 | 225 | @property 226 | def n_params(self) -> int: 227 | """ 228 | Number of model parameters. 229 | """ 230 | return _get_n_params(self.functions) 231 | 232 | def __call__(self, model: Model): 233 | """ 234 | Evaluate the objective function. 235 | """ 236 | return sum(f(model) for f in self.functions) 237 | 238 | def gradient(self, model: Model) -> npt.NDArray[np.float64]: 239 | """ 240 | Evaluate the gradient of the objective function for a given model. 241 | """ 242 | return sum(f.gradient(model) for f in self.functions) 243 | 244 | def hessian( 245 | self, model: Model 246 | ) -> npt.NDArray[np.float64] | sparray | LinearOperator: 247 | """ 248 | Evaluate the hessian of the objective function for a given model. 249 | """ 250 | return _sum(f.hessian(model) for f in self.functions) 251 | 252 | def hessian_diagonal(self, model: Model) -> npt.NDArray[np.float64]: 253 | """ 254 | Diagonal of the Hessian. 255 | """ 256 | if not self.functions: 257 | msg = "Invalid empty Combo when summing." 258 | raise ValueError(msg) 259 | return _sum_arrays(f.hessian_diagonal(model) for f in self.functions) 260 | 261 | def flatten(self) -> "Combo": 262 | """ 263 | Create a new flattened combo. 264 | 265 | Create a new ``Combo`` object by unpacking nested ``Combo``s in the current one. 266 | """ 267 | return Combo(_unpack_combo(self.functions)) 268 | 269 | def contains(self, objective) -> bool: 270 | """ 271 | Check if the ``Combo`` contains the given objective function, recursively. 272 | """ 273 | return _contains(self, objective) 274 | 275 | def __repr__(self): 276 | functions = [] 277 | for function in self.functions: 278 | function_str = repr(function) 279 | if isinstance(function, Iterable): 280 | function_str = f"[ {function_str} ]" 281 | functions.append(function_str) 282 | return " + ".join(functions) 283 | 284 | def _repr_latex_(self): 285 | functions = [] 286 | for function in self.functions: 287 | function_str = function._repr_latex_().strip("$") 288 | if isinstance(function, Iterable): 289 | function_str = f"[ {function_str} ]" 290 | functions.append(function_str) 291 | phi_str = " + ".join(functions) 292 | return f"$ {phi_str} $" 293 | 294 | def __iadd__(self, other) -> Self: 295 | if other.n_params != self.n_params: 296 | msg = ( 297 | f"Trying to add objective function '{other}' with invalid " 298 | f"n_params ({other.n_params}) different from the one of " 299 | f"'{self}' ({self.n_params})." 300 | ) 301 | raise ValueError(msg) 302 | self._functions.append(other) 303 | return self 304 | 305 | 306 | def _unpack_combo(functions: Iterable) -> list: 307 | """ 308 | Unpack combo objective functions. 309 | """ 310 | unpacked = [] 311 | for f in functions: 312 | if isinstance(f, Iterable): 313 | unpacked.extend(_unpack_combo(f)) 314 | else: 315 | unpacked.append(f) 316 | return unpacked 317 | 318 | 319 | def _contains(combo: Combo, objective: Objective) -> bool: 320 | """ 321 | Check if combo contains a given objective function, recursively. 322 | """ 323 | for f in combo.functions: 324 | if f is objective: 325 | return True 326 | if isinstance(f, Combo) and _contains(f, objective): 327 | return True 328 | if isinstance(f, Scaled): 329 | if f.function is objective: 330 | return True 331 | if isinstance(f.function, Combo) and _contains(f.function, objective): 332 | return True 333 | return False 334 | 335 | 336 | def _get_n_params(functions: list) -> int: 337 | """ 338 | Get number of parameters of a list of objective functions. 339 | 340 | Parameters 341 | ---------- 342 | functions : list of Objective 343 | List of objective functions. 344 | 345 | Returns 346 | ------- 347 | int 348 | Number of parameters of every objective function in the list. 349 | 350 | Raises 351 | ------ 352 | ValueError 353 | If any of the objective functions in the list have different number of 354 | parameters. 355 | """ 356 | n_params_list = [f.n_params for f in functions] 357 | n_params = n_params_list[0] 358 | if not all(p == n_params for p in n_params_list): 359 | msg = "Invalid objective functions with different n_params." 360 | raise ValueError(msg) 361 | return n_params 362 | 363 | 364 | def _sum( 365 | operators: Iterator[npt.NDArray | sparray | LinearOperator], 366 | ) -> npt.NDArray | sparray | LinearOperator: 367 | """ 368 | Sum objects within an iterator. 369 | 370 | This function supports summing together ``LinearOperators`` with Numpy arrays and 371 | sparse arrays. 372 | """ 373 | if not operators: 374 | msg = "Invalid empty 'operators' array when summing." 375 | raise ValueError(msg) 376 | 377 | result = copy(next(operators)) 378 | for operator in operators: 379 | if isinstance(operator, LinearOperator) or isinstance(result, LinearOperator): 380 | result = aslinearoperator(result) 381 | result += aslinearoperator(operator) 382 | else: 383 | result += operator 384 | return result 385 | 386 | 387 | def _sum_arrays(arrays: Iterator[npt.NDArray]) -> npt.NDArray: 388 | """ 389 | Sum arrays within an iterator. 390 | 391 | Parameters 392 | ---------- 393 | arrays : Iterator 394 | Iterator with arrays. 395 | 396 | See Also 397 | -------- 398 | _sum : Supports summing arrays, sparse arrays and ``LinearOperator``s. 399 | """ 400 | if not arrays: 401 | msg = "Invalid empty 'arrays' array when summing." 402 | raise ValueError(msg) 403 | 404 | result = copy(next(arrays)) 405 | for array in arrays: 406 | result += array 407 | return result 408 | -------------------------------------------------------------------------------- /src/inversion_ideas/recipes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Recipe functions to easily build commonly used inversions and objective functions. 3 | """ 4 | 5 | from collections.abc import Callable 6 | 7 | import numpy as np 8 | import numpy.typing as npt 9 | 10 | from .base import Combo, Minimizer, Objective 11 | from .conditions import ChiTarget, ObjectiveChanged 12 | from .data_misfit import DataMisfit 13 | from .directives import Irls, MultiplierCooler 14 | from .inversion import Inversion 15 | from .inversion_log import Column 16 | from .preconditioners import JacobiPreconditioner 17 | from .regularization import Flatness, Smallness 18 | from .typing import Model, Preconditioner 19 | 20 | 21 | def create_l2_inversion( 22 | data_misfit: Objective, 23 | model_norm: Objective, 24 | *, 25 | starting_beta: float, 26 | initial_model: Model, 27 | minimizer: Minimizer | Callable[[Objective, Model], Model], 28 | beta_cooling_factor: float = 2.0, 29 | beta_cooling_rate: int = 1, 30 | chi_target: float = 1.0, 31 | max_iterations: int | None = None, 32 | cache_models: bool = True, 33 | preconditioner: Preconditioner | Callable[[Model], Preconditioner] | None = None, 34 | ) -> Inversion: 35 | r""" 36 | Create inversion of the form :math:`\phi_d + \beta \phi_m`. 37 | 38 | Build an inversion with a beta cooling schedule and a stopping criteria for a chi 39 | factor target. 40 | 41 | Parameters 42 | ---------- 43 | data_misfit : Objective 44 | Data misfit term :math:`\phi_d`. 45 | model_norm : Objective 46 | Model norm :math:`\phi_m`. 47 | starting_beta : float 48 | Starting value for the trade-off parameter :math:`\beta`. 49 | initial_model : (n_params) array 50 | Initial model to use in the inversion. 51 | minimizer : Minimizer 52 | Instance of :class:`Minimizer` used to minimize the objective function during 53 | the inversion. 54 | beta_cooling_factor : float, optional 55 | Cooling factor for the trade-off parameter :math:`\beta`. Every 56 | ``beta_cooling_rate`` iterations, the :math:`\beta` will be _cooled down_ by 57 | dividing it by the ``beta_cooling_factor``. 58 | beta_cooling_rate : int, optional 59 | Cooling rate for the trade-off parameter :math:`\beta`. The trade-off parameter 60 | will be cooled down every ``beta_cooling_rate`` iterations. 61 | chi_target : float, optional 62 | Target for the chi factor. The inversion will finish after the data misfit 63 | reaches a :math:`\chi` factor lower or equal to ``chi_target``. 64 | max_iterations : int, optional 65 | Max amount of iterations that will be performed. If ``None``, then there will be 66 | no limit on the total amount of iterations. 67 | cache_models : bool, optional 68 | Whether to cache models after each iteration in the inversion. 69 | preconditioner : {"jacobi"} or 2d array or sparray or LinearOperator or callable or None, optional 70 | Preconditioner that will be passed to the ``minimizer`` on every call during the 71 | inversion. The preconditioner can be a predefined 2d array, a sparse array or 72 | a LinearOperator. Alternatively, it can be a callable that takes the ``model`` 73 | as argument and returns a preconditioner matrix (same types listed before). If 74 | ``"jacobi"``, a default Jacobi preconditioner that will get updated on every 75 | iteration will be defined for the inversion. If None, no preconditioner will be 76 | passed. 77 | 78 | Returns 79 | ------- 80 | Inversion 81 | """ 82 | # Define objective function 83 | regularization = starting_beta * model_norm 84 | objective_function = data_misfit + regularization 85 | 86 | # Define directives 87 | directives = [ 88 | MultiplierCooler( 89 | regularization, 90 | cooling_factor=beta_cooling_factor, 91 | cooling_rate=beta_cooling_rate, 92 | ), 93 | ] 94 | 95 | # Stopping criteria 96 | stopping_criteria = ChiTarget(data_misfit, chi_target=chi_target) 97 | 98 | # Preconditioner 99 | minimizer_kwargs = {} 100 | if preconditioner is not None: 101 | if isinstance(preconditioner, str): 102 | if preconditioner == "jacobi": 103 | preconditioner = JacobiPreconditioner(objective_function) 104 | else: 105 | msg = f"Invalid preconditioner '{preconditioner}'." 106 | raise ValueError(msg) 107 | minimizer_kwargs["preconditioner"] = preconditioner 108 | 109 | # Define inversion 110 | inversion = Inversion( 111 | objective_function, 112 | initial_model, 113 | minimizer, 114 | directives=directives, 115 | stopping_criteria=stopping_criteria, 116 | cache_models=cache_models, 117 | max_iterations=max_iterations, 118 | log=True, 119 | minimizer_kwargs=minimizer_kwargs, 120 | ) 121 | return inversion 122 | 123 | 124 | def create_sparse_inversion( 125 | data_misfit: DataMisfit, 126 | model_norm: Objective, 127 | *, 128 | starting_beta: float, 129 | initial_model: Model, 130 | minimizer: Minimizer | Callable[[Objective, Model], Model], 131 | beta_cooling_factor: float = 2.0, 132 | data_misfit_rtol=1e-1, 133 | chi_l2_target: float = 1.0, 134 | model_norm_rtol: float = 1e-3, 135 | max_iterations: int | None = None, 136 | cache_models: bool = True, 137 | preconditioner: Preconditioner | Callable[[Model], Preconditioner] | None = None, 138 | ) -> Inversion: 139 | r""" 140 | Create sparse norm inversion of the form: :math:`\phi_d + \beta \phi_m`. 141 | 142 | Build an inversion where :math:`\phi_m` is a sparse norm regularization. 143 | An IRLS algorithm will be applied, split in two stages. 144 | The inversion will stop when the following inequality holds: 145 | 146 | .. math:: 147 | 148 | \frac{|\phi_m^{(k)} - \phi_m^{(k-1)}|}{|\phi_m^{(k-1)}|} < \eta_{\phi_m} 149 | 150 | where :math:`\eta_{\phi_m}` is the ``model_norm_rtol``. 151 | 152 | Parameters 153 | ---------- 154 | data_misfit : Objective 155 | Data misfit term :math:`\phi_d`. 156 | model_norm : Objective 157 | Model norm :math:`\phi_m`. It can be a single objective function term or a combo 158 | containing multiple ones. At least one of them should be a sparse regularization 159 | term. 160 | starting_beta : float 161 | Starting value for the trade-off parameter :math:`\beta`. 162 | initial_model : (n_params) array 163 | Initial model to use in the inversion. 164 | minimizer : Minimizer 165 | Instance of :class:`Minimizer` used to minimize the objective function during 166 | the inversion. 167 | beta_cooling_factor : float, optional 168 | Cooling factor for the trade-off parameter :math:`\beta`. Every 169 | ``beta_cooling_rate`` iterations, the :math:`\beta` will be _cooled down_ by 170 | dividing it by the ``beta_cooling_factor``. 171 | data_misfit_rtol : float, optional 172 | Tolerance for the data misfit. This value is used to determine whether to cool 173 | down the IRLS threshold or beta. See eq. 21 in Fournier and Oldenburg (2019). 174 | chi_l2_target : float, optional 175 | Chi factor target for the stage one (the L2 inversion). Once this chi target is 176 | reached, the second stage starts. 177 | model_norm_rtol : float, optional 178 | Tolerance for the model norm. This value is used to determine if the inversion 179 | should stop. See eq. 22 in Fournier and Oldenburg (2019). 180 | max_iterations : int, optional 181 | Max amount of iterations that will be performed. If ``None``, then there will be 182 | no limit on the total amount of iterations. 183 | cache_models : bool, optional 184 | Whether to cache models after each iteration in the inversion. 185 | preconditioner : {"jacobi"} or 2d array or sparray or LinearOperator or callable or None, optional 186 | Preconditioner that will be passed to the ``minimizer`` on every call during the 187 | inversion. The preconditioner can be a predefined 2d array, a sparse array or 188 | a LinearOperator. Alternatively, it can be a callable that takes the ``model`` 189 | as argument and returns a preconditioner matrix (same types listed before). If 190 | ``"jacobi"``, a default Jacobi preconditioner that will get updated on every 191 | iteration will be defined for the inversion. If None, no preconditioner will be 192 | passed. 193 | 194 | Returns 195 | ------- 196 | Inversion 197 | """ 198 | # Define objective function 199 | regularization = starting_beta * model_norm 200 | objective_function = data_misfit + regularization 201 | 202 | # Define IRLS directive 203 | directives = [ 204 | Irls( 205 | regularization, 206 | data_misfit=data_misfit, 207 | chi_l2_target=chi_l2_target, 208 | beta_cooling_factor=beta_cooling_factor, 209 | data_misfit_rtol=data_misfit_rtol, 210 | ) 211 | ] 212 | 213 | # Stopping criteria 214 | smallness_not_changing = ObjectiveChanged(model_norm, rtol=model_norm_rtol) 215 | 216 | # Preconditioner 217 | minimizer_kwargs = {} 218 | if preconditioner is not None: 219 | if isinstance(preconditioner, str): 220 | if preconditioner == "jacobi": 221 | preconditioner = JacobiPreconditioner(objective_function) 222 | else: 223 | msg = f"Invalid preconditioner '{preconditioner}'." 224 | raise ValueError(msg) 225 | minimizer_kwargs["preconditioner"] = preconditioner 226 | 227 | # Define inversion 228 | inversion = Inversion( 229 | objective_function, 230 | initial_model, 231 | minimizer, 232 | directives=directives, 233 | stopping_criteria=smallness_not_changing, 234 | cache_models=cache_models, 235 | max_iterations=max_iterations, 236 | log=True, 237 | minimizer_kwargs=minimizer_kwargs, 238 | ) 239 | 240 | # Add extra columns to log 241 | if inversion.log is not None: 242 | # TODO: fix this in case that model norm is a combo 243 | inversion.log.add_column( 244 | "IRLS", lambda _, __: "active" if model_norm.irls else "inactive" 245 | ) 246 | inversion.log.add_column( 247 | "IRLS threshold", 248 | Column( 249 | title="ε", 250 | callable=lambda _, __: model_norm.threshold, 251 | fmt=None, 252 | ), 253 | ) 254 | inversion.log.add_column( 255 | "model_norm_relative_diff", 256 | Column( 257 | title=r"|φm_(k) - φm_(k-1)|/|φm_(k-1)|", 258 | callable=lambda _, model: smallness_not_changing.ratio(model), 259 | fmt=None, 260 | ), 261 | ) 262 | return inversion 263 | 264 | 265 | def create_tikhonov_regularization( 266 | mesh, 267 | *, 268 | active_cells: npt.NDArray[np.bool] | None = None, 269 | cell_weights: npt.NDArray | dict[str, npt.NDArray] | None = None, 270 | reference_model: Model | None = None, 271 | alpha_s: float | None = None, 272 | alpha_x: float | None = None, 273 | alpha_y: float | None = None, 274 | alpha_z: float | None = None, 275 | reference_model_in_flatness: bool = False, 276 | ) -> Combo: 277 | """ 278 | Create a linear combination of Tikhonov (L2) regularization terms. 279 | 280 | Define a :class:`inversion_ideas.base.Combo` with L2 smallness and flatness 281 | regularization terms. 282 | 283 | Parameters 284 | ---------- 285 | mesh : discretize.base.BaseMesh 286 | Mesh to use in the regularization. 287 | active_cells : (n_params) array or None, optional 288 | Array full of bools that indicate the active cells in the mesh. 289 | cell_weights : (n_params) array or dict of (n_params) arrays or None, optional 290 | Array with cell weights. 291 | For multiple cell weights, pass a dictionary where keys are strings and values 292 | are the different weights arrays. 293 | If None, no cell weights are going to be used. 294 | reference_model : (n_params) array or None, optional 295 | Array with values for the reference model. 296 | alpha_s : float or None, optional 297 | Multiplier for the smallness term. 298 | alpha_x, alpha_y, alpha_z : float or None, optional 299 | Multipliers for the flatness terms. 300 | 301 | Returns 302 | ------- 303 | inversion_ideas.base.Combo 304 | Combo of L2 regularization terms. 305 | 306 | Notes 307 | ----- 308 | TODO 309 | """ 310 | ndims = mesh.dim 311 | if ndims == 2 and alpha_z is not None: 312 | msg = f"Cannot pass 'alpha_z' when mesh has {ndims} dimensions." 313 | raise TypeError(msg) 314 | if ndims == 1 and (alpha_y is not None or alpha_z is not None): 315 | msg = "Cannot pass 'alpha_y' nor 'alpha_z' when mesh has 1 dimension." 316 | raise TypeError(msg) 317 | 318 | smallness = Smallness( 319 | mesh, 320 | active_cells=active_cells, 321 | cell_weights=cell_weights, 322 | reference_model=reference_model, 323 | ) 324 | if alpha_s is not None: 325 | smallness = alpha_s * smallness 326 | 327 | kwargs = { 328 | "active_cells": active_cells, 329 | "cell_weights": cell_weights, 330 | } 331 | if reference_model_in_flatness: 332 | kwargs["reference_model"] = reference_model 333 | 334 | match ndims: 335 | case 3: 336 | directions = ("x", "y", "z") 337 | alphas = (alpha_x, alpha_y, alpha_z) 338 | case 2: 339 | directions = ("x", "y") 340 | alphas = (alpha_x, alpha_y) 341 | case 1: 342 | directions = ("x",) 343 | alphas = (alpha_x,) 344 | case _: 345 | raise ValueError() 346 | 347 | regularization = smallness 348 | for direction, alpha in zip(directions, alphas, strict=True): 349 | phi = Flatness(mesh, **kwargs, direction=direction) 350 | if alpha is not None: 351 | phi = alpha * phi 352 | regularization = regularization + phi 353 | 354 | return regularization.flatten() 355 | -------------------------------------------------------------------------------- /src/inversion_ideas/directives.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directives to modify the objective function between iterations of an inversion. 3 | """ 4 | 5 | import numpy as np 6 | 7 | from ._utils import extract_from_combo 8 | from .base import Combo, Directive, Objective, Scaled, Simulation 9 | from .conditions import ObjectiveChanged 10 | from .data_misfit import DataMisfit 11 | from .typing import Model, SparseRegularization 12 | from .utils import get_logger, get_sensitivity_weights 13 | 14 | 15 | class MultiplierCooler(Directive): 16 | r""" 17 | Cool the multiplier of an objective function. 18 | 19 | Parameters 20 | ---------- 21 | scaled_objective : Scaled 22 | Scaled objective function whose multiplier will be cooled. 23 | cooling_factor : float 24 | Factor by which the multiplier will be cooled. 25 | cooling_rate : int, optional 26 | Cool down the multiplier every ``cooling_rate`` call to this directive. 27 | 28 | Notes 29 | ----- 30 | Given a scaled objective function :math:`\phi(\mathbf{m}) = \alpha 31 | \varphi(\mathbf{m})`, and a cooling factor :math:`k`, this directive will *cool* the 32 | multiplier `\alpha` by dividing it by :math:`k` on every ``cooling_rate`` call to 33 | the directive. 34 | """ 35 | 36 | def __init__( 37 | self, scaled_objective: Scaled, cooling_factor: float, cooling_rate: int = 1 38 | ): 39 | if not hasattr(scaled_objective, "multiplier"): 40 | msg = "Invalid 'scaled_objective': it must have a `multiplier` attribute." 41 | raise TypeError(msg) 42 | self.regularization = scaled_objective 43 | self.cooling_factor = cooling_factor 44 | self.cooling_rate = cooling_rate 45 | 46 | def __call__(self, model: Model, iteration: int): # noqa: ARG002 47 | """ 48 | Cool the multiplier. 49 | """ 50 | if iteration % self.cooling_rate == 0: 51 | self.regularization.multiplier /= self.cooling_factor 52 | 53 | 54 | class Irls(Directive): 55 | """ 56 | Apply iterative reweighed least squares (IRLS). 57 | 58 | This directive is intended to work with a single inversion that performs the two 59 | stages. 60 | 61 | .. note:: 62 | 63 | This directive can only be applied to sparse (lp norm) regularizations. In 64 | summary they should: 65 | 66 | 1. have a ``irls`` bool attribute, 67 | 2. have a ``update_irls`` and a ``activate_irls`` methods. 68 | 69 | Parameters 70 | ---------- 71 | *args : Objective 72 | Sparse regularizations that will get IRLS updated. 73 | It can be a single regularization object 74 | (e.g. :class:`inversion_ideas.SmallnessSparse`), a 75 | :class:`inversion_ideas.base.Combo`, or a :class:`inversion_ideas.base.Scaled`, 76 | or multiple of them. 77 | :class:`inversion_ideas.base.Combo` and 78 | :class:`inversion_ideas.base.Scaled` regularizations will be explored 79 | recursively to use regularizations terms that have sensitivity weights that can 80 | be updated. 81 | data_misfit : DataMisfit 82 | Data misfit function that will be evaluated to decide whether to update the 83 | IRLS on ``sparse``, or to cool the multiplier of ``regularization``. 84 | regularization_with_beta : Scaled or None, optional 85 | Regularization that will get its multiplier cooled down. 86 | If a single ``arg`` is passed, it will be used as the regularization that will 87 | get its multiplier cooled down. Pass a ``regularization_with_beta`` if another 88 | regularization's multiplier should be cooled down, or if multiple ``args`` are 89 | passed. 90 | chi_l2_target : float, optional 91 | Target for the chi factor used in the first stage (L2 inversion). Once this 92 | target is reached, the IRLS will be activated. 93 | beta_cooling_factor : float, optional 94 | Cooling factor used to cool down the ``regularization``'s multiplier. 95 | data_misfit_rtol : float, optional 96 | Relative tolerance for the data misfit. 97 | Used to compare the current value of the data misfit with its value after the 98 | stage one is finished. 99 | cool_beta : bool, optional 100 | Whether to cool down beta during the IRLS process. 101 | If False, make sure you handle beta cooling in other way, like through other 102 | directive. 103 | 104 | .. warning:: 105 | If False, the Irls directive won't cool down beta during the inversions. 106 | This might prevent from reaching convergence. 107 | Make sure you handle beta cooling in other way, like through other 108 | directive. 109 | """ 110 | 111 | def __init__( 112 | self, 113 | *args: Objective, 114 | data_misfit: DataMisfit, 115 | regularization_with_beta: Scaled | None = None, 116 | chi_l2_target=1.0, 117 | beta_cooling_factor=2.0, 118 | data_misfit_rtol=1e-1, 119 | cool_beta=True, 120 | ): 121 | if len(args) == 0: 122 | msg = ( 123 | "Missing sparse regularization. " 124 | "Pass at least one to the IRLS directive." 125 | ) 126 | raise TypeError(msg) 127 | 128 | if regularization_with_beta is None: 129 | # Raise error if multiple sparse regs and regularization_with_beta as None 130 | if len(args) > 1: 131 | msg = ( 132 | "Cannot pass multiple sparse regularizations and leave " 133 | "'regularization_with_beta' as None. " 134 | ) 135 | raise TypeError(msg) 136 | # Assign the sparse regularization passed through args as the regularization 137 | # with beta 138 | (_reg,) = args 139 | if not isinstance(_reg, Scaled): 140 | msg = ( 141 | f"Cannot use {regularization_with_beta} as the " 142 | "'regularization_with_beta' since it doesn't have a multiplier " 143 | "that can be cooled down. " 144 | "Pass a value to 'regularization_with_beta' or pass a scaled " 145 | "regularization through the 'args'." 146 | ) 147 | raise TypeError(msg) 148 | regularization_with_beta = _reg 149 | 150 | self.regularization_with_beta: Scaled = regularization_with_beta 151 | self.sparse_regs: list[SparseRegularization] = ( 152 | self._extract_sparse_regularizations(args) 153 | ) 154 | if not self.sparse_regs: 155 | msg = ( 156 | "Invalid regularizations passed through the `args` argument. " 157 | "Couldn't locate any sparse regularization term in them." 158 | ) 159 | raise TypeError(msg) 160 | 161 | self.data_misfit = data_misfit 162 | if not hasattr(data_misfit, "chi_factor"): 163 | msg = "Invalid `data_misfit` object without `chi_factor` method." 164 | raise TypeError(msg) 165 | 166 | self.data_misfit_rtol = data_misfit_rtol 167 | self.chi_l2_target = chi_l2_target 168 | 169 | # Define a beta cooler 170 | self._beta_cooler = ( 171 | MultiplierCooler( 172 | self.regularization_with_beta, cooling_factor=beta_cooling_factor 173 | ) 174 | if cool_beta 175 | else None 176 | ) 177 | 178 | # Define a condition for the data misfit. 179 | # Compare it always with the data misfit obtained with the model from l2 180 | # inversion. 181 | self._dmisfit_below_threshold = ObjectiveChanged( 182 | data_misfit, rtol=self.data_misfit_rtol 183 | ) 184 | 185 | @property 186 | def beta_cooling_factor(self) -> float | None: 187 | """ 188 | Current beta cooling factor. 189 | """ 190 | if self._beta_cooler is None: 191 | return None 192 | return self._beta_cooler.cooling_factor 193 | 194 | def __call__(self, model: Model, iteration: int): 195 | """ 196 | Apply IRLS. 197 | 198 | Cool down beta or update IRLS depending on the values of the data misfit. 199 | """ 200 | # Cool down beta until IRLS gets activated 201 | if not all(sparse_reg.irls for sparse_reg in self.sparse_regs): 202 | self._stage_one(model, iteration) 203 | else: 204 | self._stage_two(model, iteration) 205 | 206 | def _stage_one(self, model: Model, iteration: int): 207 | """ 208 | Implement first stage of the IRLS inversion. 209 | """ 210 | if self.data_misfit.chi_factor(model) < self.chi_l2_target: 211 | # Activate IRLS if chi target has been met 212 | for sparse_reg in self.sparse_regs: 213 | sparse_reg.activate_irls(model) 214 | # Cache some attributes 215 | self._model_l2 = model 216 | self._dmisfit_l2 = self.data_misfit(self._model_l2) 217 | self._dmisfit_below_threshold.previous = self._dmisfit_l2 218 | return 219 | 220 | # Cool down beta otherwise 221 | if self._beta_cooler is not None: 222 | self._beta_cooler(model, iteration) 223 | 224 | def _stage_two(self, model: Model, iteration: int): 225 | """ 226 | Implement second stage of the IRLS inversion. 227 | """ 228 | if not self._dmisfit_below_threshold(model): 229 | # Cool beta if the data misfit is quite different from the l2 one 230 | phi_d = self.data_misfit(model) 231 | # Adjust the cooling factor 232 | # (following current implementation of UpdateIRLS) 233 | if self._beta_cooler is not None: 234 | if self._beta_cooler.cooling_factor != 1: 235 | if phi_d > self._dmisfit_l2: 236 | self._beta_cooler.cooling_factor = float( 237 | 1 / np.mean([0.75, self._dmisfit_l2 / phi_d]) 238 | ) 239 | else: 240 | self._beta_cooler.cooling_factor = float( 241 | 1 / np.mean([2.0, self._dmisfit_l2 / phi_d]) 242 | ) 243 | self._beta_cooler(model, iteration) 244 | else: 245 | # Update the IRLS 246 | for sparse_reg in self.sparse_regs: 247 | sparse_reg.update_irls(model) 248 | 249 | def _extract_sparse_regularizations( 250 | self, args: tuple[Objective, ...] 251 | ) -> list[SparseRegularization]: 252 | """ 253 | Select sparse regularizations recursively from the passed args. 254 | """ 255 | 256 | def is_sparse(regularization: Objective) -> bool: 257 | return ( 258 | hasattr(regularization, "irls") 259 | and hasattr(regularization, "update_irls") 260 | and hasattr(regularization, "activate_irls") 261 | ) 262 | 263 | sparse_regs = [] 264 | for objective in args: 265 | if isinstance(objective, Scaled | Combo): 266 | extracted_regs = extract_from_combo(objective, is_sparse) 267 | for reg in extracted_regs: 268 | get_logger().debug( 269 | f"Sparse regularization {reg} will get IRLS managed " 270 | f"by the {self} directive." 271 | ) 272 | sparse_regs += extracted_regs 273 | elif is_sparse(objective): 274 | get_logger().debug( 275 | f"Sparse regularization {objective} will get IRLS managed " 276 | f"by the {self} directive." 277 | ) 278 | sparse_regs.append(objective) 279 | 280 | return sparse_regs 281 | 282 | 283 | class UpdateSensitivityWeights(Directive): 284 | """ 285 | Update sensitivity weights on regularizations. 286 | 287 | .. note:: 288 | 289 | This directive can only be applied to regularizations that: 290 | 1. have a ``cell_weights`` attribute, 291 | 2. the ``cell_weights`` attribute is a dictionary, 292 | 3. the ``cell_weights`` attribute contains weights under the key specified 293 | through the ``weights_key`` argument ("sensitivity" by default). 294 | 295 | Parameters 296 | ---------- 297 | *args : Objective 298 | Regularizations to which the sensitivity weights will be updated. 299 | If a :class:`inversion_ideas.base.Combo` or 300 | a :class:`inversion_ideas.base.Scaled` are passed, they will be explored 301 | recursively to use regularizations that have sensitivity weights that can be 302 | updated. 303 | simulation : Simulation 304 | Simulation used to get the jacobian matrix that will be used while updating the 305 | sensitivity weights. 306 | weights_key : str, optional 307 | Key used to store the sensitivity weights on the regularization's 308 | ``cell_weights`` dictionary. Only the weights under this key will be updated. 309 | **kwargs 310 | Extra arguments passed to the 311 | :func:`inversion_ideas.utils.get_sensitivity_weights` function. 312 | 313 | See Also 314 | -------- 315 | inversion_ideas.utils.get_sensitivity_weights 316 | """ 317 | 318 | def __init__( 319 | self, 320 | *args: Objective, 321 | simulation: Simulation, 322 | weights_key: str = "sensitivity", 323 | **kwargs, 324 | ): 325 | if not args: 326 | msg = "Missing regularization. Pass at least one." 327 | raise TypeError(msg) 328 | 329 | self.weights_key = weights_key 330 | self.simulation = simulation 331 | self.kwargs = kwargs 332 | self.regularizations: list[Objective] = self._extract_regularizations(args) 333 | 334 | if not self.regularizations: 335 | msg = ( 336 | "Invalid regularizations passed through the `args` argument. " 337 | "Couldn't locate any regularization term to update " 338 | "their sensitivity weights." 339 | ) 340 | raise TypeError(msg) 341 | 342 | def __call__(self, model: Model, iteration: int): # noqa: ARG002 343 | """ 344 | Update sensitivity weights. 345 | """ 346 | # Compute the jacobian and the new sensitivity weights 347 | jacobian = self.simulation.jacobian(model) 348 | self._check_jacobian_type(jacobian) 349 | new_sensitivity_weights = get_sensitivity_weights(jacobian, **self.kwargs) 350 | 351 | # Update sensitivity weights on regularizations 352 | for regularization in self.regularizations: 353 | self._check_cell_weights(regularization) 354 | regularization.cell_weights[self.weights_key] = new_sensitivity_weights 355 | 356 | def _extract_regularizations(self, args: tuple[Objective, ...]) -> list[Objective]: 357 | """ 358 | Select regularizations to update their sensitivity weights. 359 | 360 | Extract a selection of the regularizations passed as arguments to build the 361 | ``self.regularizations`` attribute. Follow this criteria: 362 | 363 | - Any objective function that is not a ``Combo`` or a ``Scaled`` will be added 364 | as is. We'll check if the regularization has sensitivity weights (see below). 365 | - Any ``Combo`` or ``Scaled`` will be recursively explored to extract any 366 | regularization function contained by them that has sensitivity weights. 367 | 368 | A regularization is considered to have sensitivity weights if: 369 | 370 | 1. Has a ``cell_weights`` attribute. 371 | 2. Its ``cell_weights`` attribute is a dictionary. 372 | 3. Its ``cell_weights`` attribute has a key equal to ``self.weights_key``. 373 | """ 374 | 375 | def has_sensitivity_weights(regularization: Objective) -> bool: 376 | return ( 377 | hasattr(regularization, "cell_weights") 378 | and isinstance(regularization.cell_weights, dict) 379 | and self.weights_key in regularization.cell_weights 380 | ) 381 | 382 | regularizations = [] 383 | for objective in args: 384 | if isinstance(objective, Scaled | Combo): 385 | extracted_regs = extract_from_combo(objective, has_sensitivity_weights) 386 | for reg in extracted_regs: 387 | get_logger().debug( 388 | f"Sensitivity weights of {reg} will be updated " 389 | f"by the {self} directive." 390 | ) 391 | regularizations += extracted_regs 392 | else: 393 | self._check_cell_weights(objective) 394 | regularizations.append(objective) 395 | 396 | return regularizations 397 | 398 | def _check_jacobian_type(self, jacobian): 399 | """Check if jacobian is a dense array.""" 400 | if not isinstance(jacobian, np.ndarray): 401 | msg = ( 402 | "Cannot compute sensitivity weights for simulation " 403 | f"{self.simulation} : its jacobian is a {type(jacobian)}. " 404 | "It must be a dense array." 405 | ) 406 | raise TypeError(msg) 407 | 408 | def _check_cell_weights(self, regularization: Objective): 409 | """Sanity checks for cell_weights in regularization.""" 410 | # Check if regularization have cell_weights attribute 411 | if not hasattr(regularization, "cell_weights"): 412 | msg = ( 413 | "Missing `cell-weights` attribute in regularization " 414 | f"'{regularization}'." 415 | ) 416 | raise AttributeError(msg) 417 | 418 | if not isinstance(regularization.cell_weights, dict): 419 | msg = ( 420 | f"Invalid `cell_weights` attribute of type '{type(regularization)}' " 421 | f"for the '{regularization}'. It must be a dictionary." 422 | ) 423 | raise TypeError(msg) 424 | if self.weights_key not in regularization.cell_weights: 425 | msg = ( 426 | f"Missing '{self.weights_key}' weights in " 427 | f"{regularization}.cell_weights. " 428 | ) 429 | raise KeyError(msg) 430 | -------------------------------------------------------------------------------- /notebooks/08_conditions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b3b20935", 6 | "metadata": {}, 7 | "source": [ 8 | "# Experiment with conditions" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "1404792a", 15 | "metadata": { 16 | "execution": { 17 | "iopub.execute_input": "2025-10-07T20:48:58.956945Z", 18 | "iopub.status.busy": "2025-10-07T20:48:58.956672Z", 19 | "iopub.status.idle": "2025-10-07T20:49:00.692455Z", 20 | "shell.execute_reply": "2025-10-07T20:49:00.691779Z", 21 | "shell.execute_reply.started": "2025-10-07T20:48:58.956918Z" 22 | } 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "from inversion_ideas import CustomCondition" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "id": "197da35b", 33 | "metadata": { 34 | "execution": { 35 | "iopub.execute_input": "2025-10-07T20:49:00.693326Z", 36 | "iopub.status.busy": "2025-10-07T20:49:00.693011Z", 37 | "iopub.status.idle": "2025-10-07T20:49:00.700075Z", 38 | "shell.execute_reply": "2025-10-07T20:49:00.699406Z", 39 | "shell.execute_reply.started": "2025-10-07T20:49:00.693304Z" 40 | } 41 | }, 42 | "outputs": [ 43 | { 44 | "data": { 45 | "text/plain": [ 46 | "" 47 | ] 48 | }, 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "output_type": "execute_result" 52 | } 53 | ], 54 | "source": [ 55 | "is_greater_than_one = CustomCondition.create(lambda model: model > 1)\n", 56 | "is_greater_than_one" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "id": "35b4f1a6", 63 | "metadata": { 64 | "execution": { 65 | "iopub.execute_input": "2025-10-07T20:49:00.701198Z", 66 | "iopub.status.busy": "2025-10-07T20:49:00.700894Z", 67 | "iopub.status.idle": "2025-10-07T20:49:00.725594Z", 68 | "shell.execute_reply": "2025-10-07T20:49:00.723464Z", 69 | "shell.execute_reply.started": "2025-10-07T20:49:00.701168Z" 70 | } 71 | }, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "" 77 | ] 78 | }, 79 | "execution_count": 3, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "is_positive = CustomCondition.create(lambda model: model > 0)\n", 86 | "is_positive" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 4, 92 | "id": "e36397f9", 93 | "metadata": { 94 | "execution": { 95 | "iopub.execute_input": "2025-10-07T20:49:00.729820Z", 96 | "iopub.status.busy": "2025-10-07T20:49:00.727815Z", 97 | "iopub.status.idle": "2025-10-07T20:49:00.739591Z", 98 | "shell.execute_reply": "2025-10-07T20:49:00.738196Z", 99 | "shell.execute_reply.started": "2025-10-07T20:49:00.729691Z" 100 | } 101 | }, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "2.0 True\n", 108 | "0.5 False\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "x = 2.0\n", 114 | "print(x, is_greater_than_one(x))\n", 115 | "\n", 116 | "x = 0.5\n", 117 | "print(x, is_greater_than_one(x))" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 5, 123 | "id": "1b090f90", 124 | "metadata": { 125 | "execution": { 126 | "iopub.execute_input": "2025-10-07T20:49:00.741706Z", 127 | "iopub.status.busy": "2025-10-07T20:49:00.741148Z", 128 | "iopub.status.idle": "2025-10-07T20:49:00.755038Z", 129 | "shell.execute_reply": "2025-10-07T20:49:00.753391Z", 130 | "shell.execute_reply.started": "2025-10-07T20:49:00.741657Z" 131 | } 132 | }, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "2.0 True\n", 139 | "0.5 True\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "x = 2.0\n", 145 | "print(x, is_positive(x))\n", 146 | "\n", 147 | "x = 0.5\n", 148 | "print(x, is_positive(x))" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 6, 154 | "id": "78b5b6fe", 155 | "metadata": { 156 | "execution": { 157 | "iopub.execute_input": "2025-10-07T20:49:00.760585Z", 158 | "iopub.status.busy": "2025-10-07T20:49:00.759437Z", 159 | "iopub.status.idle": "2025-10-07T20:49:00.767266Z", 160 | "shell.execute_reply": "2025-10-07T20:49:00.766445Z", 161 | "shell.execute_reply.started": "2025-10-07T20:49:00.760520Z" 162 | } 163 | }, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "" 169 | ] 170 | }, 171 | "execution_count": 6, 172 | "metadata": {}, 173 | "output_type": "execute_result" 174 | } 175 | ], 176 | "source": [ 177 | "logical_and = is_greater_than_one & is_positive\n", 178 | "logical_and" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 7, 184 | "id": "427c432f", 185 | "metadata": { 186 | "execution": { 187 | "iopub.execute_input": "2025-10-07T20:49:00.768965Z", 188 | "iopub.status.busy": "2025-10-07T20:49:00.768502Z", 189 | "iopub.status.idle": "2025-10-07T20:49:00.778820Z", 190 | "shell.execute_reply": "2025-10-07T20:49:00.777979Z", 191 | "shell.execute_reply.started": "2025-10-07T20:49:00.768915Z" 192 | } 193 | }, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "2.0 True\n", 200 | "0.5 False\n", 201 | "-0.5 False\n" 202 | ] 203 | } 204 | ], 205 | "source": [ 206 | "x = 2.0\n", 207 | "print(x, logical_and(x))\n", 208 | "\n", 209 | "x = 0.5\n", 210 | "print(x, logical_and(x))\n", 211 | "\n", 212 | "x = -0.5\n", 213 | "print(x, logical_and(x))" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 8, 219 | "id": "b1945817", 220 | "metadata": { 221 | "execution": { 222 | "iopub.execute_input": "2025-10-07T20:49:00.779918Z", 223 | "iopub.status.busy": "2025-10-07T20:49:00.779608Z", 224 | "iopub.status.idle": "2025-10-07T20:49:00.787320Z", 225 | "shell.execute_reply": "2025-10-07T20:49:00.786469Z", 226 | "shell.execute_reply.started": "2025-10-07T20:49:00.779884Z" 227 | } 228 | }, 229 | "outputs": [ 230 | { 231 | "data": { 232 | "text/plain": [ 233 | "" 234 | ] 235 | }, 236 | "execution_count": 8, 237 | "metadata": {}, 238 | "output_type": "execute_result" 239 | } 240 | ], 241 | "source": [ 242 | "logical_or = is_greater_than_one | is_positive\n", 243 | "logical_or" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 9, 249 | "id": "4d82595d", 250 | "metadata": { 251 | "execution": { 252 | "iopub.execute_input": "2025-10-07T20:49:00.788658Z", 253 | "iopub.status.busy": "2025-10-07T20:49:00.788244Z", 254 | "iopub.status.idle": "2025-10-07T20:49:00.795156Z", 255 | "shell.execute_reply": "2025-10-07T20:49:00.794368Z", 256 | "shell.execute_reply.started": "2025-10-07T20:49:00.788620Z" 257 | } 258 | }, 259 | "outputs": [ 260 | { 261 | "name": "stdout", 262 | "output_type": "stream", 263 | "text": [ 264 | "2.0 True\n", 265 | "0.5 True\n", 266 | "-0.5 False\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "x = 2.0\n", 272 | "print(x, logical_or(x))\n", 273 | "\n", 274 | "x = 0.5\n", 275 | "print(x, logical_or(x))\n", 276 | "\n", 277 | "x = -0.5\n", 278 | "print(x, logical_or(x))" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 10, 284 | "id": "57694109", 285 | "metadata": { 286 | "execution": { 287 | "iopub.execute_input": "2025-10-07T20:49:00.796359Z", 288 | "iopub.status.busy": "2025-10-07T20:49:00.796035Z", 289 | "iopub.status.idle": "2025-10-07T20:49:00.967667Z", 290 | "shell.execute_reply": "2025-10-07T20:49:00.966870Z", 291 | "shell.execute_reply.started": "2025-10-07T20:49:00.796322Z" 292 | } 293 | }, 294 | "outputs": [ 295 | { 296 | "data": { 297 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjkAAAGdCAYAAADwjmIIAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAAIkpJREFUeJzt3XtU1HXi//HXIDB4gcldFVA5at7R1EIjLC/lipfsXloa5clsrTRdV2uxbcXaUis9ponlrmKpFe1am6espFLTxFJ/oKZpul7gJCSaMWjKRd+/P/o664QSyGWGd8/HOZ+zzWfen/fn/XGOznM/M6jDGGMEAABgmQBfLwAAAKA6EDkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArBTo6wXUFmfPntXhw4cVGhoqh8Ph6+UAAPyYMUYFBQVq2rSpAgKq737C6dOnVVRUVOl5goODFRISUgUr8i9ETjkdPnxYUVFRvl4GAKAWyc7OVvPmzatl7tOnT6tViwbKPXKm0nNFRETowIED1oUOkVNOoaGhkqRD/6+lwhrwKR8A4OLcJ86qxVUHPe8d1aGoqEi5R87o0NaWCgu99Pcld8FZtYg5qKKiIiLnt+rcR1RhDQIUFlrHx6sBANQGNfH1hrBQ3pcuhlsSAADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALCSX0XO9OnT1aNHD4WGhqpJkya69dZbtWfPnl89bt26dYqJiVFISIguv/xyvfLKK6XGrFixQtHR0XI6nYqOjta7775bHZcAAAD8hF9Fzrp16/Too49q06ZNSktLU0lJieLj43Xy5MmLHnPgwAENHjxYvXr1UkZGhqZMmaLHHntMK1as8IxJT0/XsGHDlJCQoG3btikhIUFDhw7Vl19+WROXBQAAfMBhjDG+XsTF5OXlqUmTJlq3bp169+59wTFPPPGEVq5cqW+++cazb8yYMdq2bZvS09MlScOGDZPb7daHH37oGTNw4EA1bNhQb775ZrnW4na75XK5dPzbyxUWWqcSVwUAsJ274Iwattuv/Px8hYWFVc85quh9qSbW6it+dSfnl/Lz8yVJv/vd7y46Jj09XfHx8V77BgwYoC1btqi4uLjMMRs3brzovIWFhXK73V4bAACoPfw2cowxmjhxoq677jp17tz5ouNyc3MVHh7utS88PFwlJSU6evRomWNyc3MvOu/06dPlcrk8W1RUVCWuBgAA1DS/jZyxY8dq+/bt5fo4yeFweD0+9wnc+fsvNOaX+86XmJio/Px8z5adnV2R5QMAAB8L9PUCLmTcuHFauXKlPv/8czVv3rzMsREREaXuyBw5ckSBgYH6/e9/X+aYX97dOZ/T6ZTT6bzEKwAAAL7mV3dyjDEaO3as3nnnHX322Wdq1arVrx4TFxentLQ0r32rV69W9+7dFRQUVOaYnj17Vt3iAQCAX/GryHn00Ue1bNkyvfHGGwoNDVVubq5yc3N16tQpz5jExETdd999nsdjxozRoUOHNHHiRH3zzTdavHixFi1apEmTJnnGjB8/XqtXr9bMmTO1e/duzZw5U5988okmTJhQk5cHAABqkF9FzoIFC5Sfn6++ffsqMjLSs6WmpnrG5OTkKCsry/O4VatWWrVqldauXatu3brpmWee0dy5c3XHHXd4xvTs2VNvvfWWUlJS1KVLFy1ZskSpqamKjY2t0esDAAA1x6//nhx/wt+TAwAoL/6eHP/gV3dyAAAAqgqRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AACgXh8NR5jZy5EhfL9FLoK8XAAAAaoecnBzPf6empupvf/ub9uzZ49lXt25dr/HFxcUKCgqqsfX9EndyAABAuURERHg2l8slh8PheXz69Glddtllevvtt9W3b1+FhIRo2bJlSkpKUrdu3bzmmTNnjlq2bOm1LyUlRR07dlRISIg6dOig5OTkSq+XOzkAAEBut9vrsdPplNPprPA8TzzxhGbNmqWUlBQ5nU4tXLjwV4/5xz/+oalTp+rll1/WlVdeqYyMDI0ePVr169fX/fffX+E1nEPkAAAARUVFeT2eOnWqkpKSKjzPhAkTdPvtt1fomGeeeUazZs3yHNeqVSvt2rVLr776KpEDAAAqJzs7W2FhYZ7Hl3IXR5K6d+9eofF5eXnKzs7WqFGjNHr0aM/+kpISuVyuS1rDOUQOAABQWFiYV+Rcqvr163s9DggIkDHGa19xcbHnv8+ePSvp54+sYmNjvcbVqVOnUmshcgAAQLVp3LixcnNzZYyRw+GQJGVmZnqeDw8PV7NmzbR//36NGDGiSs9N5AAAgGrTt29f5eXl6fnnn9edd96pjz76SB9++KHXXaOkpCQ99thjCgsL06BBg1RYWKgtW7bo+PHjmjhx4iWfmx8hBwAA1aZjx45KTk7W/Pnz1bVrV3311VeaNGmS15gHH3xQ//znP7VkyRJdccUV6tOnj5YsWaJWrVpV6twO88sPynBBbrdbLpdLx7+9XGGhlfuMEABgN3fBGTVst1/5+flV8j2XC56jit6XamKtvsKdHAAAYCUiBwAAWInIAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICViBwAAGAlIgcAAFiJyAEAAFYicgAAgJWIHAAAYCUiBwAAWInIAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICViBwAAGAlIgcAAFiJyAEAAFYicgAAgJUCfb2A2ua2dlco0BHk62X85n18eJuvl4DzDGja1ddLwP/h9wbwP9zJAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICViBwAAGAlIgcAAFiJyAEAAFYicgAAgJWIHAAAYCUiBwAAWInIAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICViBwAAGAlIgcAAFiJyAEAAFYicgAAgJWIHAAAYCUiBwAAWInIAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICViBwAAGAlIgcAAFiJyAEAAFYicgAAgJWIHAAAYCUiBwAAWInIAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICViBwAAGAlIgcAAFiJyAEAAFYicgAAgJWIHAAAYCUiBwAAWInIAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICViBwAAGAlIgcAAFjJLyMnOTlZrVq1UkhIiGJiYrR+/fqLjl27dq0cDkepbffu3V7jVqxYoejoaDmdTkVHR+vdd9+t7ssAAAA+5HeRk5qaqgkTJujJJ59URkaGevXqpUGDBikrK6vM4/bs2aOcnBzP1rZtW89z6enpGjZsmBISErRt2zYlJCRo6NCh+vLLL6v7cgAAgI9UOHI++eSTiz736quvVmoxkjR79myNGjVKDz74oDp27Kg5c+YoKipKCxYsKPO4Jk2aKCIiwrPVqVPH89ycOXPUv39/JSYmqkOHDkpMTFS/fv00Z86cSq8XAAD4pwpHzo033qg///nPKioq8uzLy8vTTTfdpMTExEotpqioSFu3blV8fLzX/vj4eG3cuLHMY6+88kpFRkaqX79+WrNmjddz6enppeYcMGBAmXMWFhbK7XZ7bQAAoPYIrOgBn3/+uRISEvTJJ5/ojTfe0MGDB/XAAw8oOjpa27Ztq9Rijh49qjNnzig8PNxrf3h4uHJzcy94TGRkpBYuXKiYmBgVFhZq6dKl6tevn9auXavevXtLknJzcys0pyRNnz5d06ZNq9T1AABQ3W5rd4UCHUGXfHyJKZa0v+oW5EcqHDmxsbHKyMjQmDFjFBMTo7Nnz+rvf/+7Jk+eLIfDUSWL+uU8xpiLzt2+fXu1b9/e8zguLk7Z2dl68cUXPZFT0TklKTExURMnTvQ8drvdioqKqtB1AAAA37mkLx7v2bNHmzdvVvPmzRUYGKjdu3frp59+qvRiGjVqpDp16pS6w3LkyJFSd2LKcs0112jv3r2exxERERWe0+l0KiwszGsDAAC1R4UjZ8aMGYqLi1P//v319ddfa/PmzcrIyFCXLl2Unp5eqcUEBwcrJiZGaWlpXvvT0tLUs2fPcs+TkZGhyMhIz+O4uLhSc65evbpCcwIAgNqlwh9XvfTSS/rPf/6jQYMGSZI6deqkr776SlOmTFHfvn1VWFhYqQVNnDhRCQkJ6t69u+Li4rRw4UJlZWVpzJgxkn7+GOm7777T66+/Lunnn5xq2bKlOnXqpKKiIi1btkwrVqzQihUrPHOOHz9evXv31syZM3XLLbfovffe0yeffKINGzZUaq0AAMB/VThyduzYoUaNGnntCwoK0gsvvKAhQ4ZUekHDhg3TsWPH9PTTTysnJ0edO3fWqlWr1KJFC0lSTk6O19+ZU1RUpEmTJum7775T3bp11alTJ33wwQcaPHiwZ0zPnj311ltv6a9//aueeuoptW7dWqmpqYqNja30egEAgH9yGGOMrxdRG7jdbrlcLvXVLZX6FjuqxseHK/eTfKhaA5p29fUS8H/4veEf3AVn1LDdfuXn51fbdzqr6n2pxBRrrd6r1rX6it/9jccAAABVgcgBAABWInIAAICViBwAAGAlIgcAAFiJyAEAAFYicgAAgJWIHAAAYCUiBwAAWInIAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICViBwAAGAlIgcAAFiJyAEAAFYicgAAgJWIHAAAYCUiBwAAWInIAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICViBwAAGAlIgcAAFiJyAEAAFYicgAAgJWIHAAAYCUiBwAAWInIAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICViBwAAGAlIgcAAFiJyAEAAFYicgAAgJWIHAAAYCUiBwAAWInIAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICViBwAAGAlIgcAAFiJyAEAAFYicgAAgJUCfb0A4FIMaNrV10sA/BK/N/xDiSmWtN/Xy/jN404OAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACq3ZIlS3TZZZfV6DmJHAAAUG4jR46Uw+Eote3bt8/XSysl0NcLAAAAtcvAgQOVkpLita9x48Y+Ws3FcScHAABUiNPpVEREhNf20ksv6YorrlD9+vUVFRWlRx55RCdOnLjoHNu2bdP111+v0NBQhYWFKSYmRlu2bPE8v3HjRvXu3Vt169ZVVFSUHnvsMZ08ebJC6yRyAACA3G6311ZYWFih4wMCAjR37lx9/fXXeu211/TZZ5/p8ccfv+j4ESNGqHnz5tq8ebO2bt2qv/zlLwoKCpIk7dixQwMGDNDtt9+u7du3KzU1VRs2bNDYsWMrtCY+rgIAAIqKivJ6PHXqVCUlJV1w7Pvvv68GDRp4Hg8aNEj/+te/PI9btWqlZ555Rg8//LCSk5MvOEdWVpYmT56sDh06SJLatm3ree6FF17Q8OHDNWHCBM9zc+fOVZ8+fbRgwQKFhISU65qIHAAAoOzsbIWFhXkeO53Oi469/vrrtWDBAs/j+vXra82aNXruuee0a9cuud1ulZSU6PTp0zp58qTq169fao6JEyfqwQcf1NKlS/WHP/xBd911l1q3bi1J2rp1q/bt26fly5d7xhtjdPbsWR04cEAdO3Ys1zXxcRUAAFBYWJjXVlbk1K9fX23atPFsRUVFGjx4sDp37qwVK1Zo69atmj9/viSpuLj4gnMkJSVp586duvHGG/XZZ58pOjpa7777riTp7Nmz+uMf/6jMzEzPtm3bNu3du9cTQuXBnRwAAFApW7ZsUUlJiWbNmqWAgJ/vn7z99tu/ely7du3Url07/elPf9I999yjlJQU3Xbbbbrqqqu0c+dOtWnTplLr4k4OAAColNatW6ukpETz5s3T/v37tXTpUr3yyisXHX/q1CmNHTtWa9eu1aFDh/TFF19o8+bNno+hnnjiCaWnp+vRRx9VZmam9u7dq5UrV2rcuHEVWheRAwAAKqVbt26aPXu2Zs6cqc6dO2v58uWaPn36RcfXqVNHx44d03333ad27dpp6NChGjRokKZNmyZJ6tKli9atW6e9e/eqV69euvLKK/XUU08pMjKyQutyGGNMpa7sN8LtdsvlcqmvblGgI8jXywEA+LESU6y1ek/5+fleX+atSlX1vlQTa/UV7uQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArORXkWOMUVJSkpo2baq6deuqb9++2rlzZ5nHLFmyRA6Ho9R2+vRpr3HJyclq1aqVQkJCFBMTo/Xr11fnpQAAAB/zq8h5/vnnNXv2bL388svavHmzIiIi1L9/fxUUFJR5XFhYmHJycry2kJAQz/OpqamaMGGCnnzySWVkZKhXr14aNGiQsrKyqvuSAACAj/hN5BhjNGfOHD355JO6/fbb1blzZ7322mv66aef9MYbb5R5rMPhUEREhNd2vtmzZ2vUqFF68MEH1bFjR82ZM0dRUVFasGBBdV4SAADwIb+JnAMHDig3N1fx8fGefU6nU3369NHGjRvLPPbEiRNq0aKFmjdvriFDhigjI8PzXFFRkbZu3eo1ryTFx8eXOW9hYaHcbrfXBgAAag+/iZzc3FxJUnh4uNf+8PBwz3MX0qFDBy1ZskQrV67Um2++qZCQEF177bXau3evJOno0aM6c+ZMheedPn26XC6XZ4uKirrUSwMAAD7gs8hZvny5GjRo4NmKi4sl/fzR0/mMMaX2ne+aa67Rvffeq65du6pXr156++231a5dO82bN89rXEXnTUxMVH5+vmfLzs6u6CUCAAAfCvTViW+++WbFxsZ6HhcWFkr6+Y5OZGSkZ/+RI0dK3YUpS0BAgHr06OG5k9OoUSPVqVOn1F2bX5vX6XTK6XSW+7wAAMC/+OxOTmhoqNq0aePZoqOjFRERobS0NM+YoqIirVu3Tj179iz3vMYYZWZmekIpODhYMTExXvNKUlpaWoXmBQAAtYvP7uT8ksPh0IQJE/Tcc8+pbdu2atu2rZ577jnVq1dPw4cP94y777771KxZM02fPl2SNG3aNF1zzTVq27at3G635s6dq8zMTM2fP99zzMSJE5WQkKDu3bsrLi5OCxcuVFZWlsaMGVPj1wkAAGqG30SOJD3++OM6deqUHnnkER0/flyxsbFavXq1QkNDPWOysrIUEPC/G1A//vijHnroIeXm5srlcunKK6/U559/rquvvtozZtiwYTp27Jiefvpp5eTkqHPnzlq1apVatGhRo9cHAABqjsMYY3y9iNrA7XbL5XKpr25RoCPI18sBAPixElOstXpP+fn5CgsLq5ZzVNX7Uk2s1Vf85kfIAQAAqhKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwEpEDAACsROQAAAArETkAAMBKRA4AALASkQMAAKxE5AAAACsROQAAwEpEDgAAsBKRAwAArETkAAAAKxE5AADASkQOAACwUqCvF1BbGGMkSSUqloyPFwMA8GslKpb0v/eOaj9XJU5zbq02InLKqaCgQJK0Qat8vBIAQG1RUFAgl8tVLXMHBwcrIiJCG3Ir/74UERGh4ODgKliVf3GYmshMC5w9e1aHDx9WaGioHA6Hr5dzydxut6KiopSdna2wsDBfL+c3jdfCf/Ba+A9bXgtjjAoKCtS0aVMFBFTfN0NOnz6toqKiSs8THByskJCQKliRf+FOTjkFBASoefPmvl5GlQkLC6vVf4DYhNfCf/Ba+A8bXovquoNzvpCQECvjpKrwxWMAAGAlIgcAAFiJyPmNcTqdmjp1qpxOp6+X8pvHa+E/eC38B68FqhJfPAYAAFbiTg4AALASkQMAAKxE5AAAACsROQAAwEpEDrw8++yz6tmzp+rVq6fLLrus1PPbt2/XbbfdpiZNmsjlcunOO+/U0aNHa36hFjt+/LgSEhLkcrnkcrmUkJCgH3/88aLjT58+rSlTpqht27aqV6+eunbtqrS0tJpbsJ/auHGj6tSpo4EDB3rtP3jwoBwOh5o0aeL551rO6datm5KSkjyP+/btK4fDIYfDIafTqWbNmummm27SO++8UxOXYLXs7GyNGjVKTZs2VXBwsFq0aKHx48fr2LFjnjHn//oHBwerdevWSkxMVGFhoQ9XjtqEyLFMVlZWpY4vKirSXXfdpYcffviCz69fv17XXnut1qxZo9WrV2vHjh2aPHlypc5pm8q+BsOHD1dmZqY++ugjffTRR8rMzFRCQsJFxx85ckRZWVlatGiRduzYobi4ON122206efJkpdZR2y1evFjjxo3Thg0bLviaFBQU6MUXX/zVeUaPHq2cnBzt27dPK1asUHR0tO6++2499NBD1bHs34T9+/ere/fu+vbbb/Xmm29q3759euWVV/Tpp58qLi5OP/zwg2fs+b/+zz//vObPn+8VokCZDPzSkSNHTHh4uHn22Wc9+zZt2mSCgoLMxx9/fNHjWrZsaWJjY01ycrL54YcfLvn8KSkpxuVy/eq4cePGmX79+l3yefyZL16DXbt2GUlm06ZNnn3p6elGktm9e3e55ti6dauRZLKzsyt0bpucOHHChIaGmt27d5thw4aZadOmeZ47cOCAkWQmT55sGjRoYL7//nvPc127djVTp071PO7Tp48ZP358qfkXL15sJJm0tLTqvAxrDRw40DRv3tz89NNPXvtzcnJMvXr1zJgxY4wxF/71v/32281VV11VU0tFLcedHD/VuHFjLV68WElJSdqyZYtOnDihe++9V4888oji4+Mvetznn3+um2++WXPnzlVkZKSGDh2qDz74QCUlJVW+xm3btun111/XAw88UOVz+wNfvAbp6elyuVyKjY317Lvmmmvkcrm0cePGXz2+sLBQiYmJ6t+/v1X/1lpFpaamqn379mrfvr3uvfdepaSkyPzirwS755571KZNGz399NMVnv/+++9Xw4YN+djqEvzwww/6+OOP9cgjj6hu3bpez0VERGjEiBFKTU0t9XpJP/+Z88UXXygoKKimlotajsjxY4MHD9bo0aM1YsQIjRkzRiEhIZoxY0aZx0RFRWnKlCn65ptvtH79eoWHh2vkyJGKiorSpEmT9PXXX1fJ2rZv367rr79ef/3rXzV8+PAqmdMf1fRrkJubqyZNmpTa36RJE+Xm5pZ53pKSEt188806ceKE/v3vf5fvAi21aNEi3XvvvZKkgQMH6sSJE/r000+9xjgcDs2YMUMLFy7Uf//73wrNHxAQoHbt2ungwYNVteTfjL1798oYo44dO17w+Y4dO+r48ePKy8uTJCUnJ6tBgwZyOp3q1q2b8vLy+Igc5Ubk+LkXX3xRJSUlevvtt7V8+fIK/WuzPXr00Lx58/Tdd99p+PDhmj17tucP/spKSkrSgAEDNGnSpCqZz5/V9GvgcDhK7TPGXHD/+VatWqUNGzbo/fffr/X/enNl7NmzR1999ZXuvvtuSVJgYKCGDRumxYsXlxo7YMAAXXfddXrqqacqfJ7yvCaouHN3cM792o4YMUKZmZlKT0/X0KFD9cADD+iOO+7w5RJRiwT6egEo2/79+3X48GGdPXtWhw4dUpcuXcp97J49e7R06VItW7ZM+fn5Gj16tEaNGlUl6zpw4IBuueWWKpnL39XkaxAREaHvv/++1P68vDyFh4eXea4DBw6ocePGatiwYbnXZ6NFixappKREzZo18+wzxigoKEjHjx8vNX7GjBmKi4ur0N2BM2fOaO/everRo0eVrPm3pE2bNnI4HNq1a5duvfXWUs/v3r1bDRs2VKNGjSRJLpdLbdq0kSQtW7ZMnTp10qJFi6rszzJYzoffB8KvKCwsNF27djX333+/mT59umncuLHJzc0t85i8vDwzb948c/XVV5s6deqYgQMHmjfffNOcOnWqQuf+tS8e79y503z33XcVmrM2qunX4NwXj7/88kvPvk2bNpXri8fff/+92b59e/kuzFLFxcUmPDzczJo1y+zYscNra9eunZk3b57ni8cZGRme4+68804THx9f7i8eL1q0yEgyn332WfVflIXi4+NNs2bNLumLxykpKSYiIsKcPHmyppaLWozI8WOTJk0yLVu2NPn5+ebMmTOmd+/e5sYbbyzzmMsvv9xER0ebmTNnmsOHD1f4nIcOHTIZGRlm2rRppkGDBiYjI8NkZGSYgoICr3E33HCDmTdvXoXnr2188RoMHDjQdOnSxaSnp5v09HRzxRVXmCFDhvzqcfPmzTM33HBDhc9nk3fffdcEBwebH3/8sdRzU6ZMMd26dbtg5OzZs8cEBgaakJCQUpEzevRok5OTY7Kzs82mTZvM448/boKCgszDDz9cA1dkp2+//dY0atTI9OrVy6xbt85kZWWZDz/80HTu3Nm0bdvWHDt2zBhz4cgpLCw0kZGR5oUXXvDBylHbEDl+as2aNSYwMNCsX7/es+/QoUPG5XKZ5OTkix73zTffVOq8999/v5FUaluzZo3XuBYtWni9GdjIV6/BsWPHzIgRI0xoaKgJDQ01I0aMMMePH//V46ZOnWpatGhRqXPXdkOGDDGDBw++4HPnfrT+3P+eHznGGPPQQw8ZSaUi59zvgeDgYBMZGWmGDBli3nnnnWq8it+GgwcPmpEjR5qIiAgTFBRkoqKizLhx48zRo0c9Yy52J+3ZZ581jRs3LvV/voBfchhzgZ/TAwAAqOX46SoAAGAlIgcAAFiJyAEAAFYicgAAgJWIHAAAYCUiBwAAWInIAQAAViJyAACAlYgcAABgJSIHAABYicgBAABWInIAAICV/j9jXjjVmOTnNwAAAABJRU5ErkJggg==", 298 | "text/plain": [ 299 | "
" 300 | ] 301 | }, 302 | "metadata": {}, 303 | "output_type": "display_data" 304 | } 305 | ], 306 | "source": [ 307 | "import matplotlib.pyplot as plt\n", 308 | "\n", 309 | "x = [-0.5, 0.5, 2.0]\n", 310 | "\n", 311 | "\n", 312 | "x_labels = [\"x > 1?\", \"x > 0 ?\", \"AND\", \"OR\"]\n", 313 | "y_labels = [f\"{xi:.2f}\" for xi in x]\n", 314 | "\n", 315 | "grid = []\n", 316 | "for xi in x:\n", 317 | " row = [\n", 318 | " is_greater_than_one(xi),\n", 319 | " is_positive(xi),\n", 320 | " logical_and(xi),\n", 321 | " logical_or(xi),\n", 322 | " ]\n", 323 | " grid.append(row)\n", 324 | "\n", 325 | "\n", 326 | "tmp = plt.pcolormesh(x_labels, y_labels, grid, cmap=plt.get_cmap(\"viridis\", 2))\n", 327 | "formatter = plt.FuncFormatter(lambda _val, loc: str(bool(loc)))\n", 328 | "plt.colorbar(ticks=[0.25, 0.75], format=formatter)\n", 329 | "plt.ylabel(\"x\")\n", 330 | "plt.show()" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "id": "03cfc4b2", 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [] 340 | } 341 | ], 342 | "metadata": { 343 | "kernelspec": { 344 | "display_name": "Python [conda env:inversion_ideas]", 345 | "language": "python", 346 | "name": "conda-env-inversion_ideas-py" 347 | }, 348 | "language_info": { 349 | "codemirror_mode": { 350 | "name": "ipython", 351 | "version": 3 352 | }, 353 | "file_extension": ".py", 354 | "mimetype": "text/x-python", 355 | "name": "python", 356 | "nbconvert_exporter": "python", 357 | "pygments_lexer": "ipython3", 358 | "version": "3.13.5" 359 | } 360 | }, 361 | "nbformat": 4, 362 | "nbformat_minor": 5 363 | } 364 | -------------------------------------------------------------------------------- /notebooks/02_linear-regressor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "0ed66bfb-683e-4b0a-a4ec-50050f295070", 6 | "metadata": {}, 7 | "source": [ 8 | "# Use inversion framework to fit a linear regressor" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "bba9bf32-da93-4100-9ee1-0983a1450bb4", 15 | "metadata": { 16 | "execution": { 17 | "iopub.execute_input": "2025-10-07T20:41:41.072339Z", 18 | "iopub.status.busy": "2025-10-07T20:41:41.071897Z", 19 | "iopub.status.idle": "2025-10-07T20:41:42.803075Z", 20 | "shell.execute_reply": "2025-10-07T20:41:42.802417Z", 21 | "shell.execute_reply.started": "2025-10-07T20:41:41.072297Z" 22 | } 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "import numpy as np\n", 27 | "from regressor import LinearRegressor\n", 28 | "\n", 29 | "import inversion_ideas as ii" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "afa3d214-09d3-4a0f-a402-5793d8b308b2", 36 | "metadata": { 37 | "execution": { 38 | "iopub.execute_input": "2025-10-07T20:41:42.804180Z", 39 | "iopub.status.busy": "2025-10-07T20:41:42.803747Z", 40 | "iopub.status.idle": "2025-10-07T20:41:42.813215Z", 41 | "shell.execute_reply": "2025-10-07T20:41:42.812577Z", 42 | "shell.execute_reply.started": "2025-10-07T20:41:42.804150Z" 43 | } 44 | }, 45 | "outputs": [ 46 | { 47 | "data": { 48 | "text/plain": [ 49 | "array([0.78225148, 0.67148671, 0.2373809 , 0.17946133, 0.34662367,\n", 50 | " 0.15210999, 0.31142952, 0.23900652, 0.54355731, 0.91770851])" 51 | ] 52 | }, 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "n_params = 10\n", 60 | "rng = np.random.default_rng(seed=4242)\n", 61 | "true_model = rng.uniform(size=10)\n", 62 | "true_model" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "id": "5485a55c-9727-4260-b590-1ea243dba484", 69 | "metadata": { 70 | "execution": { 71 | "iopub.execute_input": "2025-10-07T20:41:42.814045Z", 72 | "iopub.status.busy": "2025-10-07T20:41:42.813828Z", 73 | "iopub.status.idle": "2025-10-07T20:41:42.827155Z", 74 | "shell.execute_reply": "2025-10-07T20:41:42.826516Z", 75 | "shell.execute_reply.started": "2025-10-07T20:41:42.814025Z" 76 | } 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "# Build the X array\n", 81 | "n_data = 25\n", 82 | "shape = (n_data, n_params)\n", 83 | "X = rng.uniform(size=n_data * n_params).reshape(shape)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 4, 89 | "id": "66458a9c-5e01-4a75-8fd0-fd149f447992", 90 | "metadata": { 91 | "execution": { 92 | "iopub.execute_input": "2025-10-07T20:41:42.827962Z", 93 | "iopub.status.busy": "2025-10-07T20:41:42.827756Z", 94 | "iopub.status.idle": "2025-10-07T20:41:42.836338Z", 95 | "shell.execute_reply": "2025-10-07T20:41:42.835624Z", 96 | "shell.execute_reply.started": "2025-10-07T20:41:42.827943Z" 97 | } 98 | }, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "array([2.83840696, 2.18091081, 2.00623242, 2.08333039, 2.01694883,\n", 104 | " 2.7826232 , 2.10564027, 1.27333506, 2.08859855, 1.94177648,\n", 105 | " 1.88492037, 2.92394733, 2.17231952, 3.08009275, 1.61670886,\n", 106 | " 1.77403753, 2.67305005, 1.91413882, 2.42117827, 2.13991628,\n", 107 | " 2.0153805 , 2.71388471, 2.65944255, 2.44416121, 3.14217523])" 108 | ] 109 | }, 110 | "execution_count": 4, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "synthetic_data = X @ true_model\n", 117 | "maxabs = np.max(np.abs(synthetic_data))\n", 118 | "noise = rng.normal(scale=1e-2 * maxabs, size=synthetic_data.size)\n", 119 | "synthetic_data += noise\n", 120 | "synthetic_data" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "id": "ad91f721-72a9-485b-9043-d85a2a220b7b", 127 | "metadata": { 128 | "execution": { 129 | "iopub.execute_input": "2025-10-07T20:41:42.837858Z", 130 | "iopub.status.busy": "2025-10-07T20:41:42.837191Z", 131 | "iopub.status.idle": "2025-10-07T20:41:42.844920Z", 132 | "shell.execute_reply": "2025-10-07T20:41:42.844113Z", 133 | "shell.execute_reply.started": "2025-10-07T20:41:42.837810Z" 134 | } 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "simulation = LinearRegressor(X)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 6, 144 | "id": "04636169-320a-4f1e-ab42-1b083b0111dd", 145 | "metadata": { 146 | "execution": { 147 | "iopub.execute_input": "2025-10-07T20:41:44.346944Z", 148 | "iopub.status.busy": "2025-10-07T20:41:44.346279Z", 149 | "iopub.status.idle": "2025-10-07T20:41:44.351596Z", 150 | "shell.execute_reply": "2025-10-07T20:41:44.350684Z", 151 | "shell.execute_reply.started": "2025-10-07T20:41:44.346903Z" 152 | } 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "uncertainty = 1e-2 * maxabs * np.ones_like(synthetic_data)\n", 157 | "data_misfit = ii.DataMisfit(synthetic_data, uncertainty, simulation)\n", 158 | "smallness = ii.TikhonovZero(n_params)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 7, 164 | "id": "8da79647-eb17-4177-afb9-776f3cc4ffc5", 165 | "metadata": { 166 | "execution": { 167 | "iopub.execute_input": "2025-10-07T20:41:44.935217Z", 168 | "iopub.status.busy": "2025-10-07T20:41:44.933833Z", 169 | "iopub.status.idle": "2025-10-07T20:41:44.952204Z", 170 | "shell.execute_reply": "2025-10-07T20:41:44.950839Z", 171 | "shell.execute_reply.started": "2025-10-07T20:41:44.935133Z" 172 | } 173 | }, 174 | "outputs": [ 175 | { 176 | "data": { 177 | "text/latex": [ 178 | "$ \\phi_{d} (m) + 1.00 \\cdot 10^{-3} \\, \\phi_{0} (m) $" 179 | ], 180 | "text/plain": [ 181 | "φd(m) + 0.00 φ0(m)" 182 | ] 183 | }, 184 | "execution_count": 7, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "phi = data_misfit + 1e-3 * smallness\n", 191 | "phi" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 8, 197 | "id": "d7bb890a-2958-4ea4-aa64-413f57c7b63e", 198 | "metadata": { 199 | "execution": { 200 | "iopub.execute_input": "2025-10-07T20:41:46.116856Z", 201 | "iopub.status.busy": "2025-10-07T20:41:46.115970Z", 202 | "iopub.status.idle": "2025-10-07T20:41:46.125063Z", 203 | "shell.execute_reply": "2025-10-07T20:41:46.124045Z", 204 | "shell.execute_reply.started": "2025-10-07T20:41:46.116782Z" 205 | } 206 | }, 207 | "outputs": [ 208 | { 209 | "data": { 210 | "text/plain": [ 211 | "array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])" 212 | ] 213 | }, 214 | "execution_count": 8, 215 | "metadata": {}, 216 | "output_type": "execute_result" 217 | } 218 | ], 219 | "source": [ 220 | "initial_model = np.zeros(n_params)\n", 221 | "initial_model" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "id": "5cc54476-cd31-4bf6-bc0f-bfd6254af7ec", 227 | "metadata": {}, 228 | "source": [ 229 | "## Minimize manually with `scipy.sparse.linalg.cg`" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 9, 235 | "id": "d4bac3a2-49d5-4592-a793-72789ec31a5c", 236 | "metadata": { 237 | "execution": { 238 | "iopub.execute_input": "2025-10-07T20:41:47.356818Z", 239 | "iopub.status.busy": "2025-10-07T20:41:47.355950Z", 240 | "iopub.status.idle": "2025-10-07T20:41:47.364506Z", 241 | "shell.execute_reply": "2025-10-07T20:41:47.362877Z", 242 | "shell.execute_reply.started": "2025-10-07T20:41:47.356743Z" 243 | } 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "from scipy.sparse.linalg import cg" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 10, 253 | "id": "c3369162-abe9-4fb9-9294-69277e50ef13", 254 | "metadata": { 255 | "execution": { 256 | "iopub.execute_input": "2025-10-07T20:41:48.113164Z", 257 | "iopub.status.busy": "2025-10-07T20:41:48.111819Z", 258 | "iopub.status.idle": "2025-10-07T20:41:48.121557Z", 259 | "shell.execute_reply": "2025-10-07T20:41:48.120420Z", 260 | "shell.execute_reply.started": "2025-10-07T20:41:48.113095Z" 261 | } 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "grad = phi.gradient(initial_model)\n", 266 | "hess = phi.hessian(initial_model)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 11, 272 | "id": "c1fed88c-9c34-4ee3-8208-e7b50c2ec678", 273 | "metadata": { 274 | "execution": { 275 | "iopub.execute_input": "2025-10-07T20:41:48.527230Z", 276 | "iopub.status.busy": "2025-10-07T20:41:48.526459Z", 277 | "iopub.status.idle": "2025-10-07T20:41:48.540500Z", 278 | "shell.execute_reply": "2025-10-07T20:41:48.537985Z", 279 | "shell.execute_reply.started": "2025-10-07T20:41:48.527156Z" 280 | } 281 | }, 282 | "outputs": [ 283 | { 284 | "data": { 285 | "text/plain": [ 286 | "(array([0.81328886, 0.65927515, 0.24729371, 0.19624752, 0.3237346 ,\n", 287 | " 0.14720343, 0.3194468 , 0.25235983, 0.52215485, 0.92181019]),\n", 288 | " 0)" 289 | ] 290 | }, 291 | "execution_count": 11, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "model_step, info = cg(hess, -grad)\n", 298 | "model_step, info" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 12, 304 | "id": "54a8ef70-6e5f-4000-ad2f-658dbffd83f1", 305 | "metadata": { 306 | "execution": { 307 | "iopub.execute_input": "2025-10-07T20:41:49.657165Z", 308 | "iopub.status.busy": "2025-10-07T20:41:49.656672Z", 309 | "iopub.status.idle": "2025-10-07T20:41:49.662335Z", 310 | "shell.execute_reply": "2025-10-07T20:41:49.661276Z", 311 | "shell.execute_reply.started": "2025-10-07T20:41:49.657122Z" 312 | } 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "inverted_model = initial_model + model_step" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 13, 322 | "id": "b05f913f-dcfa-4bca-a277-b9cf5c9de6b5", 323 | "metadata": { 324 | "execution": { 325 | "iopub.execute_input": "2025-10-07T20:41:49.860135Z", 326 | "iopub.status.busy": "2025-10-07T20:41:49.858845Z", 327 | "iopub.status.idle": "2025-10-07T20:41:49.874309Z", 328 | "shell.execute_reply": "2025-10-07T20:41:49.872575Z", 329 | "shell.execute_reply.started": "2025-10-07T20:41:49.860017Z" 330 | } 331 | }, 332 | "outputs": [ 333 | { 334 | "name": "stdout", 335 | "output_type": "stream", 336 | "text": [ 337 | "Result:\n", 338 | "[0.81328886 0.65927515 0.24729371 0.19624752 0.3237346 0.14720343\n", 339 | " 0.3194468 0.25235983 0.52215485 0.92181019]\n", 340 | "\n", 341 | "True model:\n", 342 | "[0.78225148 0.67148671 0.2373809 0.17946133 0.34662367 0.15210999\n", 343 | " 0.31142952 0.23900652 0.54355731 0.91770851]\n" 344 | ] 345 | } 346 | ], 347 | "source": [ 348 | "print(\"Result:\")\n", 349 | "print(inverted_model)\n", 350 | "print()\n", 351 | "print(\"True model:\")\n", 352 | "print(true_model)" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "id": "ad0f7d22-77d6-410c-bca1-e3d765006aff", 358 | "metadata": {}, 359 | "source": [ 360 | "## Minimize with SciPy's `minimize`" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 14, 366 | "id": "5f36ec3b-0380-4eeb-b01a-3191c12b70f5", 367 | "metadata": { 368 | "execution": { 369 | "iopub.execute_input": "2025-10-07T20:41:50.978736Z", 370 | "iopub.status.busy": "2025-10-07T20:41:50.977541Z", 371 | "iopub.status.idle": "2025-10-07T20:41:50.985613Z", 372 | "shell.execute_reply": "2025-10-07T20:41:50.984378Z", 373 | "shell.execute_reply.started": "2025-10-07T20:41:50.978663Z" 374 | } 375 | }, 376 | "outputs": [], 377 | "source": [ 378 | "from scipy.optimize import minimize" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 15, 384 | "id": "ba337ea8-0afb-46aa-b107-07bcf86a8324", 385 | "metadata": { 386 | "execution": { 387 | "iopub.execute_input": "2025-10-07T20:41:51.899274Z", 388 | "iopub.status.busy": "2025-10-07T20:41:51.898891Z", 389 | "iopub.status.idle": "2025-10-07T20:41:52.082491Z", 390 | "shell.execute_reply": "2025-10-07T20:41:52.081784Z", 391 | "shell.execute_reply.started": "2025-10-07T20:41:51.899236Z" 392 | } 393 | }, 394 | "outputs": [ 395 | { 396 | "data": { 397 | "text/plain": [ 398 | " message: Optimization terminated successfully.\n", 399 | " success: True\n", 400 | " status: 0\n", 401 | " fun: 11.719746150008518\n", 402 | " x: [ 8.133e-01 6.593e-01 2.473e-01 1.963e-01 3.237e-01\n", 403 | " 1.472e-01 3.195e-01 2.524e-01 5.222e-01 9.218e-01]\n", 404 | " nit: 15\n", 405 | " jac: [ 4.768e-07 -4.530e-06 -2.384e-06 1.907e-06 1.192e-06\n", 406 | " -2.146e-06 -1.431e-06 -4.292e-06 -2.027e-06 0.000e+00]\n", 407 | " hess_inv: [[ 3.378e-04 -1.043e-04 ... 1.536e-05 -9.322e-05]\n", 408 | " [-1.043e-04 4.045e-04 ... -1.009e-04 -1.335e-05]\n", 409 | " ...\n", 410 | " [ 1.536e-05 -1.009e-04 ... 2.159e-04 -5.467e-05]\n", 411 | " [-9.322e-05 -1.335e-05 ... -5.467e-05 2.625e-04]]\n", 412 | " nfev: 253\n", 413 | " njev: 23" 414 | ] 415 | }, 416 | "execution_count": 15, 417 | "metadata": {}, 418 | "output_type": "execute_result" 419 | } 420 | ], 421 | "source": [ 422 | "result = minimize(phi, initial_model)\n", 423 | "result" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": 16, 429 | "id": "d5cf83b5-6cbd-42d1-9513-d3ca38464b72", 430 | "metadata": { 431 | "execution": { 432 | "iopub.execute_input": "2025-10-07T20:41:52.177045Z", 433 | "iopub.status.busy": "2025-10-07T20:41:52.176423Z", 434 | "iopub.status.idle": "2025-10-07T20:41:52.183531Z", 435 | "shell.execute_reply": "2025-10-07T20:41:52.182450Z", 436 | "shell.execute_reply.started": "2025-10-07T20:41:52.176995Z" 437 | } 438 | }, 439 | "outputs": [], 440 | "source": [ 441 | "# The minimize already gives you the minimum model\n", 442 | "inverted_model = result.x" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 17, 448 | "id": "d25be88b-3134-447a-b003-c80dddff3166", 449 | "metadata": { 450 | "execution": { 451 | "iopub.execute_input": "2025-10-07T20:41:52.895646Z", 452 | "iopub.status.busy": "2025-10-07T20:41:52.894809Z", 453 | "iopub.status.idle": "2025-10-07T20:41:52.904128Z", 454 | "shell.execute_reply": "2025-10-07T20:41:52.903103Z", 455 | "shell.execute_reply.started": "2025-10-07T20:41:52.895565Z" 456 | } 457 | }, 458 | "outputs": [ 459 | { 460 | "name": "stdout", 461 | "output_type": "stream", 462 | "text": [ 463 | "Result:\n", 464 | "[0.8132614 0.65925411 0.24726513 0.19626776 0.32373026 0.14719295\n", 465 | " 0.3194738 0.25235066 0.52216838 0.92184154]\n", 466 | "\n", 467 | "True model:\n", 468 | "[0.78225148 0.67148671 0.2373809 0.17946133 0.34662367 0.15210999\n", 469 | " 0.31142952 0.23900652 0.54355731 0.91770851]\n" 470 | ] 471 | } 472 | ], 473 | "source": [ 474 | "print(\"Result:\")\n", 475 | "print(inverted_model)\n", 476 | "print()\n", 477 | "print(\"True model:\")\n", 478 | "print(true_model)" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 18, 484 | "id": "50cb98a1-1fc8-4378-a517-09ed910ce1c3", 485 | "metadata": { 486 | "execution": { 487 | "iopub.execute_input": "2025-10-07T20:41:53.646890Z", 488 | "iopub.status.busy": "2025-10-07T20:41:53.645552Z", 489 | "iopub.status.idle": "2025-10-07T20:41:53.694048Z", 490 | "shell.execute_reply": "2025-10-07T20:41:53.693375Z", 491 | "shell.execute_reply.started": "2025-10-07T20:41:53.646831Z" 492 | } 493 | }, 494 | "outputs": [ 495 | { 496 | "data": { 497 | "text/plain": [ 498 | " message: Optimization terminated successfully.\n", 499 | " success: True\n", 500 | " status: 0\n", 501 | " fun: 11.719746150007744\n", 502 | " x: [ 8.133e-01 6.593e-01 2.473e-01 1.963e-01 3.237e-01\n", 503 | " 1.472e-01 3.195e-01 2.524e-01 5.222e-01 9.218e-01]\n", 504 | " nit: 12\n", 505 | " jac: [ 1.599e-09 1.157e-09 9.655e-10 1.721e-09 1.377e-09\n", 506 | " 1.441e-09 1.387e-09 8.740e-10 1.384e-09 1.061e-09]\n", 507 | " nfev: 24\n", 508 | " njev: 24" 509 | ] 510 | }, 511 | "execution_count": 18, 512 | "metadata": {}, 513 | "output_type": "execute_result" 514 | } 515 | ], 516 | "source": [ 517 | "result = minimize(phi, initial_model, jac=phi.gradient, method=\"CG\")\n", 518 | "result" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 19, 524 | "id": "45a7c13a-79ea-445f-b63d-b551787ebc20", 525 | "metadata": { 526 | "execution": { 527 | "iopub.execute_input": "2025-10-07T20:41:54.076506Z", 528 | "iopub.status.busy": "2025-10-07T20:41:54.075029Z", 529 | "iopub.status.idle": "2025-10-07T20:41:54.081513Z", 530 | "shell.execute_reply": "2025-10-07T20:41:54.080491Z", 531 | "shell.execute_reply.started": "2025-10-07T20:41:54.076438Z" 532 | } 533 | }, 534 | "outputs": [], 535 | "source": [ 536 | "# The minimize already gives you the minimum model\n", 537 | "inverted_model = result.x" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": 20, 543 | "id": "094fbf19-b676-4aba-8e23-b93abb8b2431", 544 | "metadata": { 545 | "execution": { 546 | "iopub.execute_input": "2025-10-07T20:41:54.898297Z", 547 | "iopub.status.busy": "2025-10-07T20:41:54.897568Z", 548 | "iopub.status.idle": "2025-10-07T20:41:54.905389Z", 549 | "shell.execute_reply": "2025-10-07T20:41:54.904327Z", 550 | "shell.execute_reply.started": "2025-10-07T20:41:54.898240Z" 551 | } 552 | }, 553 | "outputs": [ 554 | { 555 | "name": "stdout", 556 | "output_type": "stream", 557 | "text": [ 558 | "Result:\n", 559 | "[0.8132614 0.65925411 0.24726513 0.19626776 0.32373025 0.14719295\n", 560 | " 0.3194738 0.25235066 0.52216838 0.92184154]\n", 561 | "\n", 562 | "True model:\n", 563 | "[0.78225148 0.67148671 0.2373809 0.17946133 0.34662367 0.15210999\n", 564 | " 0.31142952 0.23900652 0.54355731 0.91770851]\n" 565 | ] 566 | } 567 | ], 568 | "source": [ 569 | "print(\"Result:\")\n", 570 | "print(result.x)\n", 571 | "print()\n", 572 | "print(\"True model:\")\n", 573 | "print(true_model)" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 21, 579 | "id": "65235b83-fd51-4f52-b7ba-3c0dc2a35fe2", 580 | "metadata": { 581 | "execution": { 582 | "iopub.execute_input": "2025-10-07T20:41:55.278124Z", 583 | "iopub.status.busy": "2025-10-07T20:41:55.276367Z", 584 | "iopub.status.idle": "2025-10-07T20:41:55.319827Z", 585 | "shell.execute_reply": "2025-10-07T20:41:55.319010Z", 586 | "shell.execute_reply.started": "2025-10-07T20:41:55.278034Z" 587 | } 588 | }, 589 | "outputs": [ 590 | { 591 | "data": { 592 | "text/plain": [ 593 | " message: Optimization terminated successfully.\n", 594 | " success: True\n", 595 | " status: 0\n", 596 | " fun: 11.71974631421735\n", 597 | " x: [ 8.133e-01 6.592e-01 2.473e-01 1.963e-01 3.237e-01\n", 598 | " 1.472e-01 3.195e-01 2.524e-01 5.222e-01 9.218e-01]\n", 599 | " nit: 10\n", 600 | " jac: [-3.047e-02 -4.461e-02 1.683e-02 2.203e-02 -8.361e-03\n", 601 | " 7.088e-02 -2.080e-02 3.307e-02 -1.339e-02 -1.339e-02]\n", 602 | " nfev: 11\n", 603 | " njev: 11\n", 604 | " nhev: 26" 605 | ] 606 | }, 607 | "execution_count": 21, 608 | "metadata": {}, 609 | "output_type": "execute_result" 610 | } 611 | ], 612 | "source": [ 613 | "result = minimize(\n", 614 | " phi, initial_model, jac=phi.gradient, hess=phi.hessian, method=\"Newton-CG\"\n", 615 | ")\n", 616 | "result" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": 22, 622 | "id": "45df89b6-58fc-472c-bfa6-96f1d3c7cc5c", 623 | "metadata": { 624 | "execution": { 625 | "iopub.execute_input": "2025-10-07T20:41:56.044519Z", 626 | "iopub.status.busy": "2025-10-07T20:41:56.043752Z", 627 | "iopub.status.idle": "2025-10-07T20:41:56.049919Z", 628 | "shell.execute_reply": "2025-10-07T20:41:56.048694Z", 629 | "shell.execute_reply.started": "2025-10-07T20:41:56.044439Z" 630 | } 631 | }, 632 | "outputs": [], 633 | "source": [ 634 | "# The minimize already gives you the minimum model\n", 635 | "inverted_model = result.x" 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": 23, 641 | "id": "2e9d2ee1-bfd9-4c3d-9865-f63d08a25b07", 642 | "metadata": { 643 | "execution": { 644 | "iopub.execute_input": "2025-10-07T20:41:56.408201Z", 645 | "iopub.status.busy": "2025-10-07T20:41:56.407200Z", 646 | "iopub.status.idle": "2025-10-07T20:41:56.417991Z", 647 | "shell.execute_reply": "2025-10-07T20:41:56.417054Z", 648 | "shell.execute_reply.started": "2025-10-07T20:41:56.408127Z" 649 | } 650 | }, 651 | "outputs": [ 652 | { 653 | "name": "stdout", 654 | "output_type": "stream", 655 | "text": [ 656 | "Result:\n", 657 | "[0.81326764 0.65924856 0.24725963 0.19626682 0.32372755 0.14719478\n", 658 | " 0.31947594 0.25235815 0.52216987 0.92183788]\n", 659 | "\n", 660 | "True model:\n", 661 | "[0.78225148 0.67148671 0.2373809 0.17946133 0.34662367 0.15210999\n", 662 | " 0.31142952 0.23900652 0.54355731 0.91770851]\n" 663 | ] 664 | } 665 | ], 666 | "source": [ 667 | "print(\"Result:\")\n", 668 | "print(result.x)\n", 669 | "print()\n", 670 | "print(\"True model:\")\n", 671 | "print(true_model)" 672 | ] 673 | }, 674 | { 675 | "cell_type": "markdown", 676 | "id": "ded8a6fd-c419-4b13-bf3e-9d221c78a368", 677 | "metadata": {}, 678 | "source": [ 679 | "## Use custom minimizers" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": 24, 685 | "id": "f3f52071-c2c5-4016-9d61-cd2f3b781698", 686 | "metadata": { 687 | "execution": { 688 | "iopub.execute_input": "2025-10-07T20:42:14.351501Z", 689 | "iopub.status.busy": "2025-10-07T20:42:14.350925Z", 690 | "iopub.status.idle": "2025-10-07T20:42:14.361047Z", 691 | "shell.execute_reply": "2025-10-07T20:42:14.360170Z", 692 | "shell.execute_reply.started": "2025-10-07T20:42:14.351448Z" 693 | } 694 | }, 695 | "outputs": [], 696 | "source": [ 697 | "inverted_model = ii.conjugate_gradient(phi, initial_model)" 698 | ] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "execution_count": 25, 703 | "id": "e41d3e15-1a74-42b9-9583-cacb8aeceadc", 704 | "metadata": { 705 | "execution": { 706 | "iopub.execute_input": "2025-10-07T20:42:15.098929Z", 707 | "iopub.status.busy": "2025-10-07T20:42:15.097559Z", 708 | "iopub.status.idle": "2025-10-07T20:42:15.106580Z", 709 | "shell.execute_reply": "2025-10-07T20:42:15.105356Z", 710 | "shell.execute_reply.started": "2025-10-07T20:42:15.098849Z" 711 | } 712 | }, 713 | "outputs": [ 714 | { 715 | "name": "stdout", 716 | "output_type": "stream", 717 | "text": [ 718 | "Result:\n", 719 | "[0.81328886 0.65927515 0.24729371 0.19624752 0.3237346 0.14720343\n", 720 | " 0.3194468 0.25235983 0.52215485 0.92181019]\n", 721 | "\n", 722 | "True model:\n", 723 | "[0.78225148 0.67148671 0.2373809 0.17946133 0.34662367 0.15210999\n", 724 | " 0.31142952 0.23900652 0.54355731 0.91770851]\n" 725 | ] 726 | } 727 | ], 728 | "source": [ 729 | "print(\"Result:\")\n", 730 | "print(inverted_model)\n", 731 | "print()\n", 732 | "print(\"True model:\")\n", 733 | "print(true_model)" 734 | ] 735 | } 736 | ], 737 | "metadata": { 738 | "kernelspec": { 739 | "display_name": "Python [conda env:inversion_ideas]", 740 | "language": "python", 741 | "name": "conda-env-inversion_ideas-py" 742 | }, 743 | "language_info": { 744 | "codemirror_mode": { 745 | "name": "ipython", 746 | "version": 3 747 | }, 748 | "file_extension": ".py", 749 | "mimetype": "text/x-python", 750 | "name": "python", 751 | "nbconvert_exporter": "python", 752 | "pygments_lexer": "ipython3", 753 | "version": "3.13.5" 754 | } 755 | }, 756 | "nbformat": 4, 757 | "nbformat_minor": 5 758 | } 759 | --------------------------------------------------------------------------------