├── tests ├── __init__.py ├── conftest.py ├── test_io.py ├── test_decomposition.py ├── test_kernels.py ├── test_rbf.py └── test_core.py ├── examples ├── lid-driven-cavity │ ├── eq-snapshot.png │ ├── results-re-450.png │ ├── results-re-50.png │ ├── generatePODdata.sim │ ├── eq-snapshot-matrix.png │ └── lid-driven-cavity.ipynb ├── heat-conduction │ ├── 2d-heat.py │ └── design-optimization.py └── shape-optimization │ └── optimize_shape.py ├── .github ├── codecov.yml └── workflows │ ├── docs.yml │ ├── tests.yml │ └── publish.yml ├── docs ├── api │ └── index.md ├── examples.md ├── getting-started.md ├── user-guide │ ├── inference.md │ ├── io.md │ ├── autodiff.md │ └── training.md └── index.md ├── pyproject.toml ├── pod_rbf ├── __init__.py ├── types.py ├── kernels.py ├── shape_optimization.py ├── decomposition.py ├── io.py ├── core.py └── rbf.py ├── mkdocs.yml ├── .gitignore ├── CLAUDE.md ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/lid-driven-cavity/eq-snapshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kylebeggs/POD-RBF/HEAD/examples/lid-driven-cavity/eq-snapshot.png -------------------------------------------------------------------------------- /examples/lid-driven-cavity/results-re-450.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kylebeggs/POD-RBF/HEAD/examples/lid-driven-cavity/results-re-450.png -------------------------------------------------------------------------------- /examples/lid-driven-cavity/results-re-50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kylebeggs/POD-RBF/HEAD/examples/lid-driven-cavity/results-re-50.png -------------------------------------------------------------------------------- /examples/lid-driven-cavity/generatePODdata.sim: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kylebeggs/POD-RBF/HEAD/examples/lid-driven-cavity/generatePODdata.sim -------------------------------------------------------------------------------- /examples/lid-driven-cavity/eq-snapshot-matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kylebeggs/POD-RBF/HEAD/examples/lid-driven-cavity/eq-snapshot-matrix.png -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Pytest configuration and shared fixtures.""" 2 | 3 | import jax 4 | 5 | # Ensure float64 is enabled for all tests 6 | jax.config.update("jax_enable_x64", True) 7 | -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: auto 6 | threshold: 1% 7 | patch: 8 | default: 9 | target: auto 10 | threshold: 1% 11 | 12 | comment: 13 | layout: "reach,diff,flags,files,footer" 14 | behavior: default 15 | require_changes: false 16 | 17 | ignore: 18 | - "tests/" 19 | - "examples/" 20 | - "setup.py" 21 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.12" 20 | 21 | - name: Install dependencies 22 | run: | 23 | pip install mkdocs-material mkdocstrings[python] 24 | pip install -e . 25 | 26 | - name: Build and deploy docs 27 | run: mkdocs gh-deploy --force 28 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.10", "3.11", "3.12"] 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v4 20 | 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | cache: 'pip' 26 | cache-dependency-path: 'pyproject.toml' 27 | 28 | - name: Install JAX (CPU) 29 | run: | 30 | pip install --upgrade pip 31 | pip install "jax[cpu]>=0.4.0" 32 | 33 | - name: Install package with dev dependencies 34 | run: pip install -e ".[dev]" 35 | 36 | - name: Run tests with coverage 37 | run: pytest tests/ -v --cov=pod_rbf --cov-report=xml --cov-report=term 38 | 39 | - name: Upload coverage to Codecov 40 | uses: codecov/codecov-action@v4 41 | with: 42 | file: ./coverage.xml 43 | flags: unittests 44 | name: codecov-umbrella 45 | fail_ci_if_error: false 46 | env: 47 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 48 | -------------------------------------------------------------------------------- /docs/api/index.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | ## Core Functions 4 | 5 | ### train 6 | 7 | ::: pod_rbf.train 8 | options: 9 | show_root_heading: true 10 | show_source: false 11 | 12 | ### inference 13 | 14 | ::: pod_rbf.inference 15 | options: 16 | show_root_heading: true 17 | show_source: false 18 | 19 | ### inference_single 20 | 21 | ::: pod_rbf.inference_single 22 | options: 23 | show_root_heading: true 24 | show_source: false 25 | 26 | ## I/O Functions 27 | 28 | ### build_snapshot_matrix 29 | 30 | ::: pod_rbf.build_snapshot_matrix 31 | options: 32 | show_root_heading: true 33 | show_source: false 34 | 35 | ### save_model 36 | 37 | ::: pod_rbf.save_model 38 | options: 39 | show_root_heading: true 40 | show_source: false 41 | 42 | ### load_model 43 | 44 | ::: pod_rbf.load_model 45 | options: 46 | show_root_heading: true 47 | show_source: false 48 | 49 | ## Types 50 | 51 | ### TrainConfig 52 | 53 | ::: pod_rbf.TrainConfig 54 | options: 55 | show_root_heading: true 56 | show_source: false 57 | 58 | ### ModelState 59 | 60 | ::: pod_rbf.ModelState 61 | options: 62 | show_root_heading: true 63 | show_source: false 64 | 65 | ### TrainResult 66 | 67 | ::: pod_rbf.TrainResult 68 | options: 69 | show_root_heading: true 70 | show_source: false 71 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "setuptools-scm>=8.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "pod_rbf" 7 | dynamic = ["version"] 8 | authors = [{name = "Kyle Beggs", email = "beggskw@gmail.com"}] 9 | description = "JAX-based POD-RBF for autodiff-enabled reduced order modeling." 10 | readme = "README.md" 11 | license = {text = "MIT"} 12 | requires-python = ">=3.10" 13 | classifiers = [ 14 | "Programming Language :: Python :: 3", 15 | "Programming Language :: Python :: 3.10", 16 | "Programming Language :: Python :: 3.11", 17 | "Programming Language :: Python :: 3.12", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ] 21 | dependencies = [ 22 | "jax>=0.4.0", 23 | "jaxlib>=0.4.0", 24 | "numpy", 25 | "tqdm", 26 | ] 27 | 28 | [project.optional-dependencies] 29 | dev = ["pytest>=7.0", "pytest-cov"] 30 | docs = [ 31 | "mkdocs-material>=9.5", 32 | "mkdocstrings[python]>=0.24", 33 | ] 34 | 35 | [project.urls] 36 | Homepage = "https://github.com/kylebeggs/POD-RBF" 37 | Repository = "https://github.com/kylebeggs/POD-RBF" 38 | Documentation = "https://kylebeggs.github.io/POD-RBF" 39 | 40 | [tool.setuptools_scm] 41 | 42 | [tool.uv.sources] 43 | pod-rbf = { workspace = true } 44 | 45 | [dependency-groups] 46 | dev = [ 47 | "pod-rbf", 48 | ] 49 | -------------------------------------------------------------------------------- /pod_rbf/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | POD-RBF: Proper Orthogonal Decomposition - Radial Basis Function Network. 3 | 4 | A JAX-based implementation enabling autodifferentiation for: 5 | - Gradient optimization 6 | - Sensitivity analysis 7 | - Inverse problems 8 | 9 | Usage 10 | ----- 11 | >>> import pod_rbf 12 | >>> import jax.numpy as jnp 13 | >>> 14 | >>> # Train model 15 | >>> result = pod_rbf.train(snapshot, params) 16 | >>> 17 | >>> # Inference 18 | >>> pred = pod_rbf.inference_single(result.state, jnp.array(450.0)) 19 | >>> 20 | >>> # Autodiff 21 | >>> import jax 22 | >>> grad_fn = jax.grad(lambda p: jnp.sum(pod_rbf.inference_single(result.state, p)**2)) 23 | >>> gradient = grad_fn(jnp.array(450.0)) 24 | """ 25 | 26 | import jax 27 | 28 | # Enable float64 for numerical stability (SVD, condition numbers) 29 | jax.config.update("jax_enable_x64", True) 30 | 31 | from .core import inference, inference_single, train 32 | from .io import build_snapshot_matrix, load_model, save_model 33 | from .types import ModelState, TrainConfig, TrainResult 34 | 35 | try: 36 | from importlib.metadata import version 37 | 38 | __version__ = version("pod_rbf") 39 | except Exception: 40 | __version__ = "unknown" 41 | 42 | __all__ = [ 43 | # Core functions 44 | "train", 45 | "inference", 46 | "inference_single", 47 | # Types 48 | "ModelState", 49 | "TrainConfig", 50 | "TrainResult", 51 | # I/O 52 | "build_snapshot_matrix", 53 | "save_model", 54 | "load_model", 55 | ] 56 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: POD-RBF 2 | site_description: JAX-based POD-RBF for autodiff-enabled reduced order modeling 3 | site_url: https://kylebeggs.github.io/POD-RBF 4 | repo_url: https://github.com/kylebeggs/POD-RBF 5 | repo_name: kylebeggs/POD-RBF 6 | 7 | theme: 8 | name: material 9 | palette: 10 | - media: "(prefers-color-scheme: light)" 11 | scheme: default 12 | primary: indigo 13 | toggle: 14 | icon: material/brightness-7 15 | name: Switch to dark mode 16 | - media: "(prefers-color-scheme: dark)" 17 | scheme: slate 18 | primary: indigo 19 | toggle: 20 | icon: material/brightness-4 21 | name: Switch to light mode 22 | features: 23 | - navigation.instant 24 | - navigation.sections 25 | - navigation.top 26 | - search.highlight 27 | - content.code.copy 28 | 29 | plugins: 30 | - search 31 | - mkdocstrings: 32 | handlers: 33 | python: 34 | options: 35 | show_source: true 36 | show_root_heading: true 37 | 38 | nav: 39 | - Home: index.md 40 | - Getting Started: getting-started.md 41 | - User Guide: 42 | - Training Models: user-guide/training.md 43 | - Inference: user-guide/inference.md 44 | - Autodifferentiation: user-guide/autodiff.md 45 | - Saving & Loading: user-guide/io.md 46 | - API Reference: api/index.md 47 | - Examples: examples.md 48 | 49 | markdown_extensions: 50 | - pymdownx.highlight: 51 | anchor_linenums: true 52 | - pymdownx.superfences 53 | - admonition 54 | - pymdownx.details 55 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | build: 10 | name: Build distribution 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | with: 15 | persist-credentials: false 16 | fetch-depth: 0 # Required for setuptools-scm to get version from tags 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.12" 22 | 23 | - name: Install build tools 24 | run: python -m pip install build --user 25 | 26 | - name: Build wheel and sdist 27 | run: python -m build 28 | 29 | - name: Upload artifacts 30 | uses: actions/upload-artifact@v4 31 | with: 32 | name: python-package-distributions 33 | path: dist/ 34 | 35 | publish-to-pypi: 36 | name: Publish to PyPI 37 | needs: build 38 | runs-on: ubuntu-latest 39 | environment: 40 | name: pypi 41 | url: https://pypi.org/p/pod_rbf 42 | permissions: 43 | id-token: write 44 | steps: 45 | - name: Download artifacts 46 | uses: actions/download-artifact@v4 47 | with: 48 | name: python-package-distributions 49 | path: dist/ 50 | 51 | - name: Publish to PyPI 52 | uses: pypa/gh-action-pypi-publish@release/v1 53 | 54 | create-github-release: 55 | name: Create GitHub Release 56 | needs: publish-to-pypi 57 | runs-on: ubuntu-latest 58 | permissions: 59 | contents: write 60 | steps: 61 | - name: Checkout code 62 | uses: actions/checkout@v4 63 | 64 | - name: Create Release 65 | env: 66 | GH_TOKEN: ${{ github.token }} 67 | run: | 68 | gh release create ${{ github.ref_name }} \ 69 | --title "Release ${{ github.ref_name }}" \ 70 | --generate-notes 71 | -------------------------------------------------------------------------------- /pod_rbf/types.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data structures for POD-RBF. 3 | 4 | All types are NamedTuples for JAX pytree compatibility. 5 | """ 6 | 7 | from typing import NamedTuple 8 | 9 | from jax import Array 10 | 11 | 12 | class TrainConfig(NamedTuple): 13 | """Immutable training configuration.""" 14 | 15 | energy_threshold: float = 0.99 16 | mem_limit_gb: float = 16.0 17 | cond_range: tuple[float, float] = (1e11, 1e12) 18 | max_bisection_iters: int = 50 19 | c_low_init: float = 0.011 20 | c_high_init: float = 1.0 21 | c_high_step: float = 0.01 22 | c_high_search_iters: int = 200 23 | poly_degree: int = 2 # Polynomial augmentation degree (0=none, 1=linear, 2=quadratic) 24 | kernel: str = "imq" # RBF kernel type: 'imq', 'gaussian', 'polyharmonic_spline' 25 | kernel_order: int = 3 # For polyharmonic splines only (order of r^k) 26 | 27 | 28 | class ModelState(NamedTuple): 29 | """Immutable trained model state - a valid JAX pytree.""" 30 | 31 | basis: Array # Truncated POD basis (n_samples, n_modes) 32 | weights: Array # RBF network weights (n_modes, n_train_points) 33 | shape_factor: float | None # Optimized RBF shape parameter (None for PHS) 34 | train_params: Array # Training parameters (n_params, n_train_points) 35 | params_range: Array # Parameter ranges for normalization (n_params,) 36 | truncated_energy: float # Energy retained after truncation 37 | cumul_energy: Array # Cumulative energy per mode 38 | poly_coeffs: Array | None # Polynomial coefficients (n_modes, n_poly) or None 39 | poly_degree: int # Polynomial degree used (0=none) 40 | kernel: str # Kernel type used for training 41 | kernel_order: int # PHS order (ignored for other kernels) 42 | 43 | 44 | class TrainResult(NamedTuple): 45 | """Result from training, includes diagnostics.""" 46 | 47 | state: ModelState 48 | n_modes: int 49 | used_eig_decomp: bool # True if eigendecomposition was used 50 | -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ## Jupyter Notebooks 4 | 5 | Explore these example notebooks to see POD-RBF in action: 6 | 7 | ### Lid-Driven Cavity 8 | 9 | A complete walkthrough using CFD data from a 2D lid-driven cavity simulation at various Reynolds numbers. 10 | 11 | [:octicons-mark-github-16: View on GitHub](https://github.com/kylebeggs/POD-RBF/tree/master/examples/lid-driven-cavity){ .md-button } 12 | 13 | **What you'll learn:** 14 | 15 | - Building a snapshot matrix from CSV files 16 | - Training a single-parameter model 17 | - Visualizing predictions vs. ground truth 18 | 19 | ### Multi-Parameter Example 20 | 21 | Training a model with two input parameters. 22 | 23 | [:octicons-mark-github-16: View on GitHub](https://github.com/kylebeggs/POD-RBF/blob/master/examples/2-parameters.ipynb){ .md-button } 24 | 25 | **What you'll learn:** 26 | 27 | - Setting up multi-parameter training data 28 | - Inference with multiple parameters 29 | - Parameter space exploration 30 | 31 | ### Heat Conduction 32 | 33 | A simple heat conduction problem on a unit square. 34 | 35 | [:octicons-mark-github-16: View on GitHub](https://github.com/kylebeggs/POD-RBF/tree/master/examples/heat-conduction){ .md-button } 36 | 37 | **What you'll learn:** 38 | 39 | - Basic POD-RBF workflow 40 | - Working with thermal simulation data 41 | 42 | ### Shape Parameter Optimization 43 | 44 | Exploring RBF shape parameter selection. 45 | 46 | [:octicons-mark-github-16: View on GitHub](https://github.com/kylebeggs/POD-RBF/tree/master/examples/shape-optimization){ .md-button } 47 | 48 | **What you'll learn:** 49 | 50 | - How shape parameters affect interpolation 51 | - Automatic vs. manual shape parameter selection 52 | 53 | ## Running the Examples 54 | 55 | Clone the repository and install the package: 56 | 57 | ```bash 58 | git clone https://github.com/kylebeggs/POD-RBF.git 59 | cd POD-RBF 60 | pip install -e . 61 | ``` 62 | 63 | Then open the Jupyter notebooks in the `examples/` directory: 64 | 65 | ```bash 66 | jupyter notebook examples/ 67 | ``` 68 | -------------------------------------------------------------------------------- /docs/getting-started.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ## Installation 4 | 5 | Install POD-RBF using pip: 6 | 7 | ```bash 8 | pip install pod-rbf 9 | ``` 10 | 11 | Or using uv: 12 | 13 | ```bash 14 | uv add pod-rbf 15 | ``` 16 | 17 | ## Basic Workflow 18 | 19 | POD-RBF follows a simple three-step workflow: 20 | 21 | 1. **Build a snapshot matrix** - Collect solution data at different parameter values 22 | 2. **Train the model** - Compute POD basis and RBF interpolation weights 23 | 3. **Inference** - Predict solutions at new parameter values 24 | 25 | ## Minimal Example 26 | 27 | ```python 28 | import pod_rbf 29 | import jax.numpy as jnp 30 | import numpy as np 31 | 32 | # 1. Define training parameters 33 | params = np.array([1, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]) 34 | 35 | # 2. Build snapshot matrix from CSV files 36 | # Each CSV file contains one solution snapshot 37 | snapshot = pod_rbf.build_snapshot_matrix("path/to/data/") 38 | 39 | # 3. Train the model 40 | result = pod_rbf.train(snapshot, params) 41 | 42 | # 4. Predict at a new parameter value 43 | prediction = pod_rbf.inference_single(result.state, jnp.array(450.0)) 44 | ``` 45 | 46 | ## Understanding the Snapshot Matrix 47 | 48 | The snapshot matrix `X` has shape `(n_samples, n_snapshots)`: 49 | 50 | - Each **column** is one solution snapshot at a specific parameter value 51 | - Each **row** corresponds to a spatial location or degree of freedom 52 | - `n_samples` is the number of points in your solution (e.g., mesh nodes) 53 | - `n_snapshots` is the number of parameter values you trained on 54 | 55 | For example, if you solve a problem on a 400-node mesh at 10 different parameter values, your snapshot matrix is `(400, 10)`. 56 | 57 | !!! note "Parameter Ordering" 58 | The order of columns in the snapshot matrix must match the order of your parameter array. If column 5 contains the solution at Re=500, then `params[4]` must equal 500. 59 | 60 | ## What's Next? 61 | 62 | - [Training Models](user-guide/training.md) - Learn about training configuration options 63 | - [Inference](user-guide/inference.md) - Single-point and batch predictions 64 | - [Autodifferentiation](user-guide/autodiff.md) - Use JAX for gradients and optimization 65 | -------------------------------------------------------------------------------- /docs/user-guide/inference.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 | After training, use the model to predict solutions at new parameter values. 4 | 5 | ## Single-Point Inference 6 | 7 | For predicting at a single parameter value: 8 | 9 | ```python 10 | import pod_rbf 11 | import jax.numpy as jnp 12 | 13 | # Train the model 14 | result = pod_rbf.train(snapshot, params) 15 | 16 | # Predict at a new parameter 17 | prediction = pod_rbf.inference_single(result.state, jnp.array(450.0)) 18 | ``` 19 | 20 | The output shape is `(n_samples,)` - the same as one column of your snapshot matrix. 21 | 22 | ## Batch Inference 23 | 24 | For predicting at multiple parameter values simultaneously: 25 | 26 | ```python 27 | # Predict at multiple parameters 28 | new_params = jnp.array([350.0, 450.0, 550.0]) 29 | predictions = pod_rbf.inference(result.state, new_params) 30 | ``` 31 | 32 | The output shape is `(n_samples, n_points)` where `n_points` is the number of parameter values. 33 | 34 | ## Multi-Parameter Inference 35 | 36 | For models trained with multiple parameters: 37 | 38 | ```python 39 | # Single point with 2 parameters 40 | param = jnp.array([450.0, 0.15]) # [Re, Ma] 41 | prediction = pod_rbf.inference_single(result.state, param) 42 | 43 | # Batch with 2 parameters 44 | params = jnp.array([ 45 | [350.0, 0.1], 46 | [450.0, 0.15], 47 | [550.0, 0.2], 48 | ]) 49 | predictions = pod_rbf.inference(result.state, params) 50 | ``` 51 | 52 | ## Using a Saved Model 53 | 54 | Load a previously saved model and use it for inference: 55 | 56 | ```python 57 | state = pod_rbf.load_model("model.pkl") 58 | prediction = pod_rbf.inference_single(state, jnp.array(450.0)) 59 | ``` 60 | 61 | ## Performance Tips 62 | 63 | 1. **Use batch inference** when predicting at multiple parameter values - it's more efficient than calling `inference_single` in a loop. 64 | 65 | 2. **JIT compilation** - The inference functions are JAX-compatible and can be JIT-compiled for faster repeated calls: 66 | 67 | ```python 68 | import jax 69 | 70 | inference_jit = jax.jit(lambda p: pod_rbf.inference_single(state, p)) 71 | prediction = inference_jit(jnp.array(450.0)) 72 | ``` 73 | 74 | 3. **GPU acceleration** - If JAX is configured with GPU support, inference will automatically use the GPU. 75 | -------------------------------------------------------------------------------- /docs/user-guide/io.md: -------------------------------------------------------------------------------- 1 | # Saving & Loading 2 | 3 | ## Loading Snapshot Data 4 | 5 | ### From CSV Files 6 | 7 | Load snapshots from a directory of CSV files: 8 | 9 | ```python 10 | import pod_rbf 11 | 12 | snapshot = pod_rbf.build_snapshot_matrix("path/to/data/") 13 | ``` 14 | 15 | By default, this: 16 | 17 | - Loads all CSV files from the directory in alphanumeric order 18 | - Skips the first row (header) 19 | - Uses the first column 20 | 21 | Customize with optional parameters: 22 | 23 | ```python 24 | snapshot = pod_rbf.build_snapshot_matrix( 25 | "path/to/data/", 26 | skiprows=1, # Skip first N rows (default: 1 for header) 27 | usecols=0, # Column index to use (default: 0) 28 | verbose=True, # Show progress bar (default: True) 29 | ) 30 | ``` 31 | 32 | ### From NumPy Arrays 33 | 34 | If your data is already in memory: 35 | 36 | ```python 37 | import numpy as np 38 | 39 | # Combine individual solutions into a snapshot matrix 40 | # Each column is one snapshot 41 | snapshot = np.column_stack([sol1, sol2, sol3, sol4, sol5]) 42 | ``` 43 | 44 | ## Saving Models 45 | 46 | Save a trained model to disk: 47 | 48 | ```python 49 | import pod_rbf 50 | 51 | result = pod_rbf.train(snapshot, params) 52 | pod_rbf.save_model("model.pkl", result.state) 53 | ``` 54 | 55 | The model is saved as a pickle file containing the `ModelState` NamedTuple. 56 | 57 | ## Loading Models 58 | 59 | Load a previously saved model: 60 | 61 | ```python 62 | state = pod_rbf.load_model("model.pkl") 63 | 64 | # Use for inference 65 | prediction = pod_rbf.inference_single(state, jnp.array(450.0)) 66 | ``` 67 | 68 | ## Model State Contents 69 | 70 | The saved `ModelState` contains everything needed for inference: 71 | 72 | | Field | Description | 73 | |-------|-------------| 74 | | `basis` | Truncated POD basis matrix | 75 | | `weights` | RBF interpolation weights | 76 | | `shape_factor` | Optimized RBF shape parameter | 77 | | `train_params` | Training parameter values | 78 | | `params_range` | Parameter ranges for normalization | 79 | | `truncated_energy` | Energy retained after truncation | 80 | | `cumul_energy` | Cumulative energy per mode | 81 | | `poly_coeffs` | Polynomial coefficients (if used) | 82 | | `poly_degree` | Polynomial degree used | 83 | | `kernel` | Kernel type used | 84 | | `kernel_order` | PHS order (for polyharmonic splines) | 85 | 86 | ## File Format Notes 87 | 88 | - Models are saved using Python's `pickle` module 89 | - Files are portable across machines with the same Python/JAX versions 90 | - File size depends on the number of modes and training points 91 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # POD-RBF 2 | 3 | [![Tests](https://github.com/kylebeggs/POD-RBF/actions/workflows/tests.yml/badge.svg)](https://github.com/kylebeggs/POD-RBF/actions/workflows/tests.yml) 4 | [![codecov](https://codecov.io/gh/kylebeggs/POD-RBF/branch/master/graph/badge.svg)](https://codecov.io/gh/kylebeggs/POD-RBF) 5 | [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) 6 | 7 | A Python package for building Reduced Order Models (ROMs) from high-dimensional data using Proper Orthogonal Decomposition combined with Radial Basis Function interpolation. 8 | 9 | ![Lid-driven cavity results](https://raw.githubusercontent.com/kylebeggs/POD-RBF/master/examples/lid-driven-cavity/results-re-450.png) 10 | 11 | ## Features 12 | 13 | - **JAX-based** - Enables autodifferentiation for gradient optimization, sensitivity analysis, and inverse problems 14 | - **Shape parameter optimization** - Automatic tuning of RBF shape parameters 15 | - **Memory-aware algorithms** - Switches between eigenvalue decomposition and SVD based on memory requirements 16 | 17 | ## Quick Install 18 | 19 | ```bash 20 | pip install pod-rbf 21 | ``` 22 | 23 | ## Quick Example 24 | 25 | ```python 26 | import pod_rbf 27 | import jax.numpy as jnp 28 | import numpy as np 29 | 30 | # Define training parameters 31 | Re = np.array([1, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]) 32 | 33 | # Build snapshot matrix from CSV files 34 | train_snapshot = pod_rbf.build_snapshot_matrix("data/train/") 35 | 36 | # Train the model 37 | result = pod_rbf.train(train_snapshot, Re) 38 | 39 | # Inference on unseen parameter 40 | sol = pod_rbf.inference_single(result.state, jnp.array(450.0)) 41 | ``` 42 | 43 | ## Next Steps 44 | 45 | - [Getting Started](getting-started.md) - Installation and first steps 46 | - [User Guide](user-guide/training.md) - Detailed usage instructions 47 | - [API Reference](api/index.md) - Complete API documentation 48 | - [Examples](examples.md) - Jupyter notebook examples 49 | 50 | ## References 51 | 52 | This implementation is based on the following papers: 53 | 54 | 1. [Solving inverse heat conduction problems using trained POD-RBF network inverse method](https://www.tandfonline.com/doi/full/10.1080/17415970701198290) - Ostrowski, Bialecki, Kassab (2008) 55 | 2. [RBF-trained POD-accelerated CFD analysis of wind loads on PV systems](https://www.emerald.com/insight/content/doi/10.1108/HFF-03-2016-0083/full/html) - Huayamave et al. (2017) 56 | 3. [Real-Time Thermomechanical Modeling of PV Cell Fabrication via a POD-Trained RBF Interpolation Network](https://www.techscience.com/CMES/v122n3/38374) - Das et al. (2020) 57 | -------------------------------------------------------------------------------- /examples/lid-driven-cavity/lid-driven-cavity.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": "import numpy as np\nimport matplotlib.pyplot as plt\nimport pod_rbf\n\nprint(\"using version: {}\".format(pod_rbf.__version__))\n\nRe = np.linspace(0, 1000, num=11)\nRe[0] = 1\n\ncoords_path = \"data/train/re-0001.csv\"\nx, y = np.loadtxt(\n coords_path,\n delimiter=\",\",\n skiprows=1,\n usecols=(1, 2),\n unpack=True,\n)\n\n# make snapshot matrix from csv files\ntrain_snapshot = pod_rbf.build_snapshot_matrix(\"data/train\")" 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": "import jax.numpy as jnp\n\n# load validation\nval = np.loadtxt(\n \"data/validation/re-0450.csv\",\n # \"data/validation/re-0050.csv\",\n delimiter=\",\",\n skiprows=1,\n usecols=(0),\n unpack=True,\n)\n\n# train the model\nconfig = pod_rbf.TrainConfig(energy_threshold=0.9, poly_degree=2) # poly_degree: 0=none, 1=linear, 2=quadratic\nresult = pod_rbf.train(train_snapshot, Re, config)\nstate = result.state\nprint(\"Energy kept after truncating = {}%\".format(state.truncated_energy))\n\n# plot the energy decay\nplt.plot(state.cumul_energy)\n\n\n# inference the model on an unseen parameter\nsol = pod_rbf.inference_single(state, jnp.array(450.0))\n\n# calculate and plot the difference between inference and actual\ndiff = np.nan_to_num(np.abs(sol - val))\nprint(\"Average Percent Error = {}\".format(np.mean(diff)))\n\n# plot the inferenced solution\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(22, 9))\nax1.set_title(\"POD-RBF\", fontsize=40)\ncntr1 = ax1.tricontourf(\n x, y, sol, levels=np.linspace(0, 1, num=20), cmap=\"viridis\", extend=\"both\"\n)\nax1.set_xticks([])\nax1.set_yticks([])\n# plot the actual solution\nax2.set_title(\"Target\", fontsize=40)\ncntr2 = ax2.tricontourf(\n x, y, val, levels=np.linspace(0, 1, num=20), cmap=\"viridis\", extend=\"both\"\n)\ncbar_ax = fig.add_axes([0.485, 0.15, 0.025, 0.7])\nfig.colorbar(cntr2, cax=cbar_ax)\nax2.set_xticks([])\nax2.set_yticks([])\n# fig.tight_layout()\n\nfig2, ax = plt.subplots(1, 1, figsize=(12, 9))\ncntr = ax.tricontourf(x, y, diff, cmap=\"viridis\", extend=\"both\")\nfig2.colorbar(cntr)\nplt.show()" 16 | } 17 | ], 18 | "metadata": { 19 | "interpreter": { 20 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 21 | }, 22 | "kernelspec": { 23 | "display_name": "Python 3.8.5 64-bit", 24 | "name": "python3" 25 | }, 26 | "language_info": { 27 | "name": "python", 28 | "version": "" 29 | }, 30 | "metadata": { 31 | "interpreter": { 32 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 33 | } 34 | }, 35 | "orig_nbformat": 2 36 | }, 37 | "nbformat": 4, 38 | "nbformat_minor": 2 39 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # VS Code 2 | *.code-workspace 3 | .vscode 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # uv 102 | # For libraries, uv.lock should not be committed as it locks dependencies for development only. 103 | # Library users will resolve dependencies based on pyproject.toml. 104 | uv.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ 148 | -------------------------------------------------------------------------------- /examples/heat-conduction/2d-heat.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import time 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def buildSnapshotMatrix(params, num_points): 8 | """ 9 | Assemble the snapshot matrix 10 | """ 11 | print("making the snapshot matrix... ", end="") 12 | start = time.time() 13 | 14 | # evaluate the analytical solution 15 | num_terms = 50 16 | L = 1 17 | T_L = params[0, :] 18 | 19 | n = np.arange(0, num_terms) 20 | # calculate lambdas 21 | lambs = np.pi * (2 * n + 1) / (2 * L) 22 | 23 | # define points 24 | x = np.linspace(0, L, num=num_points) 25 | X, Y = np.meshgrid(x, x, indexing="xy") 26 | 27 | snapshot = np.zeros((num_points ** 2, len(T_L))) 28 | for i in range(len(T_L)): 29 | # calculate constants 30 | C = ( 31 | 8 32 | * T_L[i] 33 | * (2 * (-1) ** n / (lambs * L) - 1) 34 | / ((lambs * L) ** 2 * np.cosh(lambs * L)) 35 | ) 36 | T = np.zeros_like(X) 37 | for j in range(0, num_terms): 38 | T = T + C[j] * np.cosh(lambs[j] * X) * np.cos(lambs[j] * Y) 39 | snapshot[:, i] = T.flatten() 40 | 41 | print("took {:3.3f} sec".format(time.time() - start)) 42 | 43 | return snapshot 44 | 45 | 46 | if __name__ == "__main__": 47 | 48 | import jax 49 | import jax.numpy as jnp 50 | import pod_rbf 51 | 52 | jax.config.update('jax_default_device', jax.devices('cpu')[0]) # Change to 'gpu' or 'tpu' for accelerators 53 | 54 | T_L = np.linspace(1, 100, num=11) 55 | T_L = np.expand_dims(T_L, axis=0) 56 | T_L_test = 55.0 57 | num_points = 41 58 | 59 | # make snapshot matrix 60 | snapshot = buildSnapshotMatrix(T_L, num_points) 61 | 62 | # calculate 'test' solution 63 | # evaluate the analytical solution 64 | num_terms = 50 65 | L = 1 66 | n = np.arange(0, num_terms) 67 | # calculate lambdas 68 | lambs = np.pi * (2 * n + 1) / (2 * L) 69 | # define points 70 | x = np.linspace(0, L, num=num_points) 71 | X, Y = np.meshgrid(x, x, indexing="xy") 72 | # calculate constants 73 | C = ( 74 | 8 75 | * T_L_test 76 | * (2 * (-1) ** n / (lambs * L) - 1) 77 | / ((lambs * L) ** 2 * np.cosh(lambs * L)) 78 | ) 79 | T_test = np.zeros_like(X) 80 | for n in range(0, num_terms): 81 | T_test = T_test + C[n] * np.cosh(lambs[n] * X) * np.cos(lambs[n] * Y) 82 | 83 | # train the POD-RBF model 84 | config = pod_rbf.TrainConfig(energy_threshold=0.5, poly_degree=2) 85 | result = pod_rbf.train(snapshot, T_L, config) 86 | state = result.state 87 | 88 | # inference the trained model 89 | sol = pod_rbf.inference_single(state, jnp.array(T_L_test)) 90 | 91 | print("Energy kept after truncating = {}%".format(state.truncated_energy)) 92 | print("Cumulative Energy = {}%".format(state.cumul_energy)) 93 | 94 | fig = plt.figure(figsize=(12, 9)) 95 | c = plt.pcolormesh(T_test, cmap="magma") 96 | fig.colorbar(c) 97 | 98 | fig = plt.figure(figsize=(12, 9)) 99 | c = plt.pcolormesh(sol.reshape((num_points, num_points)), cmap="magma") 100 | fig.colorbar(c) 101 | 102 | fig = plt.figure(figsize=(12, 9)) 103 | diff = np.abs(sol.reshape((num_points, num_points)) - T_test) / T_test * 100 104 | c = plt.pcolormesh(diff, cmap="magma") 105 | fig.colorbar(c) 106 | 107 | plt.show() 108 | -------------------------------------------------------------------------------- /docs/user-guide/autodiff.md: -------------------------------------------------------------------------------- 1 | # Autodifferentiation 2 | 3 | POD-RBF is built on JAX, enabling automatic differentiation through the inference functions. This is useful for optimization, sensitivity analysis, and inverse problems. 4 | 5 | ## Computing Gradients 6 | 7 | Use `jax.grad` to compute gradients with respect to parameters: 8 | 9 | ```python 10 | import jax 11 | import jax.numpy as jnp 12 | import pod_rbf 13 | 14 | # Train model 15 | result = pod_rbf.train(snapshot, params) 16 | state = result.state 17 | 18 | # Define an objective function 19 | def objective(param): 20 | prediction = pod_rbf.inference_single(state, param) 21 | return jnp.sum(prediction ** 2) 22 | 23 | # Compute gradient 24 | grad_fn = jax.grad(objective) 25 | gradient = grad_fn(jnp.array(450.0)) 26 | ``` 27 | 28 | ## Optimization Example 29 | 30 | Find the parameter value that minimizes a cost function: 31 | 32 | ```python 33 | import jax 34 | import jax.numpy as jnp 35 | from jax import grad 36 | 37 | def cost_function(param, target): 38 | prediction = pod_rbf.inference_single(state, param) 39 | return jnp.mean((prediction - target) ** 2) 40 | 41 | # Gradient descent 42 | param = jnp.array(500.0) # Initial guess 43 | learning_rate = 10.0 44 | 45 | for i in range(100): 46 | grad_val = grad(cost_function)(param, target_solution) 47 | param = param - learning_rate * grad_val 48 | 49 | print(f"Optimal parameter: {param}") 50 | ``` 51 | 52 | ## Inverse Problems 53 | 54 | For inverse problems where you want to find the parameter that produced an observed solution: 55 | 56 | ```python 57 | import jax 58 | import jax.numpy as jnp 59 | from jax.scipy.optimize import minimize 60 | 61 | def inverse_objective(param): 62 | prediction = pod_rbf.inference_single(state, param) 63 | return jnp.sum((prediction - observed_solution) ** 2) 64 | 65 | # Use BFGS optimization 66 | result = minimize( 67 | inverse_objective, 68 | x0=jnp.array(500.0), 69 | method="BFGS", 70 | ) 71 | 72 | recovered_param = result.x 73 | ``` 74 | 75 | ## Sensitivity Analysis 76 | 77 | Compute how sensitive the solution is to parameter changes: 78 | 79 | ```python 80 | import jax 81 | import jax.numpy as jnp 82 | 83 | # Jacobian: how each output point changes with the parameter 84 | jacobian_fn = jax.jacobian( 85 | lambda p: pod_rbf.inference_single(state, p) 86 | ) 87 | sensitivity = jacobian_fn(jnp.array(450.0)) 88 | 89 | # sensitivity shape: (n_samples,) for single parameter 90 | # Positive values indicate the solution increases with the parameter 91 | ``` 92 | 93 | ## Multi-Parameter Gradients 94 | 95 | For models with multiple parameters: 96 | 97 | ```python 98 | def objective(params): 99 | # params: [Re, Ma] 100 | prediction = pod_rbf.inference_single(state, params) 101 | return jnp.sum(prediction ** 2) 102 | 103 | # Gradient with respect to all parameters 104 | grad_fn = jax.grad(objective) 105 | gradients = grad_fn(jnp.array([450.0, 0.15])) 106 | # gradients shape: (2,) - one gradient per parameter 107 | ``` 108 | 109 | ## JIT Compilation 110 | 111 | For performance, JIT-compile your gradient functions: 112 | 113 | ```python 114 | @jax.jit 115 | def compute_gradient(param): 116 | return jax.grad(objective)(param) 117 | 118 | # First call compiles; subsequent calls are fast 119 | gradient = compute_gradient(jnp.array(450.0)) 120 | ``` 121 | 122 | ## Higher-Order Derivatives 123 | 124 | JAX supports higher-order derivatives: 125 | 126 | ```python 127 | # Second derivative (Hessian for scalar output) 128 | hessian_fn = jax.hessian(objective) 129 | hessian = hessian_fn(jnp.array(450.0)) 130 | ``` 131 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## Project Overview 6 | 7 | POD-RBF is a JAX-based Python library for building Reduced Order Models (ROMs) using Proper Orthogonal Decomposition combined with Radial Basis Function interpolation. It enables autodifferentiation for gradient optimization, sensitivity analysis, and inverse problems. 8 | 9 | ## Development Commands 10 | 11 | ```bash 12 | # Install for development 13 | pip install -e ".[dev]" 14 | 15 | # Run tests 16 | pytest tests/ -v 17 | 18 | # Run single test file 19 | pytest tests/test_core.py -v 20 | 21 | # Run specific test 22 | pytest tests/test_core.py::TestGradients::test_inverse_problem -v 23 | ``` 24 | 25 | ## Architecture 26 | 27 | ### Module Structure 28 | 29 | ``` 30 | pod_rbf/ 31 | __init__.py # Public API, enables float64 32 | types.py # ModelState, TrainConfig, TrainResult (NamedTuples) 33 | core.py # train(), inference(), inference_single() 34 | rbf.py # build_collocation_matrix(), build_inference_matrix() 35 | decomposition.py # compute_pod_basis_svd(), compute_pod_basis_eig() 36 | shape_optimization.py # find_optimal_shape_param() (fixed-iteration bisection) 37 | io.py # build_snapshot_matrix(), save_model(), load_model() 38 | ``` 39 | 40 | ### Key Types 41 | 42 | ```python 43 | class ModelState(NamedTuple): 44 | basis: Array # (n_samples, n_modes) 45 | weights: Array # (n_modes, n_train_points) 46 | shape_factor: float 47 | train_params: Array # (n_params, n_train_points) 48 | params_range: Array # (n_params,) 49 | truncated_energy: float 50 | cumul_energy: Array 51 | 52 | class TrainConfig(NamedTuple): 53 | energy_threshold: float = 0.99 54 | mem_limit_gb: float = 16.0 55 | cond_range: tuple = (1e11, 1e12) 56 | max_bisection_iters: int = 50 57 | ``` 58 | 59 | ### API 60 | 61 | ```python 62 | import pod_rbf 63 | import jax 64 | import jax.numpy as jnp 65 | 66 | # Train with default config (energy_threshold=0.99) 67 | result = pod_rbf.train(snapshot, params) 68 | state = result.state 69 | 70 | # Train with custom config 71 | config = pod_rbf.TrainConfig(energy_threshold=0.9) 72 | result = pod_rbf.train(snapshot, params, config) 73 | 74 | # Inference (single point) 75 | pred = pod_rbf.inference_single(state, jnp.array(450.0)) 76 | 77 | # Inference (batch) 78 | preds = pod_rbf.inference(state, jnp.array([400.0, 450.0, 500.0])) 79 | 80 | # Autodiff 81 | grad_fn = jax.grad(lambda p: jnp.sum(pod_rbf.inference_single(state, p)**2)) 82 | gradient = grad_fn(jnp.array(450.0)) 83 | 84 | # I/O 85 | snapshot = pod_rbf.build_snapshot_matrix("data/train/") # load CSVs from directory 86 | pod_rbf.save_model("model.pkl", state) 87 | state = pod_rbf.load_model("model.pkl") 88 | ``` 89 | 90 | ### Data Shape Conventions 91 | 92 | - **Snapshot matrix**: `(n_samples, n_snapshots)` - each column is one parameter's solution 93 | - **Parameters**: 1D `(n_snapshots,)` or 2D `(n_params, n_snapshots)` 94 | - **Inference output**: `(n_samples,)` for single, `(n_samples, n_points)` for batch 95 | 96 | ### Key Algorithms 97 | 98 | 1. **POD truncation**: Keeps modes until cumulative energy exceeds `energy_threshold` 99 | 2. **RBF kernel**: Hardy Inverse Multi-Quadrics: `1/√(r²/c² + 1)` 100 | 3. **Shape optimization**: Fixed-iteration bisection (50 iters) for condition number in [10^11, 10^12] 101 | 102 | ## Dependencies 103 | 104 | - `jax>=0.4.0`, `jaxlib>=0.4.0` - Autodiff and JIT compilation 105 | - `numpy` - File I/O operations 106 | - `tqdm` - Progress bars 107 | - Python ≥ 3.10 108 | -------------------------------------------------------------------------------- /docs/user-guide/training.md: -------------------------------------------------------------------------------- 1 | # Training Models 2 | 3 | ## Building the Snapshot Matrix 4 | 5 | The snapshot matrix contains your training data. Each column is a solution snapshot at a specific parameter value. 6 | 7 | ### From CSV Files 8 | 9 | If your snapshots are stored as individual CSV files in a directory: 10 | 11 | ```python 12 | import pod_rbf 13 | 14 | snapshot = pod_rbf.build_snapshot_matrix("path/to/data/") 15 | ``` 16 | 17 | Files are loaded in alphanumeric order. The function expects one value per row in each CSV file. 18 | 19 | !!! tip "File Organization" 20 | Keep all training snapshots in a dedicated directory. Files are sorted alphanumerically, so use consistent naming like `snapshot_001.csv`, `snapshot_002.csv`, etc. 21 | 22 | ### From Arrays 23 | 24 | If you already have your data in memory: 25 | 26 | ```python 27 | import numpy as np 28 | 29 | # Shape: (n_samples, n_snapshots) 30 | snapshot = np.column_stack([solution1, solution2, solution3, ...]) 31 | ``` 32 | 33 | ## Training 34 | 35 | ### Basic Training 36 | 37 | ```python 38 | import pod_rbf 39 | import numpy as np 40 | 41 | params = np.array([100, 200, 300, 400, 500]) 42 | result = pod_rbf.train(snapshot, params) 43 | ``` 44 | 45 | The `result` object contains: 46 | 47 | - `result.state` - The trained model state (use this for inference) 48 | - `result.n_modes` - Number of POD modes retained 49 | - `result.used_eig_decomp` - Whether eigendecomposition was used (vs SVD) 50 | 51 | ### Training Configuration 52 | 53 | Customize training with `TrainConfig`: 54 | 55 | ```python 56 | from pod_rbf import TrainConfig 57 | 58 | config = TrainConfig( 59 | energy_threshold=0.99, # Keep modes until 99% energy retained 60 | kernel="imq", # RBF kernel: 'imq', 'gaussian', 'polyharmonic_spline' 61 | poly_degree=2, # Polynomial augmentation: 0=none, 1=linear, 2=quadratic 62 | ) 63 | 64 | result = pod_rbf.train(snapshot, params, config) 65 | ``` 66 | 67 | ### Configuration Options 68 | 69 | | Parameter | Default | Description | 70 | |-----------|---------|-------------| 71 | | `energy_threshold` | 0.99 | POD truncation threshold (0-1) | 72 | | `kernel` | `"imq"` | RBF kernel type | 73 | | `poly_degree` | 2 | Polynomial augmentation degree | 74 | | `mem_limit_gb` | 16.0 | Memory limit for algorithm selection | 75 | | `cond_range` | (1e11, 1e12) | Target condition number range | 76 | | `max_bisection_iters` | 50 | Max iterations for shape optimization | 77 | 78 | ### Kernel Options 79 | 80 | POD-RBF supports three RBF kernels: 81 | 82 | - **`imq`** (Inverse Multi-Quadrics) - Default, good general-purpose choice 83 | - **`gaussian`** - Smoother interpolation, requires careful shape parameter tuning 84 | - **`polyharmonic_spline`** - No shape parameter needed, use with `kernel_order` 85 | 86 | ```python 87 | # Using polyharmonic splines (no shape parameter optimization) 88 | config = TrainConfig(kernel="polyharmonic_spline", kernel_order=3) 89 | result = pod_rbf.train(snapshot, params, config) 90 | ``` 91 | 92 | ## Multi-Parameter Training 93 | 94 | For problems with multiple parameters: 95 | 96 | ```python 97 | import numpy as np 98 | 99 | # Parameters shape: (n_params, n_snapshots) 100 | params = np.array([ 101 | [100, 200, 300, 400], # Parameter 1 (e.g., Reynolds number) 102 | [0.1, 0.1, 0.2, 0.2], # Parameter 2 (e.g., Mach number) 103 | ]) 104 | 105 | result = pod_rbf.train(snapshot, params) 106 | ``` 107 | 108 | See the [2-parameter example](https://github.com/kylebeggs/POD-RBF/blob/master/examples/2-parameters.ipynb) for a complete walkthrough. 109 | 110 | ## Manual Shape Parameter 111 | 112 | If you want to specify the RBF shape parameter instead of using automatic optimization: 113 | 114 | ```python 115 | result = pod_rbf.train(snapshot, params, shape_factor=0.5) 116 | ``` 117 | 118 | ## Understanding POD Truncation 119 | 120 | POD extracts the dominant modes from your snapshot data. The `energy_threshold` controls how many modes are kept: 121 | 122 | - `0.99` (default) - Keep modes until 99% of total energy is captured 123 | - Higher values retain more modes (more accurate, slower inference) 124 | - Lower values retain fewer modes (faster inference, may lose accuracy) 125 | 126 | After training, check how much energy was retained: 127 | 128 | ```python 129 | print(f"Modes retained: {result.n_modes}") 130 | print(f"Energy retained: {result.state.truncated_energy:.4f}") 131 | ``` 132 | -------------------------------------------------------------------------------- /pod_rbf/kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | RBF kernel functions and dispatcher. 3 | 4 | Supports multiple kernel types: 5 | - Inverse Multi-Quadrics (IMQ): phi(r) = 1 / sqrt(r²/c² + 1) 6 | - Gaussian: phi(r) = exp(-r²/c²) 7 | - Polyharmonic Splines (PHS): phi(r) = r^k or r^k*log(r) 8 | """ 9 | 10 | from enum import Enum 11 | 12 | import jax.numpy as jnp 13 | from jax import Array 14 | 15 | 16 | class KernelType(str, Enum): 17 | """RBF kernel types (internal use only).""" 18 | 19 | IMQ = "imq" 20 | GAUSSIAN = "gaussian" 21 | POLYHARMONIC_SPLINE = "polyharmonic_spline" 22 | 23 | 24 | def kernel_imq(r2: Array, shape_factor: float) -> Array: 25 | """ 26 | Inverse Multiquadrics kernel. 27 | 28 | phi(r) = 1 / sqrt(r²/c² + 1) 29 | 30 | Parameters 31 | ---------- 32 | r2 : Array 33 | Squared distances between points. 34 | shape_factor : float 35 | Shape parameter c. 36 | 37 | Returns 38 | ------- 39 | Array 40 | Kernel values, same shape as r2. 41 | """ 42 | return 1.0 / jnp.sqrt(r2 / (shape_factor**2) + 1.0) 43 | 44 | 45 | def kernel_gaussian(r2: Array, shape_factor: float) -> Array: 46 | """ 47 | Gaussian kernel. 48 | 49 | phi(r) = exp(-r²/c²) 50 | 51 | Parameters 52 | ---------- 53 | r2 : Array 54 | Squared distances between points. 55 | shape_factor : float 56 | Shape parameter c. 57 | 58 | Returns 59 | ------- 60 | Array 61 | Kernel values, same shape as r2. 62 | """ 63 | return jnp.exp(-r2 / (shape_factor**2)) 64 | 65 | 66 | def kernel_polyharmonic_spline(r2: Array, order: int) -> Array: 67 | """ 68 | Polyharmonic spline kernel. 69 | 70 | - Odd order k: phi(r) = r^k 71 | - Even order k: phi(r) = r^k * log(r) 72 | 73 | Parameters 74 | ---------- 75 | r2 : Array 76 | Squared distances between points. 77 | order : int 78 | Polynomial order (typically 1-5). 79 | Odd: r, r³, r⁵ 80 | Even: r²log(r), r⁴log(r) 81 | 82 | Returns 83 | ------- 84 | Array 85 | Kernel values, same shape as r2. 86 | 87 | Notes 88 | ----- 89 | For even orders, handles r=0 case to avoid log(0) singularity. 90 | Uses r2-based formulation to avoid gradient singularity from sqrt at r2=0. 91 | """ 92 | if order % 2 == 1: 93 | # Odd order: r^k = (r2)^(k/2) 94 | # Using power of r2 directly avoids sqrt gradient singularity at r2=0 95 | return jnp.power(r2, order / 2.0) 96 | else: 97 | # Even order: r^k * log(r) = (r2)^(k/2) * log(sqrt(r2)) 98 | # = (r2)^(k/2) * (1/2) * log(r2) 99 | # Handle r2=0 case: set to 0 when r2 < threshold 100 | return jnp.where( 101 | r2 > 1e-30, 102 | jnp.power(r2, order / 2.0) * 0.5 * jnp.log(r2), 103 | 0.0, 104 | ) 105 | 106 | 107 | def apply_kernel( 108 | r2: Array, 109 | kernel: str, 110 | shape_factor: float | None, 111 | kernel_order: int, 112 | ) -> Array: 113 | """ 114 | Apply RBF kernel to distance matrix. 115 | 116 | Dispatcher function that selects and applies the appropriate kernel 117 | based on the kernel type string. 118 | 119 | Parameters 120 | ---------- 121 | r2 : Array 122 | Squared distances between points. 123 | kernel : str 124 | Kernel type: 'imq', 'gaussian', or 'polyharmonic_spline'. 125 | shape_factor : float | None 126 | Shape parameter for IMQ and Gaussian kernels. 127 | Ignored for polyharmonic splines. 128 | kernel_order : int 129 | Order for polyharmonic splines. 130 | Ignored for other kernels. 131 | 132 | Returns 133 | ------- 134 | Array 135 | Kernel values applied to distance matrix. 136 | 137 | Raises 138 | ------ 139 | ValueError 140 | If kernel type is not recognized. 141 | """ 142 | kernel_type = KernelType(kernel) 143 | 144 | if kernel_type == KernelType.IMQ: 145 | return kernel_imq(r2, shape_factor) 146 | elif kernel_type == KernelType.GAUSSIAN: 147 | return kernel_gaussian(r2, shape_factor) 148 | elif kernel_type == KernelType.POLYHARMONIC_SPLINE: 149 | return kernel_polyharmonic_spline(r2, kernel_order) 150 | else: 151 | raise ValueError(f"Unknown kernel type: {kernel}") 152 | 153 | 154 | # Kernel-specific defaults for shape parameter optimization 155 | KERNEL_SHAPE_DEFAULTS = { 156 | "imq": {"c_low": 0.011, "c_high": 1.0, "c_step": 0.01}, 157 | "gaussian": {"c_low": 0.1, "c_high": 10.0, "c_step": 0.1}, 158 | } 159 | -------------------------------------------------------------------------------- /pod_rbf/shape_optimization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shape parameter optimization for RBF interpolation. 3 | 4 | Uses fixed-iteration bisection to find optimal shape parameter c such that 5 | the collocation matrix condition number falls within a target range. 6 | """ 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | from jax import Array 11 | 12 | from .kernels import KERNEL_SHAPE_DEFAULTS 13 | from .rbf import build_collocation_matrix 14 | 15 | 16 | def find_optimal_shape_param( 17 | train_params: Array, 18 | params_range: Array, 19 | kernel: str = "imq", 20 | kernel_order: int = 3, 21 | cond_range: tuple[float, float] = (1e11, 1e12), 22 | max_iters: int = 50, 23 | c_low_init: float | None = None, 24 | c_high_init: float | None = None, 25 | c_high_step: float | None = None, 26 | c_high_search_iters: int = 200, 27 | ) -> float | None: 28 | """ 29 | Find optimal RBF shape parameter via fixed-iteration bisection. 30 | 31 | Target: condition number in [cond_range[0], cond_range[1]]. 32 | 33 | Parameters 34 | ---------- 35 | train_params : Array 36 | Training parameters, shape (n_params, n_train_points). 37 | params_range : Array 38 | Range of each parameter for normalization, shape (n_params,). 39 | kernel : str, optional 40 | Kernel type: 'imq', 'gaussian', or 'polyharmonic_spline'. 41 | Default is 'imq'. 42 | kernel_order : int, optional 43 | Order for polyharmonic splines (default 3). 44 | Ignored for other kernels. 45 | cond_range : tuple 46 | Target condition number range (lower, upper). 47 | max_iters : int 48 | Maximum bisection iterations. 49 | c_low_init : float | None, optional 50 | Initial lower bound for shape parameter. 51 | If None, uses kernel-specific default. 52 | c_high_init : float | None, optional 53 | Initial upper bound for shape parameter. 54 | If None, uses kernel-specific default. 55 | c_high_step : float | None, optional 56 | Step size for expanding upper bound search. 57 | If None, uses kernel-specific default. 58 | c_high_search_iters : int 59 | Maximum iterations for upper bound search. 60 | 61 | Returns 62 | ------- 63 | float | None 64 | Optimal shape parameter. 65 | Returns None for kernels that don't use shape parameters (e.g., PHS). 66 | """ 67 | # PHS doesn't use shape parameters 68 | if kernel == "polyharmonic_spline": 69 | return None 70 | 71 | # Use kernel-specific defaults if not provided 72 | defaults = KERNEL_SHAPE_DEFAULTS.get(kernel, KERNEL_SHAPE_DEFAULTS["imq"]) 73 | c_low_init = c_low_init or defaults["c_low"] 74 | c_high_init = c_high_init or defaults["c_high"] 75 | c_high_step = c_high_step or defaults["c_step"] 76 | 77 | cond_low, cond_high = cond_range 78 | 79 | # Step 1: Find upper bound where cond >= cond_low 80 | def search_c_high_iter(i: int, carry: tuple) -> tuple: 81 | c_high, found = carry 82 | C = build_collocation_matrix( 83 | train_params, params_range, kernel, c_high, kernel_order 84 | ) 85 | cond = jnp.linalg.cond(C) 86 | should_continue = (~found) & (cond < cond_low) 87 | new_c_high = jnp.where(should_continue, c_high + c_high_step, c_high) 88 | new_found = found | (cond >= cond_low) 89 | return (new_c_high, new_found) 90 | 91 | c_high, _ = jax.lax.fori_loop( 92 | 0, c_high_search_iters, search_c_high_iter, (c_high_init, False) 93 | ) 94 | 95 | # Step 2: Bisection to find optimal c in range 96 | def bisection_iter(i: int, carry: tuple) -> tuple: 97 | c_low_bound, c_high_bound, optim_c, found = carry 98 | 99 | mid_c = (c_low_bound + c_high_bound) / 2.0 100 | C = build_collocation_matrix( 101 | train_params, params_range, kernel, mid_c, kernel_order 102 | ) 103 | cond = jnp.linalg.cond(C) 104 | 105 | # Check if condition number is in target range 106 | in_range = (cond >= cond_low) & (cond <= cond_high) 107 | below_range = cond < cond_low 108 | 109 | # Update bounds based on condition number (only if not yet found) 110 | new_c_low = jnp.where(below_range & ~found, mid_c, c_low_bound) 111 | new_c_high = jnp.where((~below_range) & (~in_range) & ~found, mid_c, c_high_bound) 112 | new_optim_c = jnp.where(in_range & ~found, mid_c, optim_c) 113 | new_found = found | in_range 114 | 115 | return (new_c_low, new_c_high, new_optim_c, new_found) 116 | 117 | initial_guess = (c_low_init + c_high) / 2.0 118 | _, _, optim_c, _ = jax.lax.fori_loop( 119 | 0, max_iters, bisection_iter, (c_low_init, c_high, initial_guess, False) 120 | ) 121 | 122 | return optim_c 123 | -------------------------------------------------------------------------------- /pod_rbf/decomposition.py: -------------------------------------------------------------------------------- 1 | """ 2 | POD basis computation via SVD or eigendecomposition. 3 | """ 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | from jax import Array 8 | 9 | 10 | def compute_pod_basis_svd( 11 | snapshot: Array, 12 | energy_threshold: float, 13 | ) -> tuple[Array, Array, float]: 14 | """ 15 | Compute truncated POD basis via SVD. 16 | 17 | Use for smaller datasets (< mem_limit_gb). 18 | 19 | Parameters 20 | ---------- 21 | snapshot : Array 22 | Snapshot matrix, shape (n_samples, n_snapshots). 23 | energy_threshold : float 24 | Minimum fraction of total energy to retain (0 < threshold <= 1). 25 | 26 | Returns 27 | ------- 28 | basis : Array 29 | Truncated POD basis, shape (n_samples, n_modes). 30 | cumul_energy : Array 31 | Cumulative energy fraction per mode. 32 | truncated_energy : float 33 | Actual energy fraction retained. 34 | """ 35 | U, S, _ = jnp.linalg.svd(snapshot, full_matrices=False) 36 | 37 | cumul_energy = jnp.cumsum(S) / jnp.sum(S) 38 | 39 | # Handle energy_threshold >= 1 (keep all modes) 40 | keep_all = energy_threshold >= 1.0 41 | # Find first index where cumul_energy > threshold 42 | mask = cumul_energy > energy_threshold 43 | trunc_id = jnp.where( 44 | keep_all, 45 | len(S) - 1, 46 | jnp.where(jnp.any(mask), jnp.argmax(mask), len(S) - 1), 47 | ) 48 | 49 | truncated_energy = cumul_energy[trunc_id] 50 | 51 | # Dynamic slice to get truncated basis 52 | basis = jax.lax.dynamic_slice(U, (0, 0), (U.shape[0], trunc_id + 1)) 53 | 54 | return basis, cumul_energy, truncated_energy 55 | 56 | 57 | def compute_pod_basis_eig( 58 | snapshot: Array, 59 | energy_threshold: float, 60 | ) -> tuple[Array, Array, float]: 61 | """ 62 | Compute truncated POD basis via eigendecomposition. 63 | 64 | More memory-efficient for large datasets (>= mem_limit_gb). 65 | Computes (n_snapshots x n_snapshots) covariance instead of full SVD. 66 | 67 | Parameters 68 | ---------- 69 | snapshot : Array 70 | Snapshot matrix, shape (n_samples, n_snapshots). 71 | energy_threshold : float 72 | Minimum fraction of total energy to retain (0 < threshold <= 1). 73 | 74 | Returns 75 | ------- 76 | basis : Array 77 | Truncated POD basis, shape (n_samples, n_modes). 78 | cumul_energy : Array 79 | Cumulative energy fraction per mode. 80 | truncated_energy : float 81 | Actual energy fraction retained. 82 | """ 83 | # Covariance matrix (n_snapshots x n_snapshots) 84 | cov = snapshot.T @ snapshot 85 | eig_vals, eig_vecs = jnp.linalg.eigh(cov) 86 | 87 | # eigh returns ascending order, reverse to descending 88 | eig_vals = jnp.abs(eig_vals[::-1]) 89 | eig_vecs = eig_vecs[:, ::-1] 90 | 91 | cumul_energy = jnp.cumsum(eig_vals) / jnp.sum(eig_vals) 92 | 93 | # Handle energy_threshold >= 1 (keep all modes) 94 | keep_all = energy_threshold >= 1.0 95 | mask = cumul_energy > energy_threshold 96 | trunc_id = jnp.where( 97 | keep_all, 98 | len(eig_vals) - 1, 99 | jnp.where(jnp.any(mask), jnp.argmax(mask), len(eig_vals) - 1), 100 | ) 101 | 102 | truncated_energy = cumul_energy[trunc_id] 103 | 104 | # Truncate eigenvalues and eigenvectors 105 | eig_vals_trunc = jax.lax.dynamic_slice(eig_vals, (0,), (trunc_id + 1,)) 106 | eig_vecs_trunc = jax.lax.dynamic_slice( 107 | eig_vecs, (0, 0), (eig_vecs.shape[0], trunc_id + 1) 108 | ) 109 | 110 | # Compute POD basis from eigenvectors 111 | basis = (snapshot @ eig_vecs_trunc) / jnp.sqrt(eig_vals_trunc) 112 | 113 | return basis, cumul_energy, truncated_energy 114 | 115 | 116 | def compute_pod_basis( 117 | snapshot: Array, 118 | energy_threshold: float, 119 | use_eig: bool = False, 120 | ) -> tuple[Array, Array, float]: 121 | """ 122 | Compute truncated POD basis. 123 | 124 | Dispatches to SVD or eigendecomposition based on use_eig flag. 125 | The flag should be determined BEFORE JIT compilation based on memory. 126 | 127 | Parameters 128 | ---------- 129 | snapshot : Array 130 | Snapshot matrix, shape (n_samples, n_snapshots). 131 | energy_threshold : float 132 | Minimum fraction of total energy to retain (0 < threshold <= 1). 133 | use_eig : bool 134 | If True, use eigendecomposition (memory efficient for large data). 135 | If False, use SVD (faster for smaller data). 136 | 137 | Returns 138 | ------- 139 | basis : Array 140 | Truncated POD basis, shape (n_samples, n_modes). 141 | cumul_energy : Array 142 | Cumulative energy fraction per mode. 143 | truncated_energy : float 144 | Actual energy fraction retained. 145 | """ 146 | if use_eig: 147 | return compute_pod_basis_eig(snapshot, energy_threshold) 148 | return compute_pod_basis_svd(snapshot, energy_threshold) 149 | -------------------------------------------------------------------------------- /pod_rbf/io.py: -------------------------------------------------------------------------------- 1 | """ 2 | File I/O utilities for POD-RBF. 3 | 4 | Uses NumPy for file operations (not differentiable). 5 | """ 6 | 7 | import os 8 | import pickle 9 | 10 | import jax.numpy as jnp 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | from .types import ModelState 15 | 16 | 17 | def build_snapshot_matrix( 18 | dirpath: str, 19 | skiprows: int = 1, 20 | usecols: int | tuple[int, ...] = 0, 21 | verbose: bool = True, 22 | ) -> np.ndarray: 23 | """ 24 | Load snapshot matrix from CSV files in directory. 25 | 26 | Files are loaded in alphanumeric order. Ensure parameter array 27 | matches this ordering. 28 | 29 | Parameters 30 | ---------- 31 | dirpath : str 32 | Directory containing CSV files. 33 | skiprows : int 34 | Number of header rows to skip in each file. 35 | usecols : int or tuple 36 | Column(s) to read from each file. 37 | verbose : bool 38 | Show progress bar. 39 | 40 | Returns 41 | ------- 42 | np.ndarray 43 | Snapshot matrix, shape (n_samples, n_snapshots). 44 | Returns NumPy array - convert to JAX as needed. 45 | """ 46 | files = sorted( 47 | [ 48 | f 49 | for f in os.listdir(dirpath) 50 | if os.path.isfile(os.path.join(dirpath, f)) and f.endswith(".csv") 51 | ] 52 | ) 53 | 54 | if not files: 55 | raise ValueError(f"No CSV files found in {dirpath}") 56 | 57 | # Get dimensions from first file 58 | first_data = np.loadtxt( 59 | os.path.join(dirpath, files[0]), 60 | delimiter=",", 61 | skiprows=skiprows, 62 | usecols=usecols, 63 | ) 64 | n_samples = len(first_data) if first_data.ndim > 0 else 1 65 | n_snapshots = len(files) 66 | 67 | snapshot = np.zeros((n_samples, n_snapshots)) 68 | 69 | iterator = tqdm(files, desc="Loading snapshots") if verbose else files 70 | for i, f in enumerate(iterator): 71 | data = np.loadtxt( 72 | os.path.join(dirpath, f), 73 | delimiter=",", 74 | skiprows=skiprows, 75 | usecols=usecols, 76 | ) 77 | data_len = len(data) if data.ndim > 0 else 1 78 | assert data_len == n_samples, f"Inconsistent samples in {f}: got {data_len}, expected {n_samples}" 79 | snapshot[:, i] = data 80 | 81 | return snapshot 82 | 83 | 84 | def save_model(filename: str, state: ModelState) -> None: 85 | """ 86 | Save model state to file. 87 | 88 | Parameters 89 | ---------- 90 | filename : str 91 | Output filename. 92 | state : ModelState 93 | Trained model state. 94 | """ 95 | # Convert JAX arrays to NumPy for pickling 96 | state_dict = { 97 | "basis": np.asarray(state.basis), 98 | "weights": np.asarray(state.weights), 99 | "shape_factor": float(state.shape_factor) if state.shape_factor is not None else None, 100 | "train_params": np.asarray(state.train_params), 101 | "params_range": np.asarray(state.params_range), 102 | "truncated_energy": float(state.truncated_energy), 103 | "cumul_energy": np.asarray(state.cumul_energy), 104 | "poly_coeffs": np.asarray(state.poly_coeffs) if state.poly_coeffs is not None else None, 105 | "poly_degree": int(state.poly_degree), 106 | "kernel": state.kernel, 107 | "kernel_order": int(state.kernel_order), 108 | } 109 | with open(filename, "wb") as f: 110 | pickle.dump(state_dict, f) 111 | 112 | 113 | def load_model(filename: str) -> ModelState: 114 | """ 115 | Load model state from file. 116 | 117 | Parameters 118 | ---------- 119 | filename : str 120 | Input filename. 121 | 122 | Returns 123 | ------- 124 | ModelState 125 | Loaded model state with JAX arrays. 126 | """ 127 | with open(filename, "rb") as f: 128 | state_dict = pickle.load(f) 129 | 130 | # Handle backward compatibility for models saved without poly fields 131 | poly_coeffs = state_dict.get("poly_coeffs") 132 | if poly_coeffs is not None: 133 | poly_coeffs = jnp.array(poly_coeffs) 134 | poly_degree = state_dict.get("poly_degree", 0) 135 | 136 | # Handle backward compatibility for models saved without kernel fields 137 | # Legacy models used IMQ kernel only 138 | kernel = state_dict.get("kernel", "imq") 139 | kernel_order = state_dict.get("kernel_order", 3) 140 | 141 | return ModelState( 142 | basis=jnp.array(state_dict["basis"]), 143 | weights=jnp.array(state_dict["weights"]), 144 | shape_factor=state_dict["shape_factor"], 145 | train_params=jnp.array(state_dict["train_params"]), 146 | params_range=jnp.array(state_dict["params_range"]), 147 | truncated_energy=state_dict["truncated_energy"], 148 | cumul_energy=jnp.array(state_dict["cumul_energy"]), 149 | poly_coeffs=poly_coeffs, 150 | poly_degree=poly_degree, 151 | kernel=kernel, 152 | kernel_order=kernel_order, 153 | ) 154 | -------------------------------------------------------------------------------- /tests/test_io.py: -------------------------------------------------------------------------------- 1 | """Tests for I/O utilities.""" 2 | 3 | import os 4 | import tempfile 5 | 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import pytest 9 | 10 | from pod_rbf.io import build_snapshot_matrix, load_model, save_model 11 | from pod_rbf.types import ModelState 12 | 13 | 14 | class TestSaveLoadModel: 15 | """Test model serialization.""" 16 | 17 | @pytest.fixture 18 | def sample_state(self): 19 | """Create sample model state.""" 20 | return ModelState( 21 | basis=jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), 22 | weights=jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]), 23 | shape_factor=0.5, 24 | train_params=jnp.array([[1.0, 2.0, 3.0]]), 25 | params_range=jnp.array([2.0]), 26 | truncated_energy=0.99, 27 | cumul_energy=jnp.array([0.9, 0.99]), 28 | poly_coeffs=jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]), 29 | poly_degree=2, 30 | kernel="imq", 31 | kernel_order=3, 32 | ) 33 | 34 | def test_save_load_roundtrip(self, sample_state): 35 | """Saved and loaded state should match.""" 36 | with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: 37 | filename = f.name 38 | 39 | try: 40 | save_model(filename, sample_state) 41 | loaded = load_model(filename) 42 | 43 | assert jnp.allclose(loaded.basis, sample_state.basis) 44 | assert jnp.allclose(loaded.weights, sample_state.weights) 45 | assert loaded.shape_factor == sample_state.shape_factor 46 | assert jnp.allclose(loaded.train_params, sample_state.train_params) 47 | assert jnp.allclose(loaded.params_range, sample_state.params_range) 48 | assert loaded.truncated_energy == sample_state.truncated_energy 49 | assert jnp.allclose(loaded.cumul_energy, sample_state.cumul_energy) 50 | assert jnp.allclose(loaded.poly_coeffs, sample_state.poly_coeffs) 51 | assert loaded.poly_degree == sample_state.poly_degree 52 | finally: 53 | os.unlink(filename) 54 | 55 | def test_loaded_state_is_model_state(self, sample_state): 56 | """Loaded state should be ModelState instance.""" 57 | with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f: 58 | filename = f.name 59 | 60 | try: 61 | save_model(filename, sample_state) 62 | loaded = load_model(filename) 63 | 64 | assert isinstance(loaded, ModelState) 65 | finally: 66 | os.unlink(filename) 67 | 68 | 69 | class TestBuildSnapshotMatrix: 70 | """Test snapshot matrix loading from CSV files.""" 71 | 72 | @pytest.fixture 73 | def csv_dir(self): 74 | """Create temporary directory with CSV files.""" 75 | with tempfile.TemporaryDirectory() as tmpdir: 76 | # Create 3 CSV files with 5 data points each 77 | for i in range(3): 78 | data = np.arange(5) * (i + 1) # [0,1,2,3,4] * (i+1) 79 | filepath = os.path.join(tmpdir, f"data_{i:03d}.csv") 80 | np.savetxt(filepath, data[:, None], delimiter=",", header="value", comments="") 81 | 82 | yield tmpdir 83 | 84 | def test_loads_all_files(self, csv_dir): 85 | """Should load all CSV files in directory.""" 86 | snapshot = build_snapshot_matrix(csv_dir, verbose=False) 87 | 88 | assert snapshot.shape == (5, 3), f"Expected (5, 3), got {snapshot.shape}" 89 | 90 | def test_correct_values(self, csv_dir): 91 | """Should load correct values.""" 92 | snapshot = build_snapshot_matrix(csv_dir, verbose=False) 93 | 94 | # File data_000.csv has [0,1,2,3,4] 95 | assert np.allclose(snapshot[:, 0], [0, 1, 2, 3, 4]) 96 | # File data_001.csv has [0,2,4,6,8] 97 | assert np.allclose(snapshot[:, 1], [0, 2, 4, 6, 8]) 98 | # File data_002.csv has [0,3,6,9,12] 99 | assert np.allclose(snapshot[:, 2], [0, 3, 6, 9, 12]) 100 | 101 | def test_alphanumeric_order(self, csv_dir): 102 | """Files should be loaded in alphanumeric order.""" 103 | snapshot = build_snapshot_matrix(csv_dir, verbose=False) 104 | 105 | # First file should be data_000.csv 106 | # Last file should be data_002.csv 107 | assert snapshot[1, 0] == 1 # data_000: index 1 = 1 108 | assert snapshot[1, 2] == 3 # data_002: index 1 = 3 109 | 110 | def test_empty_dir_raises(self): 111 | """Should raise ValueError for empty directory.""" 112 | with tempfile.TemporaryDirectory() as tmpdir: 113 | with pytest.raises(ValueError, match="No CSV files"): 114 | build_snapshot_matrix(tmpdir, verbose=False) 115 | 116 | def test_inconsistent_samples_raises(self): 117 | """Should raise AssertionError for inconsistent sample counts.""" 118 | with tempfile.TemporaryDirectory() as tmpdir: 119 | # Create files with different row counts 120 | np.savetxt(os.path.join(tmpdir, "a.csv"), [1, 2, 3], delimiter=",", header="x", comments="") 121 | np.savetxt(os.path.join(tmpdir, "b.csv"), [1, 2], delimiter=",", header="x", comments="") 122 | 123 | with pytest.raises(AssertionError, match="Inconsistent"): 124 | build_snapshot_matrix(tmpdir, verbose=False) 125 | -------------------------------------------------------------------------------- /tests/test_decomposition.py: -------------------------------------------------------------------------------- 1 | """Tests for POD basis decomposition.""" 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import pytest 6 | 7 | from pod_rbf.decomposition import ( 8 | compute_pod_basis, 9 | compute_pod_basis_eig, 10 | compute_pod_basis_svd, 11 | ) 12 | 13 | 14 | class TestComputePodBasisSvd: 15 | """Test SVD-based POD basis computation.""" 16 | 17 | def test_orthonormality(self): 18 | """SVD-based basis should be orthonormal.""" 19 | np.random.seed(42) 20 | snapshot = jnp.array(np.random.randn(100, 10)) 21 | 22 | basis, _, _ = compute_pod_basis_svd(snapshot, 0.99) 23 | 24 | gram = basis.T @ basis 25 | identity = jnp.eye(gram.shape[0]) 26 | assert jnp.allclose(gram, identity, atol=1e-10), "Basis should be orthonormal" 27 | 28 | def test_energy_threshold(self): 29 | """Retained energy should meet threshold.""" 30 | np.random.seed(42) 31 | snapshot = jnp.array(np.random.randn(50, 10)) 32 | 33 | basis, cumul_energy, truncated_energy = compute_pod_basis_svd(snapshot, 0.95) 34 | 35 | assert truncated_energy >= 0.95, f"Truncated energy {truncated_energy} < 0.95" 36 | 37 | def test_basis_shape(self): 38 | """Basis should have correct shape.""" 39 | np.random.seed(42) 40 | snapshot = jnp.array(np.random.randn(100, 10)) 41 | 42 | basis, _, _ = compute_pod_basis_svd(snapshot, 0.99) 43 | 44 | assert basis.shape[0] == 100, "Basis should have n_samples rows" 45 | assert basis.shape[1] <= 10, "Basis should have at most n_snapshots columns" 46 | assert basis.shape[1] >= 1, "Basis should have at least 1 column" 47 | 48 | def test_cumul_energy_monotonic(self): 49 | """Cumulative energy should be monotonically increasing.""" 50 | np.random.seed(42) 51 | snapshot = jnp.array(np.random.randn(50, 15)) 52 | 53 | _, cumul_energy, _ = compute_pod_basis_svd(snapshot, 0.99) 54 | 55 | diffs = jnp.diff(cumul_energy) 56 | assert jnp.all(diffs >= 0), "Cumulative energy should be monotonically increasing" 57 | 58 | def test_keep_all_energy(self): 59 | """With threshold >= 1, should keep all modes.""" 60 | np.random.seed(42) 61 | snapshot = jnp.array(np.random.randn(50, 10)) 62 | 63 | basis, _, truncated_energy = compute_pod_basis_svd(snapshot, 1.0) 64 | 65 | assert basis.shape[1] == 10, "Should keep all modes when threshold=1.0" 66 | assert jnp.isclose(truncated_energy, 1.0, atol=1e-10) 67 | 68 | 69 | class TestComputePodBasisEig: 70 | """Test eigendecomposition-based POD basis computation.""" 71 | 72 | def test_orthonormality(self): 73 | """Eig-based basis should be orthonormal.""" 74 | np.random.seed(42) 75 | snapshot = jnp.array(np.random.randn(100, 10)) 76 | 77 | basis, _, _ = compute_pod_basis_eig(snapshot, 0.99) 78 | 79 | gram = basis.T @ basis 80 | identity = jnp.eye(gram.shape[0]) 81 | assert jnp.allclose(gram, identity, atol=1e-8), "Basis should be orthonormal" 82 | 83 | def test_energy_threshold(self): 84 | """Retained energy should meet threshold.""" 85 | np.random.seed(42) 86 | snapshot = jnp.array(np.random.randn(50, 10)) 87 | 88 | basis, cumul_energy, truncated_energy = compute_pod_basis_eig(snapshot, 0.95) 89 | 90 | assert truncated_energy >= 0.95, f"Truncated energy {truncated_energy} < 0.95" 91 | 92 | def test_basis_shape(self): 93 | """Basis should have correct shape.""" 94 | np.random.seed(42) 95 | snapshot = jnp.array(np.random.randn(100, 10)) 96 | 97 | basis, _, _ = compute_pod_basis_eig(snapshot, 0.99) 98 | 99 | assert basis.shape[0] == 100, "Basis should have n_samples rows" 100 | assert basis.shape[1] <= 10, "Basis should have at most n_snapshots columns" 101 | assert basis.shape[1] >= 1, "Basis should have at least 1 column" 102 | 103 | 104 | class TestSvdEigEquivalence: 105 | """Test that SVD and eigendecomposition produce equivalent results.""" 106 | 107 | def test_span_equivalence(self): 108 | """SVD and eig bases should span the same subspace.""" 109 | np.random.seed(42) 110 | snapshot = jnp.array(np.random.randn(100, 10)) 111 | 112 | basis_svd, _, _ = compute_pod_basis_svd(snapshot, 0.99) 113 | basis_eig, _, _ = compute_pod_basis_eig(snapshot, 0.99) 114 | 115 | # Same number of modes (may differ by 1 due to numerical differences) 116 | n_modes_svd = basis_svd.shape[1] 117 | n_modes_eig = basis_eig.shape[1] 118 | assert abs(n_modes_svd - n_modes_eig) <= 1, "Mode counts should be similar" 119 | 120 | # Project onto common subspace - projection matrices should be similar 121 | proj_svd = basis_svd @ basis_svd.T 122 | proj_eig = basis_eig @ basis_eig.T 123 | 124 | # They span similar subspaces if the projections are close 125 | # (accounting for potentially different number of modes) 126 | min_modes = min(n_modes_svd, n_modes_eig) 127 | basis_svd_trunc = basis_svd[:, :min_modes] 128 | basis_eig_trunc = basis_eig[:, :min_modes] 129 | proj_svd_trunc = basis_svd_trunc @ basis_svd_trunc.T 130 | proj_eig_trunc = basis_eig_trunc @ basis_eig_trunc.T 131 | 132 | assert jnp.allclose(proj_svd_trunc, proj_eig_trunc, atol=1e-6), "Projections should be similar" 133 | 134 | def test_energy_equivalence(self): 135 | """SVD and eig should both meet energy threshold.""" 136 | np.random.seed(42) 137 | snapshot = jnp.array(np.random.randn(100, 10)) 138 | 139 | _, _, energy_svd = compute_pod_basis_svd(snapshot, 0.95) 140 | _, _, energy_eig = compute_pod_basis_eig(snapshot, 0.95) 141 | 142 | # Both should meet threshold (may differ due to discrete truncation) 143 | assert energy_svd >= 0.95, f"SVD energy {energy_svd} below threshold" 144 | assert energy_eig >= 0.95, f"Eig energy {energy_eig} below threshold" 145 | 146 | 147 | class TestComputePodBasis: 148 | """Test dispatch function.""" 149 | 150 | def test_dispatch_svd(self): 151 | """use_eig=False should use SVD.""" 152 | np.random.seed(42) 153 | snapshot = jnp.array(np.random.randn(50, 10)) 154 | 155 | basis, _, _ = compute_pod_basis(snapshot, 0.99, use_eig=False) 156 | basis_svd, _, _ = compute_pod_basis_svd(snapshot, 0.99) 157 | 158 | assert jnp.allclose(basis, basis_svd), "use_eig=False should match SVD" 159 | 160 | def test_dispatch_eig(self): 161 | """use_eig=True should use eigendecomposition.""" 162 | np.random.seed(42) 163 | snapshot = jnp.array(np.random.randn(50, 10)) 164 | 165 | basis, _, _ = compute_pod_basis(snapshot, 0.99, use_eig=True) 166 | basis_eig, _, _ = compute_pod_basis_eig(snapshot, 0.99) 167 | 168 | assert jnp.allclose(basis, basis_eig), "use_eig=True should match eig" 169 | -------------------------------------------------------------------------------- /pod_rbf/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core POD-RBF training and inference functions. 3 | 4 | Pure functional interface for JAX autodiff compatibility. 5 | """ 6 | 7 | import jax.numpy as jnp 8 | from jax import Array 9 | 10 | from .decomposition import compute_pod_basis 11 | from .rbf import ( 12 | build_collocation_matrix, 13 | build_inference_matrix, 14 | build_polynomial_basis, 15 | solve_augmented_system_direct, 16 | solve_augmented_system_schur, 17 | ) 18 | from .shape_optimization import find_optimal_shape_param 19 | from .types import ModelState, TrainConfig, TrainResult 20 | 21 | 22 | def _normalize_params(params: Array) -> Array: 23 | """Ensure params is 2D: (n_params, n_points).""" 24 | if params.ndim == 1: 25 | return params[None, :] 26 | return params 27 | 28 | 29 | def train( 30 | snapshot: Array, 31 | train_params: Array, 32 | config: TrainConfig = TrainConfig(), 33 | shape_factor: float | None = None, 34 | ) -> TrainResult: 35 | """ 36 | Train POD-RBF model. 37 | 38 | Parameters 39 | ---------- 40 | snapshot : Array 41 | Solution snapshots, shape (n_samples, n_snapshots). 42 | Each column is a snapshot at a different parameter value. 43 | train_params : Array 44 | Parameter values, shape (n_snapshots,) or (n_params, n_snapshots). 45 | config : TrainConfig 46 | Training configuration. 47 | shape_factor : float, optional 48 | RBF shape parameter. If None, automatically optimized. 49 | 50 | Returns 51 | ------- 52 | TrainResult 53 | Training result containing model state and diagnostics. 54 | """ 55 | train_params = _normalize_params(jnp.asarray(train_params)) 56 | snapshot = jnp.asarray(snapshot) 57 | n_params, n_snapshots = train_params.shape 58 | 59 | assert snapshot.shape[1] == n_snapshots, ( 60 | f"Mismatch: {snapshot.shape[1]} snapshots vs {n_snapshots} params" 61 | ) 62 | 63 | # Compute parameter ranges for normalization 64 | params_range = jnp.ptp(train_params, axis=1) 65 | 66 | # Determine decomposition method based on memory 67 | memory_gb = snapshot.nbytes / 1e9 68 | use_eig = memory_gb >= config.mem_limit_gb 69 | 70 | # Find optimal shape factor if not provided 71 | if shape_factor is None: 72 | shape_factor = find_optimal_shape_param( 73 | train_params, 74 | params_range, 75 | kernel=config.kernel, 76 | kernel_order=config.kernel_order, 77 | cond_range=config.cond_range, 78 | max_iters=config.max_bisection_iters, 79 | c_low_init=config.c_low_init, 80 | c_high_init=config.c_high_init, 81 | c_high_step=config.c_high_step, 82 | c_high_search_iters=config.c_high_search_iters, 83 | ) 84 | 85 | # Compute truncated POD basis 86 | basis, cumul_energy, truncated_energy = compute_pod_basis( 87 | snapshot, config.energy_threshold, use_eig=use_eig 88 | ) 89 | 90 | # Build collocation matrix 91 | F = build_collocation_matrix( 92 | train_params, 93 | params_range, 94 | kernel=config.kernel, 95 | shape_factor=shape_factor, 96 | kernel_order=config.kernel_order, 97 | ) 98 | A = basis.T @ snapshot # (n_modes, n_train) 99 | 100 | # Compute weights using Schur complement solver or fallback to pinv 101 | poly_degree = config.poly_degree 102 | if poly_degree > 0: 103 | P = build_polynomial_basis(train_params, params_range, poly_degree) 104 | # Use direct solver for PHS (F is not SPD), Schur for others (F is SPD) 105 | if config.kernel == "polyharmonic_spline": 106 | weights, poly_coeffs = solve_augmented_system_direct(F, P, A) 107 | else: 108 | weights, poly_coeffs = solve_augmented_system_schur(F, P, A) 109 | else: 110 | weights = A @ jnp.linalg.pinv(F.T) 111 | poly_coeffs = None 112 | 113 | state = ModelState( 114 | basis=basis, 115 | weights=weights, 116 | shape_factor=float(shape_factor) if shape_factor is not None else None, 117 | train_params=train_params, 118 | params_range=params_range, 119 | truncated_energy=float(truncated_energy), 120 | cumul_energy=cumul_energy, 121 | poly_coeffs=poly_coeffs, 122 | poly_degree=poly_degree, 123 | kernel=config.kernel, 124 | kernel_order=config.kernel_order, 125 | ) 126 | 127 | return TrainResult( 128 | state=state, 129 | n_modes=basis.shape[1], 130 | used_eig_decomp=use_eig, 131 | ) 132 | 133 | 134 | def _inference_impl( 135 | basis: Array, 136 | weights: Array, 137 | train_params: Array, 138 | params_range: Array, 139 | shape_factor: float | None, 140 | poly_coeffs: Array | None, 141 | poly_degree: int, 142 | inf_params: Array, 143 | kernel: str, 144 | kernel_order: int, 145 | ) -> Array: 146 | """Core inference implementation for JIT compilation.""" 147 | F = build_inference_matrix( 148 | train_params, 149 | inf_params, 150 | params_range, 151 | kernel=kernel, 152 | shape_factor=shape_factor, 153 | kernel_order=kernel_order, 154 | ) 155 | 156 | # RBF contribution 157 | A = weights @ F.T # (n_modes, n_inf) 158 | 159 | # Add polynomial contribution if used 160 | if poly_coeffs is not None: 161 | P_inf = build_polynomial_basis(inf_params, params_range, poly_degree) 162 | A = A + poly_coeffs @ P_inf.T # (n_modes, n_inf) 163 | 164 | return basis @ A 165 | 166 | 167 | def inference(state: ModelState, inf_params: Array) -> Array: 168 | """ 169 | Inference trained model at multiple parameter points. 170 | 171 | Parameters 172 | ---------- 173 | state : ModelState 174 | Trained model state from train(). 175 | inf_params : Array 176 | Inference parameters, shape (n_params, n_points) or (n_points,). 177 | 178 | Returns 179 | ------- 180 | Array 181 | Predicted solutions, shape (n_samples, n_points). 182 | """ 183 | inf_params = _normalize_params(jnp.asarray(inf_params)) 184 | 185 | # Extract poly_degree as Python int before JIT tracing 186 | poly_degree = int(state.poly_degree) if state.poly_coeffs is not None else 0 187 | 188 | return _inference_impl( 189 | state.basis, 190 | state.weights, 191 | state.train_params, 192 | state.params_range, 193 | state.shape_factor, 194 | state.poly_coeffs, 195 | poly_degree, 196 | inf_params, 197 | kernel=state.kernel, 198 | kernel_order=state.kernel_order, 199 | ) 200 | 201 | 202 | def inference_single(state: ModelState, inf_param: Array) -> Array: 203 | """ 204 | Inference trained model at a single parameter point. 205 | 206 | More convenient for gradient computation. 207 | 208 | Parameters 209 | ---------- 210 | state : ModelState 211 | Trained model state from train(). 212 | inf_param : Array 213 | Single inference parameter, scalar or shape (n_params,). 214 | 215 | Returns 216 | ------- 217 | Array 218 | Predicted solution, shape (n_samples,). 219 | """ 220 | inf_param = jnp.asarray(inf_param) 221 | 222 | # Handle scalar input 223 | if inf_param.ndim == 0: 224 | inf_param = inf_param[None] 225 | 226 | # Shape to (n_params, 1) for inference 227 | inf_params = inf_param[:, None] 228 | 229 | result = inference(state, inf_params) 230 | return result[:, 0] 231 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # POD-RBF 2 | 3 | [![Tests](https://github.com/kylebeggs/POD-RBF/actions/workflows/tests.yml/badge.svg)](https://github.com/kylebeggs/POD-RBF/actions/workflows/tests.yml) 4 | [![codecov](https://codecov.io/gh/kylebeggs/POD-RBF/branch/master/graph/badge.svg)](https://codecov.io/gh/kylebeggs/POD-RBF) 5 | [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) 6 | 7 | ![Re-450](examples/lid-driven-cavity/results-re-450.png) 8 | 9 | A Python package for building a Reduced Order Model (ROM) from high-dimensional data using a Proper 10 | Orthogonal Decomposition - Radial Basis Function (POD-RBF) Network. 11 | 12 | Given a 'snapshot' matrix of the data points with varying parameters, this code contains functions 13 | to find the truncated POD basis and interpolate using a RBF network for new parameters. 14 | 15 | This code is based on the following papers implementing the method: 16 | 17 | 1. [Solving inverse heat conduction problems using trained POD-RBF network inverse method - Ostrowski, Bialecki, Kassab (2008)](https://www.tandfonline.com/doi/full/10.1080/17415970701198290) 18 | 2. [RBF-trained POD-accelerated CFD analysis of wind loads on PV systems - Huayamave, Ceballos, Barriento, Seigneur, Barkaszi, Divo, and Kassab (2017)](https://www.emerald.com/insight/content/doi/10.1108/HFF-03-2016-0083/full/html) 19 | 3. [Real-Time Thermomechanical Modeling of PV Cell Fabrication via a POD-Trained RBF Interpolation Network - Das, Khoury, Divo, Huayamave, Ceballos, Eaglin, Kassab, Payne, Yelundur, and Seigneur (2020)](https://www.techscience.com/CMES/v122n3/38374) 20 | 21 | Features: 22 | 23 | * **JAX-based** - enables autodifferentiation for gradient optimization, sensitivity analysis, and inverse problems 24 | * Shape parameter optimization for the Radial Basis Functions (RBFs) 25 | * Algorithm switching based on memory requirements (eigenvalue decomposition vs. SVD) 26 | 27 | ## Installation 28 | 29 | ```bash 30 | pip install pod-rbf 31 | # or 32 | uv add pod-rbf 33 | ``` 34 | 35 | ## Example 36 | 37 | In the [example](https://github.com/kylebeggs/POD-RBF/tree/master/examples) folder you can find two 38 | examples of that demonstrates how to use the package. The first is a simple heat conduction problem 39 | on a unit square. 40 | 41 | The other example will be demonstrated step-by-step here. We seek to build a ROM of the 2D 42 | lid-driven cavity problem. For the impatient, here is the full code to run this example. I will 43 | break down each line in the sections below. 44 | 45 | If you wish to build a ROM with multiple parameters, see this basic [2-parameter example](https://github.com/kylebeggs/POD-RBF/tree/master/examples/2-parameters.ipynb). 46 | 47 | ```python 48 | import pod_rbf 49 | import jax.numpy as jnp 50 | import numpy as np 51 | 52 | Re = np.linspace(0, 1000, num=11) 53 | Re[0] = 1 54 | 55 | # make snapshot matrix from csv files 56 | train_snapshot = pod_rbf.build_snapshot_matrix("examples/lid-driven-cavity/data/train/") 57 | 58 | # train the model (keeps 99% energy in POD modes by default) 59 | result = pod_rbf.train(train_snapshot, Re) 60 | 61 | # inference on an unseen parameter 62 | sol = pod_rbf.inference_single(result.state, jnp.array(450.0)) 63 | ``` 64 | 65 | ### Building the snapshot matrix 66 | 67 | First, we need to build the snapshot matrix, X, which contains the data we are training on. It must be of the form where each column is the k-th 'snapshot' of the solution field given some 68 | parameter, p_k, with n samples in the snapshot at locations x_n. A single snapshot is below 69 | 70 | ![snapshot equation](examples/lid-driven-cavity/eq-snapshot.png) 71 | 72 | and the snapshot matrix would then look like 73 | 74 | ![snapshot equation](examples/lid-driven-cavity/eq-snapshot-matrix.png) 75 | 76 | where m is the total number of snapshots. 77 | 78 | For example, suppose our lid-driven cavity was solved on a mesh with 400 cells and we varied the 79 | parameter of interest (Re number in this case) 10 times. We would have a matrix of size (n,m) = 80 | (400,10). 81 | 82 | For our example, solutions were generated using STAR-CCM+ for Reynolds numbers of 1-1000 in 83 | increments of 100. The Re number will serve as our single parameter in this case. Snapshots were 84 | generated as a separate .csv file for each. To make it easier to combine them all into the snapshot 85 | matrix, there is a function which takes the path and file pattern. The same syntax is borrowed from 86 | the ffmpeg tool - that is, if you had files named as sample_001.csv, sample_002.csv ... you would 87 | input sample_%03d.csv. The files for this example are named as re-%04d.csv so we would issue a 88 | command as 89 | 90 | ```python 91 | >>> import pod_rbf 92 | >>> train_snapshot = pod_rbf.build_snapshot_matrix("examples/lid-driven-cavity/data/train/") 93 | ``` 94 | 95 | --- 96 | Note: if you are using this approach where each snapshot is contained in a different csv file, 97 | please group all of them into a directory of their own. 98 | 99 | --- 100 | 101 | If you notice, these files are contained in the train folder, as I also generated some more 102 | snapshots for validation (which as you probably guessed is in the /data/validation folder). now we 103 | need to generate the array of input parameters that correspond to each snapshot. 104 | 105 | ```python 106 | >>> Re = np.linspace(0, 1000, num=11) 107 | >>> Re[0] = 1 108 | ``` 109 | 110 | --- 111 | Note: it is extremely important that each input parameter maps to the same element number of the 112 | snapshot matrix. For example if the 5th column (index 4) then the input parameter used to generate 113 | that snapshot should be what you find in the 5th element (index 4) of the array, e.g. 114 | ```train_snapshot[:,4] -> Re[4]```. The csv files are loaded in alpha-numeric order so that is why 115 | the input parameter array goes from 1 -> 1000. 116 | 117 | --- 118 | 119 | where ```Re``` is an array of input parameters that we are training the model on. Next, we train 120 | the model with a single function call. We choose to keep 99% of the energy in POD modes (this is 121 | the default, so you don't have to set that). 122 | 123 | ```python 124 | >>> result = pod_rbf.train(train_snapshot, Re) 125 | >>> # Or with custom config: 126 | >>> config = pod_rbf.TrainConfig(energy_threshold=0.99) 127 | >>> result = pod_rbf.train(train_snapshot, Re, config) 128 | ``` 129 | 130 | Now that the weights and truncated POD basis have been calculated and stored in `result.state`, we 131 | can inference on the model using any input parameter. 132 | 133 | ```python 134 | >>> import jax.numpy as jnp 135 | >>> sol = pod_rbf.inference_single(result.state, jnp.array(450.0)) 136 | ``` 137 | 138 | and we can plot the results comparing the inference and target below 139 | 140 | ![Re-450](examples/lid-driven-cavity/results-re-450.png) 141 | 142 | and for Reynold's number of 50: 143 | 144 | ![Re-450](examples/lid-driven-cavity/results-re-50.png) 145 | 146 | 147 | ### Saving and loading models 148 | You can save and load the trained model state: 149 | 150 | ```python 151 | >>> pod_rbf.save_model("model.pkl", result.state) 152 | >>> state = pod_rbf.load_model("model.pkl") 153 | >>> sol = pod_rbf.inference_single(state, jnp.array(450.0)) 154 | ``` 155 | 156 | ### Autodifferentiation 157 | 158 | Since POD-RBF is built on JAX, you can compute gradients for optimization and inverse problems: 159 | 160 | ```python 161 | >>> import jax 162 | >>> grad_fn = jax.grad(lambda p: jnp.sum(pod_rbf.inference_single(result.state, p)**2)) 163 | >>> gradient = grad_fn(jnp.array(450.0)) 164 | ``` 165 | -------------------------------------------------------------------------------- /examples/heat-conduction/design-optimization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Design Optimization with Autodiff 3 | 4 | Demonstrates using JAX autodifferentiation through POD-RBF for design optimization. 5 | 6 | Problem: 1D nonlinear heat conduction with temperature-dependent thermal conductivity 7 | k(T) = k₀(1 + βT) 8 | 9 | The nonlinear dependence on boundary temperature T_L creates a non-trivial optimization 10 | landscape requiring multiple POD modes to capture. 11 | 12 | Objective: Find boundary temperature T_L that achieves a target average temperature. 13 | """ 14 | 15 | import time 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | import pod_rbf 22 | 23 | 24 | jax.config.update("jax_default_device", jax.devices("cpu")[0]) 25 | 26 | # Physical parameters 27 | T_0 = 300.0 # Left boundary temperature (K) 28 | BETA = 0.002 # Temperature coefficient for conductivity 29 | L = 1.0 # Domain length 30 | 31 | 32 | def analytical_solution(x, T_L): 33 | """ 34 | Analytical solution for 1D steady heat conduction with k(T) = k₀(1 + βT). 35 | 36 | Uses Kirchhoff transformation: θ = T + (β/2)T² 37 | The transformed variable θ satisfies linear diffusion, giving θ(x) linear in x. 38 | Inverting the quadratic yields T(x). 39 | """ 40 | theta_0 = T_0 + (BETA / 2) * T_0**2 41 | theta_L = T_L + (BETA / 2) * T_L**2 42 | theta = theta_0 + (theta_L - theta_0) * (x / L) 43 | return (-1 + np.sqrt(1 + 2 * BETA * theta)) / BETA 44 | 45 | 46 | def build_snapshot_matrix(T_L_values, x): 47 | """Build snapshot matrix from analytical solutions at different T_L values.""" 48 | print("Building snapshot matrix... ", end="") 49 | start = time.time() 50 | 51 | n_points = len(x) 52 | n_snapshots = len(T_L_values) 53 | snapshot = np.zeros((n_points, n_snapshots)) 54 | 55 | for i, T_L in enumerate(T_L_values): 56 | snapshot[:, i] = analytical_solution(x, T_L) 57 | 58 | print(f"took {time.time() - start:.3f} sec") 59 | return snapshot 60 | 61 | 62 | def run_optimization(): 63 | # Spatial discretization 64 | n_points = 100 65 | x = np.linspace(0, L, n_points) 66 | 67 | # Training: sample T_L over range [350, 600] K 68 | T_L_train = np.linspace(350, 600, num=20) 69 | snapshot = build_snapshot_matrix(T_L_train, x) 70 | 71 | # Train ROM 72 | config = pod_rbf.TrainConfig(energy_threshold=0.9999, poly_degree=2) 73 | result = pod_rbf.train(snapshot, T_L_train, config) 74 | state = result.state 75 | 76 | print(f"Trained with {result.n_modes} modes, energy retained: {state.truncated_energy:.6f}") 77 | print(f"Cumulative energy per mode: {state.cumul_energy}") 78 | 79 | # Target: achieve average temperature of 400 K 80 | T_target = 400.0 81 | 82 | def objective(T_L): 83 | """Squared error between predicted average temp and target.""" 84 | pred = pod_rbf.inference_single(state, T_L) 85 | avg_temp = jnp.mean(pred) 86 | return (avg_temp - T_target) ** 2 87 | 88 | # JIT-compile functions 89 | grad_fn = jax.jit(jax.grad(objective)) 90 | obj_fn = jax.jit(objective) 91 | 92 | # Also track average temperature 93 | @jax.jit 94 | def avg_temp_fn(T_L): 95 | return jnp.mean(pod_rbf.inference_single(state, T_L)) 96 | 97 | # Optimization settings 98 | T_L_init = jnp.array(550.0) # Start far from optimal 99 | T_L = T_L_init 100 | lr = 1.0 # Learning rate 101 | n_iters = 30 102 | T_L_min, T_L_max = 350.0, 600.0 # Valid parameter range 103 | 104 | # Track history 105 | history = { 106 | "T_L": [float(T_L)], 107 | "objective": [float(obj_fn(T_L))], 108 | "avg_temp": [float(avg_temp_fn(T_L))], 109 | "grad": [], 110 | } 111 | 112 | print(f"\nOptimizing: find T_L to achieve average temperature = {T_target} K") 113 | print(f"Initial: T_L = {T_L_init:.1f} K, avg_temp = {history['avg_temp'][0]:.2f} K") 114 | 115 | # Gradient descent 116 | for i in range(n_iters): 117 | grad = grad_fn(T_L) 118 | history["grad"].append(float(grad)) 119 | 120 | T_L = T_L - lr * grad 121 | T_L = jnp.clip(T_L, T_L_min, T_L_max) 122 | 123 | obj_val = obj_fn(T_L) 124 | avg_temp = avg_temp_fn(T_L) 125 | history["T_L"].append(float(T_L)) 126 | history["objective"].append(float(obj_val)) 127 | history["avg_temp"].append(float(avg_temp)) 128 | 129 | if (i + 1) % 5 == 0: 130 | print( 131 | f" Iter {i+1:3d}: T_L = {T_L:.2f} K, " 132 | f"avg_temp = {avg_temp:.2f} K, loss = {obj_val:.4e}" 133 | ) 134 | 135 | T_L_opt = float(T_L) 136 | print(f"\nOptimal T_L = {T_L_opt:.2f} K (avg_temp = {history['avg_temp'][-1]:.2f} K)") 137 | 138 | # Get solutions for visualization 139 | pred_init = pod_rbf.inference_single(state, T_L_init) 140 | pred_opt = pod_rbf.inference_single(state, jnp.array(T_L_opt)) 141 | 142 | # Analytical solutions for comparison 143 | T_analytical_init = analytical_solution(x, float(T_L_init)) 144 | T_analytical_opt = analytical_solution(x, T_L_opt) 145 | 146 | return ( 147 | history, 148 | x, 149 | pred_init, 150 | pred_opt, 151 | T_analytical_init, 152 | T_analytical_opt, 153 | T_L_init, 154 | T_L_opt, 155 | T_target, 156 | ) 157 | 158 | 159 | def plot_results( 160 | history, x, pred_init, pred_opt, T_anal_init, T_anal_opt, T_L_init, T_L_opt, T_target 161 | ): 162 | fig, axes = plt.subplots(2, 2, figsize=(12, 10)) 163 | 164 | # Objective convergence 165 | ax = axes[0, 0] 166 | ax.semilogy(history["objective"], "b-", linewidth=2) 167 | ax.set_xlabel("Iteration") 168 | ax.set_ylabel("Objective (squared error)") 169 | ax.set_title("Optimization Convergence") 170 | ax.grid(True, alpha=0.3) 171 | 172 | # Average temperature convergence 173 | ax = axes[0, 1] 174 | ax.plot(history["avg_temp"], "r-", linewidth=2, label="Predicted avg temp") 175 | ax.axhline(y=T_target, color="g", linestyle="--", linewidth=2, label=f"Target: {T_target} K") 176 | ax.set_xlabel("Iteration") 177 | ax.set_ylabel("Average Temperature (K)") 178 | ax.set_title("Average Temperature vs Target") 179 | ax.legend() 180 | ax.grid(True, alpha=0.3) 181 | 182 | # Initial temperature profile 183 | ax = axes[1, 0] 184 | ax.plot(x, pred_init, "b-", linewidth=2, label="ROM prediction") 185 | ax.plot(x, T_anal_init, "r--", linewidth=2, label="Analytical") 186 | ax.axhline( 187 | y=float(jnp.mean(pred_init)), 188 | color="g", 189 | linestyle=":", 190 | label=f"Avg: {float(jnp.mean(pred_init)):.1f} K", 191 | ) 192 | ax.set_xlabel("x") 193 | ax.set_ylabel("Temperature (K)") 194 | ax.set_title(f"Initial: T_L = {float(T_L_init):.1f} K") 195 | ax.legend() 196 | ax.grid(True, alpha=0.3) 197 | 198 | # Optimized temperature profile 199 | ax = axes[1, 1] 200 | ax.plot(x, pred_opt, "b-", linewidth=2, label="ROM prediction") 201 | ax.plot(x, T_anal_opt, "r--", linewidth=2, label="Analytical") 202 | ax.axhline( 203 | y=float(jnp.mean(pred_opt)), 204 | color="g", 205 | linestyle=":", 206 | label=f"Avg: {float(jnp.mean(pred_opt)):.1f} K", 207 | ) 208 | ax.axhline(y=T_target, color="orange", linestyle="--", alpha=0.7, label=f"Target: {T_target} K") 209 | ax.set_xlabel("x") 210 | ax.set_ylabel("Temperature (K)") 211 | ax.set_title(f"Optimized: T_L = {T_L_opt:.1f} K") 212 | ax.legend() 213 | ax.grid(True, alpha=0.3) 214 | 215 | plt.tight_layout() 216 | plt.show() 217 | 218 | 219 | if __name__ == "__main__": 220 | results = run_optimization() 221 | plot_results(*results) 222 | -------------------------------------------------------------------------------- /tests/test_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for RBF kernel functions. 3 | """ 4 | 5 | import jax.numpy as jnp 6 | import pytest 7 | 8 | from pod_rbf.kernels import ( 9 | KernelType, 10 | apply_kernel, 11 | kernel_gaussian, 12 | kernel_imq, 13 | kernel_polyharmonic_spline, 14 | ) 15 | 16 | 17 | class TestKernelIMQ: 18 | """Test Inverse Multiquadrics kernel.""" 19 | 20 | def test_imq_at_zero(self): 21 | """IMQ kernel should be 1 at r=0.""" 22 | r2 = jnp.array([0.0]) 23 | result = kernel_imq(r2, shape_factor=1.0) 24 | assert jnp.allclose(result, 1.0) 25 | 26 | def test_imq_bounded(self): 27 | """IMQ kernel values should be in (0, 1].""" 28 | r2 = jnp.array([0.0, 0.1, 1.0, 10.0, 100.0]) 29 | result = kernel_imq(r2, shape_factor=1.0) 30 | assert jnp.all(result > 0) 31 | assert jnp.all(result <= 1.0) 32 | 33 | def test_imq_decay(self): 34 | """IMQ should decay monotonically with distance.""" 35 | r2 = jnp.array([0.0, 1.0, 4.0, 9.0]) 36 | result = kernel_imq(r2, shape_factor=1.0) 37 | # Check monotonic decrease 38 | assert jnp.all(result[:-1] >= result[1:]) 39 | 40 | def test_imq_shape_factor_effect(self): 41 | """Larger shape factor should give slower decay.""" 42 | r2 = jnp.array([4.0]) 43 | val_small_c = kernel_imq(r2, shape_factor=0.5) 44 | val_large_c = kernel_imq(r2, shape_factor=2.0) 45 | # Larger c means slower decay, so value should be higher 46 | assert val_large_c > val_small_c 47 | 48 | 49 | class TestKernelGaussian: 50 | """Test Gaussian kernel.""" 51 | 52 | def test_gaussian_at_zero(self): 53 | """Gaussian kernel should be 1 at r=0.""" 54 | r2 = jnp.array([0.0]) 55 | result = kernel_gaussian(r2, shape_factor=1.0) 56 | assert jnp.allclose(result, 1.0) 57 | 58 | def test_gaussian_bounded(self): 59 | """Gaussian kernel values should be in (0, 1].""" 60 | r2 = jnp.array([0.0, 0.1, 1.0, 10.0, 100.0]) 61 | result = kernel_gaussian(r2, shape_factor=1.0) 62 | assert jnp.all(result > 0) 63 | assert jnp.all(result <= 1.0) 64 | 65 | def test_gaussian_decay(self): 66 | """Gaussian should decay monotonically with distance.""" 67 | r2 = jnp.array([0.0, 1.0, 4.0, 9.0]) 68 | result = kernel_gaussian(r2, shape_factor=1.0) 69 | # Check monotonic decrease 70 | assert jnp.all(result[:-1] >= result[1:]) 71 | 72 | def test_gaussian_exponential_decay(self): 73 | """Gaussian should decay exponentially.""" 74 | r2 = jnp.array([0.0, 1.0]) 75 | c = 1.0 76 | result = kernel_gaussian(r2, shape_factor=c) 77 | # At r²=1, should be exp(-1/c²) = exp(-1) 78 | assert jnp.allclose(result[1], jnp.exp(-1.0)) 79 | 80 | def test_gaussian_shape_factor_effect(self): 81 | """Larger shape factor should give slower decay.""" 82 | r2 = jnp.array([4.0]) 83 | val_small_c = kernel_gaussian(r2, shape_factor=0.5) 84 | val_large_c = kernel_gaussian(r2, shape_factor=2.0) 85 | # Larger c means slower decay, so value should be higher 86 | assert val_large_c > val_small_c 87 | 88 | 89 | class TestKernelPolyharmonicSpline: 90 | """Test Polyharmonic Spline kernels.""" 91 | 92 | def test_phs_order_1(self): 93 | """PHS order 1: phi(r) = r.""" 94 | r2 = jnp.array([0.0, 1.0, 4.0, 9.0]) 95 | result = kernel_polyharmonic_spline(r2, order=1) 96 | expected = jnp.sqrt(r2) 97 | assert jnp.allclose(result, expected) 98 | 99 | def test_phs_order_3(self): 100 | """PHS order 3: phi(r) = r³.""" 101 | r2 = jnp.array([0.0, 1.0, 4.0, 9.0]) 102 | result = kernel_polyharmonic_spline(r2, order=3) 103 | r = jnp.sqrt(r2) 104 | expected = r**3 105 | assert jnp.allclose(result, expected) 106 | 107 | def test_phs_order_5(self): 108 | """PHS order 5: phi(r) = r⁵.""" 109 | r2 = jnp.array([0.0, 1.0, 4.0, 9.0]) 110 | result = kernel_polyharmonic_spline(r2, order=5) 111 | r = jnp.sqrt(r2) 112 | expected = r**5 113 | assert jnp.allclose(result, expected) 114 | 115 | def test_phs_order_2(self): 116 | """PHS order 2: phi(r) = r²*log(r).""" 117 | # Avoid r=0 for log 118 | r2 = jnp.array([1.0, 4.0, 9.0]) 119 | result = kernel_polyharmonic_spline(r2, order=2) 120 | r = jnp.sqrt(r2) 121 | expected = (r**2) * jnp.log(r) 122 | assert jnp.allclose(result, expected) 123 | 124 | def test_phs_order_4(self): 125 | """PHS order 4: phi(r) = r⁴*log(r).""" 126 | # Avoid r=0 for log 127 | r2 = jnp.array([1.0, 4.0, 9.0]) 128 | result = kernel_polyharmonic_spline(r2, order=4) 129 | r = jnp.sqrt(r2) 130 | expected = (r**4) * jnp.log(r) 131 | assert jnp.allclose(result, expected) 132 | 133 | def test_phs_at_zero_odd(self): 134 | """PHS odd orders should be 0 at r=0.""" 135 | r2 = jnp.array([0.0]) 136 | for order in [1, 3, 5]: 137 | result = kernel_polyharmonic_spline(r2, order=order) 138 | assert jnp.allclose(result, 0.0) 139 | 140 | def test_phs_at_zero_even(self): 141 | """PHS even orders should be 0 at r=0 (handled by jnp.where).""" 142 | r2 = jnp.array([0.0]) 143 | for order in [2, 4]: 144 | result = kernel_polyharmonic_spline(r2, order=order) 145 | # Should be set to 0 by the where clause to avoid log(0) 146 | assert jnp.isfinite(result[0]) 147 | assert jnp.allclose(result, 0.0) 148 | 149 | def test_phs_growth(self): 150 | """PHS should grow with distance (unlike IMQ/Gaussian).""" 151 | r2 = jnp.array([0.0, 1.0, 4.0]) 152 | result = kernel_polyharmonic_spline(r2, order=3) 153 | # Should increase (not decay) 154 | assert result[0] <= result[1] <= result[2] 155 | 156 | 157 | class TestApplyKernel: 158 | """Test kernel dispatcher.""" 159 | 160 | def test_apply_imq(self): 161 | """Dispatcher should correctly apply IMQ kernel.""" 162 | r2 = jnp.array([0.0, 1.0, 4.0]) 163 | result = apply_kernel(r2, "imq", shape_factor=1.0, kernel_order=3) 164 | expected = kernel_imq(r2, shape_factor=1.0) 165 | assert jnp.allclose(result, expected) 166 | 167 | def test_apply_gaussian(self): 168 | """Dispatcher should correctly apply Gaussian kernel.""" 169 | r2 = jnp.array([0.0, 1.0, 4.0]) 170 | result = apply_kernel(r2, "gaussian", shape_factor=1.0, kernel_order=3) 171 | expected = kernel_gaussian(r2, shape_factor=1.0) 172 | assert jnp.allclose(result, expected) 173 | 174 | def test_apply_phs(self): 175 | """Dispatcher should correctly apply PHS kernel.""" 176 | r2 = jnp.array([0.0, 1.0, 4.0]) 177 | result = apply_kernel(r2, "polyharmonic_spline", shape_factor=None, kernel_order=3) 178 | expected = kernel_polyharmonic_spline(r2, order=3) 179 | assert jnp.allclose(result, expected) 180 | 181 | def test_apply_invalid_kernel(self): 182 | """Dispatcher should raise error for invalid kernel.""" 183 | r2 = jnp.array([1.0]) 184 | with pytest.raises(ValueError, match="is not a valid KernelType"): 185 | apply_kernel(r2, "invalid_kernel", shape_factor=1.0, kernel_order=3) 186 | 187 | 188 | class TestKernelType: 189 | """Test KernelType enum.""" 190 | 191 | def test_enum_values(self): 192 | """Test that enum values are correct.""" 193 | assert KernelType.IMQ == "imq" 194 | assert KernelType.GAUSSIAN == "gaussian" 195 | assert KernelType.POLYHARMONIC_SPLINE == "polyharmonic_spline" 196 | 197 | def test_enum_from_string(self): 198 | """Test creating enum from string.""" 199 | assert KernelType("imq") == KernelType.IMQ 200 | assert KernelType("gaussian") == KernelType.GAUSSIAN 201 | assert KernelType("polyharmonic_spline") == KernelType.POLYHARMONIC_SPLINE 202 | -------------------------------------------------------------------------------- /pod_rbf/rbf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Radial Basis Function (RBF) kernel and matrix construction. 3 | 4 | Supports multiple kernel types: 5 | - Inverse Multi-Quadrics (IMQ): phi(r) = 1 / sqrt(r²/c² + 1) 6 | - Gaussian: phi(r) = exp(-r²/c²) 7 | - Polyharmonic Splines (PHS): phi(r) = r^k or r^k*log(r) 8 | """ 9 | 10 | import jax 11 | import jax.numpy as jnp 12 | import jax.scipy.linalg as jla 13 | from jax import Array 14 | 15 | from .kernels import apply_kernel 16 | 17 | 18 | def build_collocation_matrix( 19 | train_params: Array, 20 | params_range: Array, 21 | kernel: str = "imq", 22 | shape_factor: float | None = None, 23 | kernel_order: int = 3, 24 | ) -> Array: 25 | """ 26 | Build RBF collocation matrix for training. 27 | 28 | Parameters 29 | ---------- 30 | train_params : Array 31 | Training parameters, shape (n_params, n_train_points). 32 | params_range : Array 33 | Range of each parameter for normalization, shape (n_params,). 34 | kernel : str, optional 35 | Kernel type: 'imq', 'gaussian', or 'polyharmonic_spline'. 36 | Default is 'imq'. 37 | shape_factor : float | None, optional 38 | RBF shape parameter c. Required for IMQ and Gaussian kernels. 39 | Ignored for polyharmonic splines. 40 | kernel_order : int, optional 41 | Order for polyharmonic splines (default 3). 42 | Ignored for other kernels. 43 | 44 | Returns 45 | ------- 46 | Array 47 | Collocation matrix, shape (n_train_points, n_train_points). 48 | """ 49 | n_params, n_train = train_params.shape 50 | 51 | def accumulate_r2(i: int, r2: Array) -> Array: 52 | param_row = train_params[i, :] 53 | diff = param_row[:, None] - param_row[None, :] # (n_train, n_train) 54 | return r2 + (diff / params_range[i]) ** 2 55 | 56 | r2 = jax.lax.fori_loop(0, n_params, accumulate_r2, jnp.zeros((n_train, n_train))) 57 | 58 | return apply_kernel(r2, kernel, shape_factor, kernel_order) 59 | 60 | 61 | def build_inference_matrix( 62 | train_params: Array, 63 | inf_params: Array, 64 | params_range: Array, 65 | kernel: str = "imq", 66 | shape_factor: float | None = None, 67 | kernel_order: int = 3, 68 | ) -> Array: 69 | """ 70 | Build RBF inference matrix for prediction at new parameters. 71 | 72 | Parameters 73 | ---------- 74 | train_params : Array 75 | Training parameters, shape (n_params, n_train_points). 76 | inf_params : Array 77 | Inference parameters, shape (n_params, n_inf_points). 78 | params_range : Array 79 | Range of each parameter for normalization, shape (n_params,). 80 | kernel : str, optional 81 | Kernel type: 'imq', 'gaussian', or 'polyharmonic_spline'. 82 | Default is 'imq'. 83 | shape_factor : float | None, optional 84 | RBF shape parameter c. Required for IMQ and Gaussian kernels. 85 | Ignored for polyharmonic splines. 86 | kernel_order : int, optional 87 | Order for polyharmonic splines (default 3). 88 | Ignored for other kernels. 89 | 90 | Returns 91 | ------- 92 | Array 93 | Inference matrix, shape (n_inf_points, n_train_points). 94 | """ 95 | n_params = train_params.shape[0] 96 | n_inf = inf_params.shape[1] 97 | n_train = train_params.shape[1] 98 | 99 | def accumulate_r2(i: int, r2: Array) -> Array: 100 | diff = inf_params[i, :, None] - train_params[i, None, :] # (n_inf, n_train) 101 | return r2 + (diff / params_range[i]) ** 2 102 | 103 | r2 = jax.lax.fori_loop(0, n_params, accumulate_r2, jnp.zeros((n_inf, n_train))) 104 | 105 | return apply_kernel(r2, kernel, shape_factor, kernel_order) 106 | 107 | 108 | def build_polynomial_basis( 109 | params: Array, 110 | params_range: Array, 111 | degree: int = 2, 112 | ) -> Array: 113 | """ 114 | Build polynomial basis matrix for RBF augmentation. 115 | 116 | Parameters 117 | ---------- 118 | params : Array 119 | Parameters, shape (n_params, n_points). 120 | params_range : Array 121 | Range of each parameter for normalization, shape (n_params,). 122 | degree : int 123 | Polynomial degree (0=constant, 1=linear, 2=quadratic). 124 | Note: This should be a Python int, not a traced value. 125 | 126 | Returns 127 | ------- 128 | Array 129 | Polynomial basis matrix, shape (n_points, n_poly). 130 | - degree 0: [1] -> 1 column 131 | - degree 1: [1, p1, p2, ...] -> n_params + 1 columns 132 | - degree 2: [1, p1, ..., pn, p1², ..., pn², p1*p2, ...] -> (n+1)(n+2)/2 cols 133 | """ 134 | n_params, n_points = params.shape 135 | 136 | # Normalize parameters by dividing by range (consistent between train/inference) 137 | p_norm = params / params_range[:, None] 138 | 139 | # Build polynomial terms - degree must be a Python int for JIT compatibility 140 | terms = [jnp.ones((n_points,))] # constant term 141 | 142 | if degree >= 1: 143 | # Linear terms 144 | for i in range(n_params): 145 | terms.append(p_norm[i, :]) 146 | 147 | if degree >= 2: 148 | # Squared terms 149 | for i in range(n_params): 150 | terms.append(p_norm[i, :] ** 2) 151 | # Cross terms 152 | for i in range(n_params): 153 | for j in range(i + 1, n_params): 154 | terms.append(p_norm[i, :] * p_norm[j, :]) 155 | 156 | return jnp.stack(terms, axis=1) 157 | 158 | 159 | def solve_augmented_system_schur( 160 | F: Array, 161 | P: Array, 162 | rhs: Array, 163 | ) -> tuple[Array, Array]: 164 | """ 165 | Solve augmented RBF system via Schur complement. 166 | 167 | Solves the saddle-point system: 168 | [F P] [λ] [rhs] 169 | [P.T 0] [c] = [0] 170 | 171 | Using Schur complement: S = P.T @ F^{-1} @ P 172 | 173 | Parameters 174 | ---------- 175 | F : Array 176 | RBF collocation matrix, shape (n_train, n_train). Symmetric positive definite. 177 | P : Array 178 | Polynomial basis matrix, shape (n_train, n_poly). 179 | rhs : Array 180 | Right-hand side, shape (n_rhs, n_train). Each row is a separate RHS. 181 | 182 | Returns 183 | ------- 184 | tuple[Array, Array] 185 | rbf_weights : shape (n_rhs, n_train) 186 | poly_coeffs : shape (n_rhs, n_poly) 187 | """ 188 | # Cholesky factorization of F (symmetric positive definite) 189 | cho_F = jla.cho_factor(F) 190 | 191 | # Solve F @ X = P for X = F^{-1} @ P 192 | F_inv_P = jla.cho_solve(cho_F, P) # (n_train, n_poly) 193 | 194 | # Schur complement: S = P.T @ F^{-1} @ P 195 | S = P.T @ F_inv_P # (n_poly, n_poly) 196 | 197 | # Solve F @ Y = rhs.T for Y = F^{-1} @ rhs.T 198 | F_inv_rhs = jla.cho_solve(cho_F, rhs.T) # (n_train, n_rhs) 199 | 200 | # Solve S @ c.T = P.T @ F^{-1} @ rhs.T for polynomial coefficients 201 | schur_rhs = P.T @ F_inv_rhs # (n_poly, n_rhs) 202 | poly_coeffs = jnp.linalg.solve(S, schur_rhs).T # (n_rhs, n_poly) 203 | 204 | # Back-substitute: λ = F^{-1} @ (rhs - P @ c) 205 | rbf_weights = (F_inv_rhs - F_inv_P @ poly_coeffs.T).T # (n_rhs, n_train) 206 | 207 | return rbf_weights, poly_coeffs 208 | 209 | 210 | def solve_augmented_system_direct( 211 | F: Array, 212 | P: Array, 213 | rhs: Array, 214 | ) -> tuple[Array, Array]: 215 | """ 216 | Solve augmented RBF system by direct assembly and solve. 217 | 218 | Solves the saddle-point system: 219 | [F P] [λ] [rhs] 220 | [P.T 0] [c] = [0] 221 | 222 | This method assembles and solves the full system directly, which works 223 | for kernels where F is not positive definite (e.g., polyharmonic splines). 224 | 225 | Parameters 226 | ---------- 227 | F : Array 228 | RBF collocation matrix, shape (n_train, n_train). 229 | P : Array 230 | Polynomial basis matrix, shape (n_train, n_poly). 231 | rhs : Array 232 | Right-hand side, shape (n_rhs, n_train). Each row is a separate RHS. 233 | 234 | Returns 235 | ------- 236 | tuple[Array, Array] 237 | rbf_weights : shape (n_rhs, n_train) 238 | poly_coeffs : shape (n_rhs, n_poly) 239 | """ 240 | n_train = F.shape[0] 241 | n_poly = P.shape[1] 242 | n_rhs = rhs.shape[0] 243 | 244 | # Assemble full augmented system matrix 245 | # [F P ] 246 | # [P' 0 ] 247 | top = jnp.hstack([F, P]) 248 | bottom = jnp.hstack([P.T, jnp.zeros((n_poly, n_poly))]) 249 | A_aug = jnp.vstack([top, bottom]) 250 | 251 | # Assemble augmented RHS 252 | # [rhs] 253 | # [0 ] 254 | rhs_aug = jnp.hstack([rhs, jnp.zeros((n_rhs, n_poly))]) 255 | 256 | # Solve the full system 257 | solution = jnp.linalg.solve(A_aug, rhs_aug.T).T # (n_rhs, n_train + n_poly) 258 | 259 | # Extract weights and polynomial coefficients 260 | rbf_weights = solution[:, :n_train] 261 | poly_coeffs = solution[:, n_train:] 262 | 263 | return rbf_weights, poly_coeffs 264 | -------------------------------------------------------------------------------- /tests/test_rbf.py: -------------------------------------------------------------------------------- 1 | """Tests for RBF matrix construction.""" 2 | 3 | import jax.numpy as jnp 4 | import pytest 5 | 6 | from pod_rbf.rbf import ( 7 | build_collocation_matrix, 8 | build_inference_matrix, 9 | build_polynomial_basis, 10 | solve_augmented_system_schur, 11 | ) 12 | 13 | 14 | class TestBuildCollocationMatrix: 15 | """Test RBF collocation matrix construction.""" 16 | 17 | def test_symmetry(self): 18 | """Collocation matrix should be symmetric.""" 19 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 20 | params_range = jnp.array([4.0]) 21 | C = build_collocation_matrix(params, params_range, shape_factor=1.0) 22 | 23 | assert jnp.allclose(C, C.T), "Collocation matrix should be symmetric" 24 | 25 | def test_diagonal_ones(self): 26 | """Diagonal should be 1 (r=0 -> phi=1).""" 27 | params = jnp.array([[1.0, 2.0, 3.0]]) 28 | params_range = jnp.array([2.0]) 29 | C = build_collocation_matrix(params, params_range, shape_factor=1.0) 30 | 31 | assert jnp.allclose(jnp.diag(C), 1.0), "Diagonal elements should be 1" 32 | 33 | def test_shape(self): 34 | """Output shape should be (n_train, n_train).""" 35 | n_train = 7 36 | params = jnp.linspace(0, 10, n_train)[None, :] 37 | params_range = jnp.array([10.0]) 38 | C = build_collocation_matrix(params, params_range, shape_factor=0.5) 39 | 40 | assert C.shape == (n_train, n_train), f"Expected ({n_train}, {n_train}), got {C.shape}" 41 | 42 | def test_multi_param(self): 43 | """Should work with multiple parameters.""" 44 | n_train = 5 45 | n_params = 3 46 | params = jnp.array([ 47 | [1.0, 2.0, 3.0, 4.0, 5.0], 48 | [0.1, 0.2, 0.3, 0.4, 0.5], 49 | [10.0, 20.0, 30.0, 40.0, 50.0], 50 | ]) 51 | params_range = jnp.array([4.0, 0.4, 40.0]) 52 | C = build_collocation_matrix(params, params_range, shape_factor=1.0) 53 | 54 | assert C.shape == (n_train, n_train) 55 | assert jnp.allclose(C, C.T), "Multi-param collocation should be symmetric" 56 | assert jnp.allclose(jnp.diag(C), 1.0), "Diagonal should still be 1" 57 | 58 | def test_values_positive(self): 59 | """All values should be positive (IMQ kernel is always positive).""" 60 | params = jnp.array([[1.0, 5.0, 10.0, 20.0]]) 61 | params_range = jnp.array([19.0]) 62 | C = build_collocation_matrix(params, params_range, shape_factor=0.5) 63 | 64 | assert jnp.all(C > 0), "All values should be positive" 65 | 66 | def test_values_bounded(self): 67 | """Values should be in (0, 1] for IMQ kernel.""" 68 | params = jnp.array([[1.0, 5.0, 10.0, 20.0]]) 69 | params_range = jnp.array([19.0]) 70 | C = build_collocation_matrix(params, params_range, shape_factor=0.5) 71 | 72 | assert jnp.all(C > 0), "All values should be > 0" 73 | assert jnp.all(C <= 1.0), "All values should be <= 1" 74 | 75 | 76 | class TestBuildInferenceMatrix: 77 | """Test RBF inference matrix construction.""" 78 | 79 | def test_shape(self): 80 | """Output shape should be (n_inf, n_train).""" 81 | n_train = 5 82 | n_inf = 3 83 | train_params = jnp.linspace(0, 10, n_train)[None, :] 84 | inf_params = jnp.array([[2.5, 5.0, 7.5]]) 85 | params_range = jnp.array([10.0]) 86 | 87 | F = build_inference_matrix(train_params, inf_params, params_range, shape_factor=1.0) 88 | 89 | assert F.shape == (n_inf, n_train), f"Expected ({n_inf}, {n_train}), got {F.shape}" 90 | 91 | def test_at_training_points(self): 92 | """Inference at training points should have max value 1 at that column.""" 93 | train_params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 94 | inf_params = jnp.array([[3.0]]) # Exactly at training point 95 | params_range = jnp.array([4.0]) 96 | 97 | F = build_inference_matrix(train_params, inf_params, params_range, shape_factor=1.0) 98 | 99 | # At training point index 2, value should be 1 100 | assert jnp.isclose(F[0, 2], 1.0), f"Expected 1.0 at training point, got {F[0, 2]}" 101 | 102 | def test_values_positive(self): 103 | """All values should be positive.""" 104 | train_params = jnp.array([[1.0, 5.0, 10.0]]) 105 | inf_params = jnp.array([[2.0, 7.0]]) 106 | params_range = jnp.array([9.0]) 107 | 108 | F = build_inference_matrix(train_params, inf_params, params_range, shape_factor=0.5) 109 | 110 | assert jnp.all(F > 0), "All values should be positive" 111 | 112 | def test_multi_param(self): 113 | """Should work with multiple parameters.""" 114 | train_params = jnp.array([ 115 | [1.0, 2.0, 3.0], 116 | [0.1, 0.2, 0.3], 117 | ]) 118 | inf_params = jnp.array([ 119 | [1.5, 2.5], 120 | [0.15, 0.25], 121 | ]) 122 | params_range = jnp.array([2.0, 0.2]) 123 | 124 | F = build_inference_matrix(train_params, inf_params, params_range, shape_factor=1.0) 125 | 126 | assert F.shape == (2, 3) 127 | assert jnp.all(F > 0) 128 | 129 | 130 | class TestBuildPolynomialBasis: 131 | """Test polynomial basis matrix construction.""" 132 | 133 | def test_degree_0_shape(self): 134 | """Degree 0 should return constant column only.""" 135 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 136 | params_range = jnp.array([4.0]) 137 | P = build_polynomial_basis(params, params_range, degree=0) 138 | 139 | assert P.shape == (5, 1), f"Expected (5, 1), got {P.shape}" 140 | assert jnp.allclose(P[:, 0], 1.0), "Constant column should be all ones" 141 | 142 | def test_degree_1_shape(self): 143 | """Degree 1 should return constant + linear terms.""" 144 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 145 | params_range = jnp.array([4.0]) 146 | P = build_polynomial_basis(params, params_range, degree=1) 147 | 148 | # 1D: [1, p] -> 2 columns 149 | assert P.shape == (5, 2), f"Expected (5, 2), got {P.shape}" 150 | assert jnp.allclose(P[:, 0], 1.0), "First column should be all ones" 151 | 152 | def test_degree_2_shape_1d(self): 153 | """Degree 2 with 1 param should return 3 columns.""" 154 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 155 | params_range = jnp.array([4.0]) 156 | P = build_polynomial_basis(params, params_range, degree=2) 157 | 158 | # 1D: [1, p, p²] -> 3 columns 159 | assert P.shape == (5, 3), f"Expected (5, 3), got {P.shape}" 160 | 161 | def test_degree_2_shape_2d(self): 162 | """Degree 2 with 2 params should return 6 columns.""" 163 | params = jnp.array([ 164 | [1.0, 2.0, 3.0, 4.0, 5.0], 165 | [0.1, 0.2, 0.3, 0.4, 0.5], 166 | ]) 167 | params_range = jnp.array([4.0, 0.4]) 168 | P = build_polynomial_basis(params, params_range, degree=2) 169 | 170 | # 2D: [1, p1, p2, p1², p2², p1*p2] -> 6 columns 171 | assert P.shape == (5, 6), f"Expected (5, 6), got {P.shape}" 172 | 173 | def test_normalized_values(self): 174 | """Polynomial values should scale appropriately with params_range.""" 175 | params = jnp.array([[10.0, 20.0, 30.0, 40.0, 50.0]]) 176 | params_range = jnp.array([40.0]) 177 | P = build_polynomial_basis(params, params_range, degree=2) 178 | 179 | # Normalized values are params/range, so for [10, 20, 30, 40, 50]/40 = [0.25, 0.5, 0.75, 1.0, 1.25] 180 | # Constant term should be 1 181 | assert jnp.allclose(P[:, 0], 1.0), "Constant column should be all ones" 182 | # Linear terms should be params/range 183 | expected_linear = jnp.array([0.25, 0.5, 0.75, 1.0, 1.25]) 184 | assert jnp.allclose(P[:, 1], expected_linear), f"Linear term mismatch: {P[:, 1]} vs {expected_linear}" 185 | 186 | def test_multi_param_cross_terms(self): 187 | """Cross terms should be computed correctly for multi-param case.""" 188 | params = jnp.array([ 189 | [0.0, 2.0, 4.0], 190 | [0.0, 2.0, 4.0], 191 | ]) 192 | params_range = jnp.array([4.0, 4.0]) 193 | P = build_polynomial_basis(params, params_range, degree=2) 194 | 195 | # Columns: [1, p1, p2, p1², p2², p1*p2] 196 | # Normalized: [0, 0.5, 1] for both params 197 | assert P.shape == (3, 6) 198 | # Last column is cross term p1*p2 199 | # At normalized values [0, 0.5, 1], cross terms are [0, 0.25, 1] 200 | expected_cross = jnp.array([0.0, 0.25, 1.0]) 201 | assert jnp.allclose(P[:, 5], expected_cross), f"Cross term mismatch: {P[:, 5]} vs {expected_cross}" 202 | 203 | 204 | class TestSolveAugmentedSystemSchur: 205 | """Test Schur complement solver.""" 206 | 207 | def test_simple_system(self): 208 | """Solve a simple augmented system and verify solution.""" 209 | # Simple 3x3 SPD matrix 210 | F = jnp.array([ 211 | [4.0, 1.0, 0.5], 212 | [1.0, 3.0, 0.5], 213 | [0.5, 0.5, 2.0], 214 | ]) 215 | # Linear polynomial basis 216 | P = jnp.array([ 217 | [1.0, 0.0], 218 | [1.0, 0.5], 219 | [1.0, 1.0], 220 | ]) 221 | # Single RHS 222 | rhs = jnp.array([[1.0, 2.0, 3.0]]) 223 | 224 | rbf_weights, poly_coeffs = solve_augmented_system_schur(F, P, rhs) 225 | 226 | assert rbf_weights.shape == (1, 3), f"RBF weights shape: {rbf_weights.shape}" 227 | assert poly_coeffs.shape == (1, 2), f"Poly coeffs shape: {poly_coeffs.shape}" 228 | 229 | # Verify solution satisfies augmented system 230 | # F @ λ + P @ c = rhs 231 | lhs1 = F @ rbf_weights.T + P @ poly_coeffs.T 232 | assert jnp.allclose(lhs1.T, rhs, rtol=1e-5), f"First equation not satisfied: {lhs1.T} vs {rhs}" 233 | 234 | # P.T @ λ = 0 (orthogonality constraint) 235 | lhs2 = P.T @ rbf_weights.T 236 | assert jnp.allclose(lhs2, 0, atol=1e-10), f"Orthogonality constraint not satisfied: {lhs2}" 237 | 238 | def test_multiple_rhs(self): 239 | """Solve system with multiple right-hand sides.""" 240 | F = jnp.array([ 241 | [4.0, 1.0, 0.5], 242 | [1.0, 3.0, 0.5], 243 | [0.5, 0.5, 2.0], 244 | ]) 245 | P = jnp.array([ 246 | [1.0, 0.0], 247 | [1.0, 0.5], 248 | [1.0, 1.0], 249 | ]) 250 | # Two RHS 251 | rhs = jnp.array([ 252 | [1.0, 2.0, 3.0], 253 | [0.5, 1.0, 1.5], 254 | ]) 255 | 256 | rbf_weights, poly_coeffs = solve_augmented_system_schur(F, P, rhs) 257 | 258 | assert rbf_weights.shape == (2, 3) 259 | assert poly_coeffs.shape == (2, 2) 260 | 261 | # Verify both solutions satisfy the system 262 | for i in range(2): 263 | lhs1 = F @ rbf_weights[i, :] + P @ poly_coeffs[i, :] 264 | assert jnp.allclose(lhs1, rhs[i, :], rtol=1e-5), f"RHS {i}: First equation not satisfied" 265 | 266 | lhs2 = P.T @ rbf_weights[i, :] 267 | assert jnp.allclose(lhs2, 0, atol=1e-10), f"RHS {i}: Orthogonality constraint not satisfied" 268 | 269 | def test_polynomial_reproduction(self): 270 | """Schur complement should reproduce polynomials exactly.""" 271 | # If RHS is in the span of P, polynomial coeffs should capture it fully 272 | n_points = 5 273 | params = jnp.linspace(0, 1, n_points)[None, :] 274 | params_range = jnp.array([1.0]) 275 | 276 | F = build_collocation_matrix(params, params_range, shape_factor=0.5) 277 | P = build_polynomial_basis(params, params_range, degree=1) # [1, p] 278 | 279 | # RHS that's a linear function: f = 2 + 3*p 280 | rhs = 2.0 + 3.0 * params # shape (1, n_points) 281 | 282 | rbf_weights, poly_coeffs = solve_augmented_system_schur(F, P, rhs) 283 | 284 | # RBF weights should be near zero (polynomial captures everything) 285 | assert jnp.allclose(rbf_weights, 0, atol=1e-8), f"RBF weights should be ~0 for polynomial RHS: {rbf_weights}" 286 | 287 | # Poly coeffs should be [2, 3] 288 | assert jnp.allclose(poly_coeffs[0, 0], 2.0, rtol=1e-5), f"Constant coeff: {poly_coeffs[0, 0]}" 289 | assert jnp.allclose(poly_coeffs[0, 1], 3.0, rtol=1e-5), f"Linear coeff: {poly_coeffs[0, 1]}" 290 | 291 | 292 | class TestCollocationMatrixKernels: 293 | """Test collocation matrix with different kernel types.""" 294 | 295 | def test_imq_kernel(self): 296 | """Test collocation matrix with IMQ kernel.""" 297 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 298 | params_range = jnp.array([4.0]) 299 | C = build_collocation_matrix( 300 | params, params_range, kernel="imq", shape_factor=1.0, kernel_order=3 301 | ) 302 | 303 | assert C.shape == (5, 5) 304 | assert jnp.allclose(C, C.T) 305 | assert jnp.allclose(jnp.diag(C), 1.0) 306 | assert jnp.all(C > 0) and jnp.all(C <= 1.0) 307 | 308 | def test_gaussian_kernel(self): 309 | """Test collocation matrix with Gaussian kernel.""" 310 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 311 | params_range = jnp.array([4.0]) 312 | C = build_collocation_matrix( 313 | params, params_range, kernel="gaussian", shape_factor=1.0, kernel_order=3 314 | ) 315 | 316 | assert C.shape == (5, 5) 317 | assert jnp.allclose(C, C.T) 318 | assert jnp.allclose(jnp.diag(C), 1.0) 319 | assert jnp.all(C > 0) and jnp.all(C <= 1.0) 320 | 321 | def test_phs_kernel(self): 322 | """Test collocation matrix with polyharmonic spline kernel.""" 323 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 324 | params_range = jnp.array([4.0]) 325 | C = build_collocation_matrix( 326 | params, 327 | params_range, 328 | kernel="polyharmonic_spline", 329 | shape_factor=None, 330 | kernel_order=3, 331 | ) 332 | 333 | assert C.shape == (5, 5) 334 | assert jnp.allclose(C, C.T) 335 | assert jnp.allclose(jnp.diag(C), 0.0) # PHS is 0 at r=0 336 | 337 | def test_kernels_produce_different_matrices(self): 338 | """Different kernels should produce different matrices.""" 339 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 340 | params_range = jnp.array([4.0]) 341 | 342 | C_imq = build_collocation_matrix( 343 | params, params_range, kernel="imq", shape_factor=1.0 344 | ) 345 | C_gauss = build_collocation_matrix( 346 | params, params_range, kernel="gaussian", shape_factor=1.0 347 | ) 348 | C_phs = build_collocation_matrix( 349 | params, params_range, kernel="polyharmonic_spline", kernel_order=3 350 | ) 351 | 352 | # Matrices should be different 353 | assert not jnp.allclose(C_imq, C_gauss) 354 | assert not jnp.allclose(C_imq, C_phs) 355 | assert not jnp.allclose(C_gauss, C_phs) 356 | 357 | 358 | class TestInferenceMatrixKernels: 359 | """Test inference matrix with different kernel types.""" 360 | 361 | def test_imq_kernel(self): 362 | """Test inference matrix with IMQ kernel.""" 363 | train_params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 364 | inf_params = jnp.array([[2.5, 3.5]]) 365 | params_range = jnp.array([4.0]) 366 | 367 | F = build_inference_matrix( 368 | train_params, 369 | inf_params, 370 | params_range, 371 | kernel="imq", 372 | shape_factor=1.0, 373 | kernel_order=3, 374 | ) 375 | 376 | assert F.shape == (2, 5) 377 | assert jnp.all(F > 0) 378 | 379 | def test_gaussian_kernel(self): 380 | """Test inference matrix with Gaussian kernel.""" 381 | train_params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 382 | inf_params = jnp.array([[2.5, 3.5]]) 383 | params_range = jnp.array([4.0]) 384 | 385 | F = build_inference_matrix( 386 | train_params, 387 | inf_params, 388 | params_range, 389 | kernel="gaussian", 390 | shape_factor=1.0, 391 | kernel_order=3, 392 | ) 393 | 394 | assert F.shape == (2, 5) 395 | assert jnp.all(F > 0) 396 | 397 | def test_phs_kernel(self): 398 | """Test inference matrix with polyharmonic spline kernel.""" 399 | train_params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]]) 400 | inf_params = jnp.array([[2.5, 3.5]]) 401 | params_range = jnp.array([4.0]) 402 | 403 | F = build_inference_matrix( 404 | train_params, 405 | inf_params, 406 | params_range, 407 | kernel="polyharmonic_spline", 408 | shape_factor=None, 409 | kernel_order=3, 410 | ) 411 | 412 | assert F.shape == (2, 5) 413 | 414 | def test_at_training_point_all_kernels(self): 415 | """All kernels should have value 1 (IMQ/Gauss) or 0 (PHS) at r=0.""" 416 | train_params = jnp.array([[1.0, 2.0, 3.0]]) 417 | inf_params = jnp.array([[2.0]]) # Exactly at training point 418 | params_range = jnp.array([2.0]) 419 | 420 | # IMQ 421 | F_imq = build_inference_matrix( 422 | train_params, inf_params, params_range, kernel="imq", shape_factor=1.0 423 | ) 424 | assert jnp.isclose(F_imq[0, 1], 1.0) 425 | 426 | # Gaussian 427 | F_gauss = build_inference_matrix( 428 | train_params, 429 | inf_params, 430 | params_range, 431 | kernel="gaussian", 432 | shape_factor=1.0, 433 | ) 434 | assert jnp.isclose(F_gauss[0, 1], 1.0) 435 | 436 | # PHS 437 | F_phs = build_inference_matrix( 438 | train_params, 439 | inf_params, 440 | params_range, 441 | kernel="polyharmonic_spline", 442 | kernel_order=3, 443 | ) 444 | assert jnp.isclose(F_phs[0, 1], 0.0) 445 | -------------------------------------------------------------------------------- /examples/shape-optimization/optimize_shape.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shape Parameter Optimization via Autodiff 3 | 4 | Demonstrates using JAX autodifferentiation with JAXopt L-BFGS to find 5 | the optimal RBF shape parameter using the Rippa criterion (LOO-CV). 6 | 7 | The Rippa criterion provides a closed-form leave-one-out cross-validation 8 | error that is differentiable with respect to the shape parameter, enabling 9 | gradient-based optimization. 10 | 11 | Key concepts demonstrated: 12 | 1. Rippa criterion (LOO-CV) for RBF shape parameter selection 13 | 2. JAX autodiff through RBF matrix construction 14 | 3. JAXopt L-BFGS optimizer for scientific computing 15 | 4. Comparison with POD-RBF's condition-number based auto-optimization 16 | 17 | Requirements: 18 | pip install jaxopt 19 | 20 | References: 21 | Rippa, S. (1999). "An algorithm for selecting a good value for the 22 | parameter c in radial basis function interpolation." 23 | Advances in Computational Mathematics, 11(2-3), 193-210. 24 | """ 25 | 26 | import time 27 | 28 | import jax 29 | import jax.numpy as jnp 30 | import jaxopt 31 | import matplotlib.pyplot as plt 32 | import numpy as np 33 | 34 | from pod_rbf.rbf import build_collocation_matrix 35 | 36 | # Use CPU for reproducibility 37 | jax.config.update("jax_default_device", jax.devices("cpu")[0]) 38 | 39 | 40 | # ============================================================================= 41 | # Test Function 42 | # ============================================================================= 43 | 44 | 45 | def runge_function(x): 46 | """ 47 | Runge function - a classic test case for interpolation. 48 | 49 | f(x) = 1 / (1 + 25x^2) 50 | 51 | This function has a sharp peak at x=0 and is challenging to interpolate 52 | accurately, making it ideal for demonstrating the importance of shape 53 | parameter selection. 54 | """ 55 | return 1.0 / (1.0 + 25.0 * x**2) 56 | 57 | 58 | # ============================================================================= 59 | # Rippa Criterion (LOO-CV Cost Function) 60 | # ============================================================================= 61 | 62 | 63 | def loocv_cost(shape_factor, x, y, kernel="imq"): 64 | """ 65 | Compute leave-one-out cross-validation error using Rippa's closed-form formula. 66 | 67 | The Rippa criterion computes the LOO-CV error without actually performing 68 | n separate leave-one-out fits. For RBF interpolation A @ c = y, the LOO 69 | error at point i is: 70 | 71 | e_i = c_i / A_inv_ii 72 | 73 | where c_i is the i-th interpolation coefficient and A_inv_ii is the i-th 74 | diagonal element of A^{-1}. 75 | 76 | Parameters 77 | ---------- 78 | shape_factor : float 79 | RBF shape parameter to evaluate. 80 | x : Array 81 | Training point locations, shape (n_points,). 82 | y : Array 83 | Training values, shape (n_points,). 84 | kernel : str 85 | Kernel type: 'imq' or 'gaussian'. 86 | 87 | Returns 88 | ------- 89 | float 90 | Mean squared LOO-CV error. 91 | """ 92 | # Build collocation matrix with current shape factor 93 | # pod_rbf expects (n_params, n_points) shape 94 | x_2d = x[None, :] 95 | x_range = jnp.array([jnp.ptp(x)]) 96 | 97 | A = build_collocation_matrix(x_2d, x_range, kernel=kernel, shape_factor=shape_factor) 98 | 99 | # Solve for RBF coefficients: A @ c = y 100 | c = jnp.linalg.solve(A, y) 101 | 102 | # Compute diagonal of A^{-1} 103 | # For numerical stability, we could use the formula: 104 | # diag(A^{-1})_i = e_i^T @ A^{-1} @ e_i 105 | # But direct inversion is fine for moderate problem sizes 106 | A_inv = jnp.linalg.inv(A) 107 | A_inv_diag = jnp.diag(A_inv) 108 | 109 | # Rippa criterion: LOO error at each point 110 | loo_errors = c / A_inv_diag 111 | 112 | # Return mean squared error 113 | return jnp.mean(loo_errors**2) 114 | 115 | 116 | # ============================================================================= 117 | # Shape Parameter Optimization 118 | # ============================================================================= 119 | 120 | 121 | def optimize_shape_parameter(x, y, kernel="imq", initial_guess=1.0, verbose=True): 122 | """ 123 | Find optimal shape parameter using L-BFGS on Rippa criterion. 124 | 125 | Uses JAXopt's L-BFGS optimizer to minimize the LOO-CV error, with gradients 126 | computed automatically via JAX autodiff. 127 | 128 | We optimize in log-space (log_c) to ensure positivity without breaking gradients. 129 | shape_factor = exp(log_c) 130 | 131 | Parameters 132 | ---------- 133 | x : Array 134 | Training point locations. 135 | y : Array 136 | Training values. 137 | kernel : str 138 | Kernel type. 139 | initial_guess : float 140 | Starting value for shape parameter. 141 | verbose : bool 142 | Print optimization progress. 143 | 144 | Returns 145 | ------- 146 | optimal_shape : float 147 | Optimized shape parameter. 148 | history : dict 149 | Optimization history with shape parameters, costs, and gradients. 150 | """ 151 | history = {"shape_factor": [initial_guess], "cost": [], "grad_log": []} 152 | 153 | # Optimize in log-space for unconstrained optimization with guaranteed positivity 154 | # shape_factor = exp(log_c), so we optimize log_c 155 | def objective(log_c): 156 | shape_factor = jnp.exp(log_c[0]) 157 | return loocv_cost(shape_factor, x, y, kernel) 158 | 159 | # JIT compile for speed 160 | objective_jit = jax.jit(objective) 161 | grad_fn = jax.jit(jax.grad(objective)) 162 | 163 | # Initial evaluation (in log space) 164 | log_c_init = jnp.log(initial_guess) 165 | init_cost = objective_jit(jnp.array([log_c_init])) 166 | init_grad = grad_fn(jnp.array([log_c_init])) 167 | history["cost"].append(float(init_cost)) 168 | history["grad_log"].append(float(init_grad[0])) 169 | 170 | if verbose: 171 | print(f"Initial: shape_factor = {initial_guess:.6f}, LOO-CV cost = {init_cost:.6e}") 172 | 173 | # Create L-BFGS solver 174 | solver = jaxopt.LBFGS( 175 | fun=objective, 176 | maxiter=100, 177 | tol=1e-10, 178 | ) 179 | 180 | # Initialize in log space 181 | log_c = jnp.array([log_c_init]) 182 | 183 | # Run optimization with manual iteration to track history 184 | state = solver.init_state(log_c) 185 | 186 | for i in range(100): 187 | log_c, state = solver.update(log_c, state) 188 | cost = objective_jit(log_c) 189 | grad = grad_fn(log_c) 190 | 191 | shape_factor = float(jnp.exp(log_c[0])) 192 | history["shape_factor"].append(shape_factor) 193 | history["cost"].append(float(cost)) 194 | history["grad_log"].append(float(grad[0])) 195 | 196 | if verbose and (i + 1) % 10 == 0: 197 | print( 198 | f" Iter {i+1:3d}: shape_factor = {shape_factor:.6f}, " 199 | f"cost = {cost:.6e}, |grad| = {jnp.abs(grad[0]):.2e}" 200 | ) 201 | 202 | # Check convergence (gradient in log space) 203 | if jnp.abs(grad[0]) < 1e-10: 204 | if verbose: 205 | print(f" Converged at iteration {i+1}") 206 | break 207 | 208 | optimal_shape = float(jnp.exp(log_c[0])) 209 | 210 | if verbose: 211 | print(f"\nOptimal shape parameter: {optimal_shape:.6f}") 212 | print(f"Final LOO-CV cost: {history['cost'][-1]:.6e}") 213 | 214 | return optimal_shape, history 215 | 216 | 217 | # ============================================================================= 218 | # RBF Interpolation Helper 219 | # ============================================================================= 220 | 221 | 222 | def rbf_interpolate(x_train, y_train, x_eval, shape_factor, kernel="imq"): 223 | """ 224 | Perform RBF interpolation at evaluation points. 225 | 226 | Parameters 227 | ---------- 228 | x_train : Array 229 | Training point locations. 230 | y_train : Array 231 | Training values. 232 | x_eval : Array 233 | Points at which to evaluate the interpolant. 234 | shape_factor : float 235 | RBF shape parameter. 236 | kernel : str 237 | Kernel type. 238 | 239 | Returns 240 | ------- 241 | Array 242 | Interpolated values at x_eval. 243 | """ 244 | from pod_rbf.rbf import build_inference_matrix 245 | 246 | x_train_2d = x_train[None, :] 247 | x_eval_2d = x_eval[None, :] 248 | x_range = jnp.array([jnp.ptp(x_train)]) 249 | 250 | # Build collocation matrix and solve for coefficients 251 | A = build_collocation_matrix( 252 | x_train_2d, x_range, kernel=kernel, shape_factor=shape_factor 253 | ) 254 | c = jnp.linalg.solve(A, y_train) 255 | 256 | # Build inference matrix and evaluate 257 | F = build_inference_matrix( 258 | x_train_2d, x_eval_2d, x_range, kernel=kernel, shape_factor=shape_factor 259 | ) 260 | 261 | return F @ c 262 | 263 | 264 | # ============================================================================= 265 | # Visualization 266 | # ============================================================================= 267 | 268 | 269 | def plot_loocv_landscape(x, y, optimal_shape, kernel="imq"): 270 | """Plot LOO-CV cost as a function of shape parameter.""" 271 | shape_factors = np.logspace(-2, 1, 100) 272 | costs = [] 273 | 274 | loocv_jit = jax.jit(lambda c: loocv_cost(c, x, y, kernel)) 275 | 276 | for c in shape_factors: 277 | costs.append(float(loocv_jit(c))) 278 | 279 | fig, ax = plt.subplots(figsize=(8, 5)) 280 | ax.semilogy(shape_factors, costs, "b-", linewidth=2, label="LOO-CV cost") 281 | ax.axvline( 282 | optimal_shape, color="r", linestyle="--", linewidth=2, label=f"Optimal: {optimal_shape:.4f}" 283 | ) 284 | ax.set_xlabel("Shape Parameter (c)", fontsize=12) 285 | ax.set_ylabel("LOO-CV Cost (log scale)", fontsize=12) 286 | ax.set_title("LOO-CV Error Landscape", fontsize=14) 287 | ax.legend(fontsize=11) 288 | ax.grid(True, alpha=0.3) 289 | ax.set_xscale("log") 290 | 291 | return fig 292 | 293 | 294 | def plot_interpolation_comparison(x_train, y_train, x_eval, y_true, shape_factors, kernel="imq"): 295 | """Compare interpolation with different shape parameters.""" 296 | n_shapes = len(shape_factors) 297 | fig, axes = plt.subplots(1, n_shapes, figsize=(5 * n_shapes, 4)) 298 | 299 | if n_shapes == 1: 300 | axes = [axes] 301 | 302 | for ax, shape_factor in zip(axes, shape_factors): 303 | # Interpolate 304 | y_interp = rbf_interpolate(x_train, y_train, x_eval, shape_factor, kernel) 305 | 306 | # Compute error 307 | rmse = np.sqrt(np.mean((np.array(y_interp) - np.array(y_true)) ** 2)) 308 | 309 | # Plot 310 | ax.plot(x_eval, y_true, "k-", linewidth=2, label="True function") 311 | ax.plot(x_eval, y_interp, "b--", linewidth=2, label="RBF interpolant") 312 | ax.scatter(x_train, y_train, c="r", s=50, zorder=5, label="Training points") 313 | 314 | ax.set_xlabel("x", fontsize=11) 315 | ax.set_ylabel("f(x)", fontsize=11) 316 | ax.set_title(f"c = {shape_factor:.4f}, RMSE = {rmse:.2e}", fontsize=12) 317 | ax.legend(fontsize=9) 318 | ax.grid(True, alpha=0.3) 319 | 320 | plt.tight_layout() 321 | return fig 322 | 323 | 324 | def plot_convergence(history): 325 | """Plot optimization convergence.""" 326 | fig, axes = plt.subplots(1, 3, figsize=(15, 4)) 327 | 328 | # Cost convergence 329 | ax = axes[0] 330 | ax.semilogy(history["cost"], "b-", linewidth=2, marker="o", markersize=4) 331 | ax.set_xlabel("Iteration", fontsize=12) 332 | ax.set_ylabel("LOO-CV Cost (log scale)", fontsize=12) 333 | ax.set_title("Cost Convergence", fontsize=14) 334 | ax.grid(True, alpha=0.3) 335 | 336 | # Shape parameter evolution 337 | ax = axes[1] 338 | ax.plot(history["shape_factor"], "r-", linewidth=2, marker="o", markersize=4) 339 | ax.set_xlabel("Iteration", fontsize=12) 340 | ax.set_ylabel("Shape Parameter", fontsize=12) 341 | ax.set_title("Shape Parameter Evolution", fontsize=14) 342 | ax.grid(True, alpha=0.3) 343 | 344 | # Gradient magnitude (in log space) 345 | ax = axes[2] 346 | grad_key = "grad_log" if "grad_log" in history else "grad" 347 | ax.semilogy(np.abs(history[grad_key]), "g-", linewidth=2, marker="o", markersize=4) 348 | ax.set_xlabel("Iteration", fontsize=12) 349 | ax.set_ylabel("|Gradient| (log scale)", fontsize=12) 350 | ax.set_title("Gradient Magnitude", fontsize=14) 351 | ax.grid(True, alpha=0.3) 352 | 353 | plt.tight_layout() 354 | return fig 355 | 356 | 357 | # ============================================================================= 358 | # Main 359 | # ============================================================================= 360 | 361 | 362 | def find_good_initial_guess(x, y, kernel="imq"): 363 | """Scan a range of shape parameters to find a good starting point.""" 364 | loocv_jit = jax.jit(lambda c: loocv_cost(c, x, y, kernel)) 365 | 366 | # Scan over a range of shape parameters 367 | shape_factors = np.logspace(-2, 1, 50) 368 | costs = [] 369 | 370 | for c in shape_factors: 371 | try: 372 | cost = float(loocv_jit(c)) 373 | if np.isfinite(cost): 374 | costs.append(cost) 375 | else: 376 | costs.append(np.inf) 377 | except Exception: 378 | costs.append(np.inf) 379 | 380 | # Find minimum 381 | best_idx = np.argmin(costs) 382 | return shape_factors[best_idx], shape_factors, costs 383 | 384 | 385 | def main(): 386 | print("=" * 70) 387 | print("Shape Parameter Optimization via Autodiff (Rippa Criterion + L-BFGS)") 388 | print("=" * 70) 389 | 390 | # Setup 391 | np.random.seed(42) 392 | kernel = "imq" 393 | 394 | # Generate training data (Runge function) 395 | n_train = 15 396 | x_train = jnp.linspace(-1, 1, n_train) 397 | y_train = runge_function(x_train) 398 | 399 | # Dense evaluation grid 400 | x_eval = jnp.linspace(-1, 1, 200) 401 | y_true = runge_function(x_eval) 402 | 403 | print(f"\nTest problem: Runge function f(x) = 1/(1 + 25x^2)") 404 | print(f"Training points: {n_train}") 405 | print(f"Kernel: {kernel.upper()}") 406 | 407 | # ========================================================================== 408 | # First, scan the landscape to find a good initial guess 409 | # ========================================================================== 410 | print("\n" + "-" * 70) 411 | print("Scanning LOO-CV landscape for good initial guess...") 412 | print("-" * 70) 413 | 414 | initial_guess, scan_shapes, scan_costs = find_good_initial_guess(x_train, y_train, kernel) 415 | print(f"Best from scan: shape_factor = {initial_guess:.6f}, LOO-CV = {np.min(scan_costs):.6e}") 416 | 417 | # ========================================================================== 418 | # Optimize shape parameter using Rippa criterion + L-BFGS 419 | # ========================================================================== 420 | print("\n" + "-" * 70) 421 | print("Refining with JAXopt L-BFGS...") 422 | print("-" * 70) 423 | 424 | start = time.time() 425 | optimal_shape, history = optimize_shape_parameter( 426 | x_train, y_train, kernel=kernel, initial_guess=initial_guess, verbose=True 427 | ) 428 | opt_time = time.time() - start 429 | print(f"Optimization time: {opt_time:.3f} sec") 430 | 431 | # ========================================================================== 432 | # Compare with different shape parameters 433 | # ========================================================================== 434 | print("\n" + "-" * 70) 435 | print("Comparing interpolation accuracy...") 436 | print("-" * 70) 437 | 438 | # Test several shape parameters 439 | test_shapes = [0.1, optimal_shape, 2.0] 440 | shape_labels = ["Too small (0.1)", f"Optimal ({optimal_shape:.4f})", "Too large (2.0)"] 441 | 442 | for shape, label in zip(test_shapes, shape_labels): 443 | y_interp = rbf_interpolate(x_train, y_train, x_eval, shape, kernel) 444 | rmse = np.sqrt(np.mean((np.array(y_interp) - np.array(y_true)) ** 2)) 445 | loocv = loocv_cost(shape, x_train, y_train, kernel) 446 | print(f" {label:25s}: RMSE = {rmse:.4e}, LOO-CV = {loocv:.4e}") 447 | 448 | # ========================================================================== 449 | # Demonstrate autodiff explicitly 450 | # ========================================================================== 451 | print("\n" + "-" * 70) 452 | print("Demonstrating autodiff capabilities...") 453 | print("-" * 70) 454 | 455 | # Show gradient computation 456 | grad_fn = jax.jit(jax.grad(lambda c: loocv_cost(c, x_train, y_train, kernel))) 457 | 458 | for c in [0.1, 0.5, optimal_shape, 2.0]: 459 | grad = grad_fn(c) 460 | print(f" d(LOO-CV)/dc at c={c:.4f}: {grad:.6e}") 461 | 462 | # Show Hessian computation 463 | hess_fn = jax.jit(jax.hessian(lambda c: loocv_cost(c, x_train, y_train, kernel))) 464 | hess_at_opt = hess_fn(optimal_shape) 465 | print(f"\n d²(LOO-CV)/dc² at optimal c={optimal_shape:.4f}: {hess_at_opt:.6e}") 466 | print(" (Positive Hessian confirms this is a local minimum)") 467 | 468 | # ========================================================================== 469 | # Visualizations 470 | # ========================================================================== 471 | print("\n" + "-" * 70) 472 | print("Generating plots...") 473 | print("-" * 70) 474 | 475 | # Plot 1: LOO-CV landscape 476 | fig1 = plot_loocv_landscape(x_train, y_train, optimal_shape, kernel) 477 | 478 | # Plot 2: Interpolation comparison 479 | fig2 = plot_interpolation_comparison( 480 | x_train, y_train, x_eval, y_true, test_shapes, kernel 481 | ) 482 | 483 | # Plot 3: Convergence history 484 | fig3 = plot_convergence(history) 485 | 486 | print("\n" + "=" * 70) 487 | print("Done! This example demonstrated:") 488 | print(" 1. Rippa criterion (LOO-CV) for RBF shape parameter selection") 489 | print(" 2. JAX autodiff through build_collocation_matrix") 490 | print(" 3. JAXopt L-BFGS optimization") 491 | print(" 4. Gradient and Hessian computation of LOO-CV cost") 492 | print("=" * 70) 493 | 494 | plt.show() 495 | 496 | 497 | if __name__ == "__main__": 498 | main() 499 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | """Tests for core train/inference functions.""" 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pytest 7 | 8 | import pod_rbf 9 | from pod_rbf.core import inference, inference_single, train 10 | from pod_rbf.types import ModelState, TrainConfig, TrainResult 11 | 12 | 13 | class TestTrain: 14 | """Test training function.""" 15 | 16 | @pytest.fixture 17 | def linear_data(self): 18 | """Create simple linear test data: f(x, p) = p * x.""" 19 | x = jnp.linspace(0, 1, 50) 20 | params = jnp.linspace(1, 10, 10) 21 | # snapshot[i, j] = params[j] * x[i] 22 | snapshot = x[:, None] * params[None, :] 23 | return snapshot, params, x 24 | 25 | def test_returns_train_result(self, linear_data): 26 | """Train should return TrainResult.""" 27 | snapshot, params, _ = linear_data 28 | result = train(snapshot, params) 29 | 30 | assert isinstance(result, TrainResult) 31 | assert isinstance(result.state, ModelState) 32 | assert result.n_modes > 0 33 | assert isinstance(result.used_eig_decomp, bool) 34 | 35 | def test_model_state_shapes(self, linear_data): 36 | """Model state should have correct shapes.""" 37 | snapshot, params, _ = linear_data 38 | result = train(snapshot, params) 39 | state = result.state 40 | 41 | n_samples, n_snapshots = snapshot.shape 42 | n_modes = result.n_modes 43 | 44 | assert state.basis.shape == (n_samples, n_modes) 45 | assert state.weights.shape == (n_modes, n_snapshots) 46 | assert state.train_params.shape == (1, n_snapshots) 47 | assert state.params_range.shape == (1,) 48 | 49 | def test_custom_config(self, linear_data): 50 | """Should accept custom config.""" 51 | snapshot, params, _ = linear_data 52 | config = TrainConfig(energy_threshold=0.9) 53 | result = train(snapshot, params, config=config) 54 | 55 | assert result.state.truncated_energy >= 0.9 56 | 57 | def test_fixed_shape_factor(self, linear_data): 58 | """Should accept fixed shape factor.""" 59 | snapshot, params, _ = linear_data 60 | result = train(snapshot, params, shape_factor=0.5) 61 | 62 | assert result.state.shape_factor == 0.5 63 | 64 | def test_multi_param(self): 65 | """Should work with multiple parameters.""" 66 | x = jnp.linspace(0, 1, 30) 67 | p1 = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0]) 68 | p2 = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5]) 69 | params = jnp.stack([p1, p2], axis=0) 70 | 71 | # f(x, p1, p2) = p1 * x + p2 72 | snapshot = x[:, None] * p1[None, :] + p2[None, :] 73 | 74 | result = train(snapshot, params) 75 | 76 | assert result.state.train_params.shape == (2, 5) 77 | assert result.state.params_range.shape == (2,) 78 | 79 | 80 | class TestInference: 81 | """Test inference functions.""" 82 | 83 | @pytest.fixture 84 | def trained_model(self): 85 | """Create and train a model on linear data.""" 86 | x = jnp.linspace(0, 1, 50) 87 | params = jnp.linspace(1, 10, 10) 88 | snapshot = x[:, None] * params[None, :] 89 | result = train(snapshot, params) 90 | return result.state, x, params, snapshot 91 | 92 | def test_inference_single_shape(self, trained_model): 93 | """inference_single should return 1D array.""" 94 | state, x, _, _ = trained_model 95 | pred = inference_single(state, jnp.array(5.0)) 96 | 97 | assert pred.shape == (len(x),) 98 | 99 | def test_inference_batch_shape(self, trained_model): 100 | """inference should return 2D array for batch input.""" 101 | state, x, _, _ = trained_model 102 | pred = inference(state, jnp.array([[2.0, 5.0, 8.0]])) 103 | 104 | assert pred.shape == (len(x), 3) 105 | 106 | def test_interpolation_at_training_points(self, trained_model): 107 | """Inference at training points should match training data.""" 108 | state, x, params, snapshot = trained_model 109 | 110 | for i, p in enumerate(params): 111 | pred = inference_single(state, p) 112 | expected = snapshot[:, i] 113 | assert jnp.allclose(pred, expected, rtol=1e-3), f"Mismatch at training point {i}" 114 | 115 | def test_interpolation_between_points(self, trained_model): 116 | """Inference between training points should be accurate for linear data.""" 117 | state, x, _, _ = trained_model 118 | 119 | # Test at midpoint 120 | pred = inference_single(state, jnp.array(5.5)) 121 | expected = 5.5 * x 122 | 123 | assert jnp.allclose(pred, expected, rtol=1e-2), "Interpolation at midpoint should be accurate" 124 | 125 | def test_scalar_param(self, trained_model): 126 | """Should handle scalar parameter input.""" 127 | state, x, _, _ = trained_model 128 | pred = inference_single(state, 5.0) 129 | 130 | assert pred.shape == (len(x),) 131 | 132 | 133 | class TestGradients: 134 | """Test autodifferentiation capabilities.""" 135 | 136 | @pytest.fixture 137 | def trained_model(self): 138 | """Create trained model for gradient tests.""" 139 | x = jnp.linspace(0, 1, 50) 140 | params = jnp.linspace(1, 10, 10) 141 | snapshot = x[:, None] * params[None, :] 142 | result = train(snapshot, params) 143 | return result.state, x 144 | 145 | def test_grad_wrt_param(self, trained_model): 146 | """Gradient of inference w.r.t. parameter should exist and be non-zero.""" 147 | state, _ = trained_model 148 | 149 | def loss(p): 150 | pred = inference_single(state, p) 151 | return jnp.sum(pred**2) 152 | 153 | grad_fn = jax.grad(loss) 154 | grad = grad_fn(jnp.array(5.0)) 155 | 156 | assert not jnp.isnan(grad), "Gradient should not be NaN" 157 | assert grad != 0.0, "Gradient should be non-zero" 158 | 159 | def test_inverse_problem(self, trained_model): 160 | """Test solving inverse problem via gradient descent.""" 161 | state, x = trained_model 162 | 163 | # Target: parameter = 7.5 164 | target = 7.5 * x 165 | 166 | def loss(p): 167 | pred = inference_single(state, p) 168 | return jnp.mean((pred - target) ** 2) 169 | 170 | # Gradient descent 171 | p = jnp.array(5.0) # Initial guess 172 | lr = 0.5 173 | 174 | for _ in range(50): 175 | g = jax.grad(loss)(p) 176 | p = p - lr * g 177 | 178 | assert jnp.abs(p - 7.5) < 0.1, f"Recovered parameter {p} should be close to 7.5" 179 | 180 | def test_jacobian(self, trained_model): 181 | """Jacobian should have correct shape.""" 182 | state, x = trained_model 183 | 184 | jacobian = jax.jacobian(lambda p: inference_single(state, p))(jnp.array(5.0)) 185 | 186 | assert jacobian.shape == (len(x),), f"Jacobian shape should be ({len(x)},), got {jacobian.shape}" 187 | 188 | def test_value_and_grad(self, trained_model): 189 | """value_and_grad should work.""" 190 | state, _ = trained_model 191 | 192 | def loss(p): 193 | pred = inference_single(state, p) 194 | return jnp.sum(pred**2) 195 | 196 | val, grad = jax.value_and_grad(loss)(jnp.array(5.0)) 197 | 198 | assert not jnp.isnan(val) 199 | assert not jnp.isnan(grad) 200 | 201 | 202 | class TestSchurComplement: 203 | """Test Schur complement solver integration.""" 204 | 205 | @pytest.fixture 206 | def linear_data(self): 207 | """Create simple linear test data: f(x, p) = p * x.""" 208 | x = jnp.linspace(0, 1, 50) 209 | params = jnp.linspace(1, 10, 10) 210 | snapshot = x[:, None] * params[None, :] 211 | return snapshot, params, x 212 | 213 | def test_model_state_has_poly_fields(self, linear_data): 214 | """Model state should have polynomial coefficient fields.""" 215 | snapshot, params, _ = linear_data 216 | result = train(snapshot, params) 217 | state = result.state 218 | 219 | assert hasattr(state, "poly_coeffs"), "ModelState should have poly_coeffs field" 220 | assert hasattr(state, "poly_degree"), "ModelState should have poly_degree field" 221 | assert state.poly_degree == 2, "Default poly_degree should be 2" 222 | assert state.poly_coeffs is not None, "poly_coeffs should not be None with default config" 223 | 224 | def test_poly_coeffs_shape(self, linear_data): 225 | """Polynomial coefficients should have correct shape.""" 226 | snapshot, params, _ = linear_data 227 | result = train(snapshot, params) 228 | state = result.state 229 | 230 | n_modes = result.n_modes 231 | # For 1D params with degree 2: 3 polynomial terms [1, p, p²] 232 | n_poly = 3 233 | assert state.poly_coeffs.shape == (n_modes, n_poly), f"Expected ({n_modes}, {n_poly}), got {state.poly_coeffs.shape}" 234 | 235 | def test_poly_degree_0_fallback(self, linear_data): 236 | """poly_degree=0 should use pinv fallback.""" 237 | snapshot, params, _ = linear_data 238 | config = TrainConfig(poly_degree=0) 239 | result = train(snapshot, params, config=config) 240 | state = result.state 241 | 242 | assert state.poly_degree == 0 243 | assert state.poly_coeffs is None, "poly_coeffs should be None when poly_degree=0" 244 | 245 | def test_interpolation_with_schur(self, linear_data): 246 | """Schur complement solver should interpolate training points accurately.""" 247 | snapshot, params, x = linear_data 248 | result = train(snapshot, params) 249 | state = result.state 250 | 251 | for i, p in enumerate(params): 252 | pred = inference_single(state, p) 253 | expected = snapshot[:, i] 254 | assert jnp.allclose(pred, expected, rtol=1e-3), f"Mismatch at training point {i}" 255 | 256 | def test_interpolation_between_points_schur(self, linear_data): 257 | """Schur solver should interpolate accurately between training points.""" 258 | snapshot, params, x = linear_data 259 | result = train(snapshot, params) 260 | state = result.state 261 | 262 | pred = inference_single(state, jnp.array(5.5)) 263 | expected = 5.5 * x 264 | 265 | assert jnp.allclose(pred, expected, rtol=1e-2), "Interpolation at midpoint should be accurate" 266 | 267 | def test_grad_with_schur(self, linear_data): 268 | """Autodiff should work through Schur complement solver.""" 269 | snapshot, params, _ = linear_data 270 | result = train(snapshot, params) 271 | state = result.state 272 | 273 | def loss(p): 274 | pred = inference_single(state, p) 275 | return jnp.sum(pred**2) 276 | 277 | grad_fn = jax.grad(loss) 278 | grad = grad_fn(jnp.array(5.0)) 279 | 280 | assert not jnp.isnan(grad), "Gradient should not be NaN" 281 | assert grad != 0.0, "Gradient should be non-zero" 282 | 283 | def test_multi_param_with_schur(self): 284 | """Schur solver should work with multiple parameters.""" 285 | x = jnp.linspace(0, 1, 30) 286 | # Use uncorrelated parameters to avoid rank-deficient polynomial basis 287 | p1 = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0]) 288 | p2 = jnp.array([0.5, 0.1, 0.4, 0.2, 0.3]) 289 | params = jnp.stack([p1, p2], axis=0) 290 | 291 | snapshot = x[:, None] * p1[None, :] + p2[None, :] 292 | 293 | result = train(snapshot, params) 294 | state = result.state 295 | 296 | # For 2D params with degree 2: 6 polynomial terms 297 | n_poly = 6 298 | assert state.poly_coeffs.shape[1] == n_poly, f"Expected {n_poly} poly terms, got {state.poly_coeffs.shape[1]}" 299 | 300 | # Test interpolation at training points 301 | for i in range(len(p1)): 302 | pred = inference_single(state, jnp.array([p1[i], p2[i]])) 303 | expected = snapshot[:, i] 304 | assert jnp.allclose(pred, expected, rtol=1e-3), f"Mismatch at training point {i}" 305 | 306 | 307 | class TestJIT: 308 | """Test JIT compilation.""" 309 | 310 | @pytest.fixture 311 | def trained_model(self): 312 | """Create trained model for JIT tests.""" 313 | x = jnp.linspace(0, 1, 50) 314 | params = jnp.linspace(1, 10, 10) 315 | snapshot = x[:, None] * params[None, :] 316 | result = train(snapshot, params) 317 | return result.state 318 | 319 | def test_inference_single_jit(self, trained_model): 320 | """inference_single should JIT compile via closure pattern.""" 321 | state = trained_model 322 | 323 | # Create a closure that captures state (recommended pattern for JAX) 324 | @jax.jit 325 | def jitted_inference(p): 326 | return inference_single(state, p) 327 | 328 | # First call compiles 329 | pred1 = jitted_inference(jnp.array(5.0)) 330 | # Second call uses cached compilation 331 | pred2 = jitted_inference(jnp.array(6.0)) 332 | 333 | assert pred1.shape == (50,) 334 | assert pred2.shape == (50,) 335 | 336 | def test_inference_batch_jit(self, trained_model): 337 | """inference should JIT compile via closure pattern.""" 338 | state = trained_model 339 | 340 | @jax.jit 341 | def jitted_inference(params): 342 | return inference(state, params) 343 | 344 | pred = jitted_inference(jnp.array([[2.0, 5.0, 8.0]])) 345 | 346 | assert pred.shape == (50, 3) 347 | 348 | def test_grad_jit(self, trained_model): 349 | """Gradient computation should JIT compile.""" 350 | state = trained_model 351 | 352 | @jax.jit 353 | def loss_and_grad(p): 354 | pred = inference_single(state, p) 355 | loss_val = jnp.sum(pred**2) 356 | return loss_val 357 | 358 | grad_fn = jax.jit(jax.grad(loss_and_grad)) 359 | 360 | loss = loss_and_grad(jnp.array(5.0)) 361 | grad = grad_fn(jnp.array(5.0)) 362 | 363 | assert not jnp.isnan(loss) 364 | assert not jnp.isnan(grad) 365 | 366 | 367 | class TestKernelTypes: 368 | """Test training and inference with different kernel types.""" 369 | 370 | @pytest.fixture 371 | def linear_data(self): 372 | """Create simple linear test data: f(x, p) = p * x.""" 373 | x = jnp.linspace(0, 1, 50) 374 | params = jnp.linspace(1, 10, 10) 375 | snapshot = x[:, None] * params[None, :] 376 | return snapshot, params, x 377 | 378 | def test_train_with_gaussian_kernel(self, linear_data): 379 | """Train with Gaussian kernel.""" 380 | snapshot, params, _ = linear_data 381 | config = TrainConfig(kernel="gaussian") 382 | result = train(snapshot, params, config=config) 383 | 384 | assert result.state.kernel == "gaussian" 385 | assert result.state.shape_factor is not None 386 | assert result.n_modes > 0 387 | 388 | def test_train_with_phs_kernel(self, linear_data): 389 | """Train with polyharmonic spline kernel.""" 390 | snapshot, params, _ = linear_data 391 | config = TrainConfig(kernel="polyharmonic_spline", kernel_order=3) 392 | result = train(snapshot, params, config=config) 393 | 394 | assert result.state.kernel == "polyharmonic_spline" 395 | assert result.state.kernel_order == 3 396 | assert result.state.shape_factor is None # PHS doesn't use shape parameter 397 | assert result.n_modes > 0 398 | 399 | def test_train_with_imq_kernel_explicit(self, linear_data): 400 | """Train with IMQ kernel explicitly specified.""" 401 | snapshot, params, _ = linear_data 402 | config = TrainConfig(kernel="imq") 403 | result = train(snapshot, params, config=config) 404 | 405 | assert result.state.kernel == "imq" 406 | assert result.state.shape_factor is not None 407 | assert result.n_modes > 0 408 | 409 | def test_inference_with_gaussian_kernel(self, linear_data): 410 | """Inference should work with Gaussian kernel.""" 411 | snapshot, params, x = linear_data 412 | config = TrainConfig(kernel="gaussian") 413 | result = train(snapshot, params, config=config) 414 | 415 | # Test inference at training points 416 | pred = inference_single(result.state, params[5]) 417 | expected = snapshot[:, 5] 418 | assert jnp.allclose(pred, expected, rtol=1e-3) 419 | 420 | def test_inference_with_phs_kernel(self, linear_data): 421 | """Inference should work with PHS kernel.""" 422 | snapshot, params, x = linear_data 423 | config = TrainConfig(kernel="polyharmonic_spline", kernel_order=3) 424 | result = train(snapshot, params, config=config) 425 | 426 | # Test inference at training points 427 | pred = inference_single(result.state, params[5]) 428 | expected = snapshot[:, 5] 429 | assert jnp.allclose(pred, expected, rtol=1e-3) 430 | 431 | def test_all_kernels_interpolate_training_data(self, linear_data): 432 | """All kernels should interpolate training data accurately.""" 433 | snapshot, params, x = linear_data 434 | 435 | kernels = [ 436 | ("imq", {}), 437 | ("gaussian", {}), 438 | ("polyharmonic_spline", {"kernel_order": 3}), 439 | ] 440 | 441 | for kernel, kwargs in kernels: 442 | config = TrainConfig(kernel=kernel, **kwargs) 443 | result = train(snapshot, params, config=config) 444 | 445 | # Check interpolation at a few training points 446 | for i in [0, 5, 9]: 447 | pred = inference_single(result.state, params[i]) 448 | expected = snapshot[:, i] 449 | assert jnp.allclose( 450 | pred, expected, rtol=1e-3 451 | ), f"Kernel {kernel} failed to interpolate training point {i}" 452 | 453 | def test_phs_different_orders(self, linear_data): 454 | """Test PHS with different orders.""" 455 | snapshot, params, x = linear_data 456 | 457 | for order in [1, 3, 5]: 458 | config = TrainConfig(kernel="polyharmonic_spline", kernel_order=order) 459 | result = train(snapshot, params, config=config) 460 | 461 | assert result.state.kernel_order == order 462 | assert result.state.shape_factor is None 463 | 464 | # Should still interpolate training data 465 | pred = inference_single(result.state, params[5]) 466 | expected = snapshot[:, 5] 467 | assert jnp.allclose(pred, expected, rtol=1e-3), f"PHS order {order} failed" 468 | 469 | 470 | class TestKernelGradients: 471 | """Test autodiff with different kernels.""" 472 | 473 | @pytest.fixture 474 | def linear_data(self): 475 | """Create simple linear test data.""" 476 | x = jnp.linspace(0, 1, 50) 477 | params = jnp.linspace(1, 10, 10) 478 | snapshot = x[:, None] * params[None, :] 479 | return snapshot, params, x 480 | 481 | def test_grad_with_gaussian_kernel(self, linear_data): 482 | """Gradients should work with Gaussian kernel.""" 483 | snapshot, params, x = linear_data 484 | config = TrainConfig(kernel="gaussian") 485 | result = train(snapshot, params, config=config) 486 | 487 | def loss(p): 488 | pred = inference_single(result.state, p) 489 | return jnp.sum(pred**2) 490 | 491 | grad = jax.grad(loss)(jnp.array(5.0)) 492 | assert not jnp.isnan(grad) 493 | assert grad != 0.0 494 | 495 | def test_grad_with_phs_kernel(self, linear_data): 496 | """Gradients should work with PHS kernel.""" 497 | snapshot, params, x = linear_data 498 | config = TrainConfig(kernel="polyharmonic_spline", kernel_order=3) 499 | result = train(snapshot, params, config=config) 500 | 501 | def loss(p): 502 | pred = inference_single(result.state, p) 503 | return jnp.sum(pred**2) 504 | 505 | grad = jax.grad(loss)(jnp.array(5.0)) 506 | assert not jnp.isnan(grad) 507 | assert grad != 0.0 508 | 509 | def test_inverse_problem_all_kernels(self, linear_data): 510 | """Inverse problem should work with all kernels.""" 511 | snapshot, params, x = linear_data 512 | target_param = 7.5 513 | target = target_param * x 514 | 515 | kernels = [ 516 | ("imq", {}), 517 | ("gaussian", {}), 518 | ("polyharmonic_spline", {"kernel_order": 3}), 519 | ] 520 | 521 | for kernel, kwargs in kernels: 522 | config = TrainConfig(kernel=kernel, **kwargs) 523 | result = train(snapshot, params, config=config) 524 | 525 | def loss(p): 526 | pred = inference_single(result.state, p) 527 | return jnp.mean((pred - target) ** 2) 528 | 529 | # Gradient descent 530 | p = jnp.array(5.0) 531 | lr = 0.5 532 | 533 | for _ in range(50): 534 | g = jax.grad(loss)(p) 535 | p = p - lr * g 536 | 537 | assert jnp.abs(p - target_param) < 0.2, f"Kernel {kernel}: recovered {p}, expected {target_param}" 538 | 539 | 540 | class TestBackwardsCompatibility: 541 | """Test backwards compatibility with default IMQ kernel.""" 542 | 543 | def test_default_kernel_is_imq(self): 544 | """Default config should use IMQ kernel.""" 545 | config = TrainConfig() 546 | assert config.kernel == "imq" 547 | assert config.kernel_order == 3 548 | 549 | def test_train_without_config_uses_imq(self): 550 | """Training without config should use IMQ kernel.""" 551 | x = jnp.linspace(0, 1, 50) 552 | params = jnp.linspace(1, 10, 10) 553 | snapshot = x[:, None] * params[None, :] 554 | 555 | result = train(snapshot, params) 556 | assert result.state.kernel == "imq" 557 | assert result.state.shape_factor is not None 558 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | --------------------------------------------------------------------------------