├── tests
├── __init__.py
├── conftest.py
├── test_io.py
├── test_decomposition.py
├── test_kernels.py
├── test_rbf.py
└── test_core.py
├── examples
├── lid-driven-cavity
│ ├── eq-snapshot.png
│ ├── results-re-450.png
│ ├── results-re-50.png
│ ├── generatePODdata.sim
│ ├── eq-snapshot-matrix.png
│ └── lid-driven-cavity.ipynb
├── heat-conduction
│ ├── 2d-heat.py
│ └── design-optimization.py
└── shape-optimization
│ └── optimize_shape.py
├── .github
├── codecov.yml
└── workflows
│ ├── docs.yml
│ ├── tests.yml
│ └── publish.yml
├── docs
├── api
│ └── index.md
├── examples.md
├── getting-started.md
├── user-guide
│ ├── inference.md
│ ├── io.md
│ ├── autodiff.md
│ └── training.md
└── index.md
├── pyproject.toml
├── pod_rbf
├── __init__.py
├── types.py
├── kernels.py
├── shape_optimization.py
├── decomposition.py
├── io.py
├── core.py
└── rbf.py
├── mkdocs.yml
├── .gitignore
├── CLAUDE.md
├── README.md
└── LICENSE
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/lid-driven-cavity/eq-snapshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kylebeggs/POD-RBF/HEAD/examples/lid-driven-cavity/eq-snapshot.png
--------------------------------------------------------------------------------
/examples/lid-driven-cavity/results-re-450.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kylebeggs/POD-RBF/HEAD/examples/lid-driven-cavity/results-re-450.png
--------------------------------------------------------------------------------
/examples/lid-driven-cavity/results-re-50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kylebeggs/POD-RBF/HEAD/examples/lid-driven-cavity/results-re-50.png
--------------------------------------------------------------------------------
/examples/lid-driven-cavity/generatePODdata.sim:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kylebeggs/POD-RBF/HEAD/examples/lid-driven-cavity/generatePODdata.sim
--------------------------------------------------------------------------------
/examples/lid-driven-cavity/eq-snapshot-matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kylebeggs/POD-RBF/HEAD/examples/lid-driven-cavity/eq-snapshot-matrix.png
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Pytest configuration and shared fixtures."""
2 |
3 | import jax
4 |
5 | # Ensure float64 is enabled for all tests
6 | jax.config.update("jax_enable_x64", True)
7 |
--------------------------------------------------------------------------------
/.github/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | status:
3 | project:
4 | default:
5 | target: auto
6 | threshold: 1%
7 | patch:
8 | default:
9 | target: auto
10 | threshold: 1%
11 |
12 | comment:
13 | layout: "reach,diff,flags,files,footer"
14 | behavior: default
15 | require_changes: false
16 |
17 | ignore:
18 | - "tests/"
19 | - "examples/"
20 | - "setup.py"
21 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: docs
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | workflow_dispatch:
7 |
8 | permissions:
9 | contents: write
10 |
11 | jobs:
12 | deploy:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v4
16 |
17 | - uses: actions/setup-python@v5
18 | with:
19 | python-version: "3.12"
20 |
21 | - name: Install dependencies
22 | run: |
23 | pip install mkdocs-material mkdocstrings[python]
24 | pip install -e .
25 |
26 | - name: Build and deploy docs
27 | run: mkdocs gh-deploy --force
28 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | pull_request:
7 | branches: [master]
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | python-version: ["3.10", "3.11", "3.12"]
16 |
17 | steps:
18 | - name: Checkout code
19 | uses: actions/checkout@v4
20 |
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v5
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | cache: 'pip'
26 | cache-dependency-path: 'pyproject.toml'
27 |
28 | - name: Install JAX (CPU)
29 | run: |
30 | pip install --upgrade pip
31 | pip install "jax[cpu]>=0.4.0"
32 |
33 | - name: Install package with dev dependencies
34 | run: pip install -e ".[dev]"
35 |
36 | - name: Run tests with coverage
37 | run: pytest tests/ -v --cov=pod_rbf --cov-report=xml --cov-report=term
38 |
39 | - name: Upload coverage to Codecov
40 | uses: codecov/codecov-action@v4
41 | with:
42 | file: ./coverage.xml
43 | flags: unittests
44 | name: codecov-umbrella
45 | fail_ci_if_error: false
46 | env:
47 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
48 |
--------------------------------------------------------------------------------
/docs/api/index.md:
--------------------------------------------------------------------------------
1 | # API Reference
2 |
3 | ## Core Functions
4 |
5 | ### train
6 |
7 | ::: pod_rbf.train
8 | options:
9 | show_root_heading: true
10 | show_source: false
11 |
12 | ### inference
13 |
14 | ::: pod_rbf.inference
15 | options:
16 | show_root_heading: true
17 | show_source: false
18 |
19 | ### inference_single
20 |
21 | ::: pod_rbf.inference_single
22 | options:
23 | show_root_heading: true
24 | show_source: false
25 |
26 | ## I/O Functions
27 |
28 | ### build_snapshot_matrix
29 |
30 | ::: pod_rbf.build_snapshot_matrix
31 | options:
32 | show_root_heading: true
33 | show_source: false
34 |
35 | ### save_model
36 |
37 | ::: pod_rbf.save_model
38 | options:
39 | show_root_heading: true
40 | show_source: false
41 |
42 | ### load_model
43 |
44 | ::: pod_rbf.load_model
45 | options:
46 | show_root_heading: true
47 | show_source: false
48 |
49 | ## Types
50 |
51 | ### TrainConfig
52 |
53 | ::: pod_rbf.TrainConfig
54 | options:
55 | show_root_heading: true
56 | show_source: false
57 |
58 | ### ModelState
59 |
60 | ::: pod_rbf.ModelState
61 | options:
62 | show_root_heading: true
63 | show_source: false
64 |
65 | ### TrainResult
66 |
67 | ::: pod_rbf.TrainResult
68 | options:
69 | show_root_heading: true
70 | show_source: false
71 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0", "setuptools-scm>=8.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "pod_rbf"
7 | dynamic = ["version"]
8 | authors = [{name = "Kyle Beggs", email = "beggskw@gmail.com"}]
9 | description = "JAX-based POD-RBF for autodiff-enabled reduced order modeling."
10 | readme = "README.md"
11 | license = {text = "MIT"}
12 | requires-python = ">=3.10"
13 | classifiers = [
14 | "Programming Language :: Python :: 3",
15 | "Programming Language :: Python :: 3.10",
16 | "Programming Language :: Python :: 3.11",
17 | "Programming Language :: Python :: 3.12",
18 | "License :: OSI Approved :: MIT License",
19 | "Operating System :: OS Independent",
20 | ]
21 | dependencies = [
22 | "jax>=0.4.0",
23 | "jaxlib>=0.4.0",
24 | "numpy",
25 | "tqdm",
26 | ]
27 |
28 | [project.optional-dependencies]
29 | dev = ["pytest>=7.0", "pytest-cov"]
30 | docs = [
31 | "mkdocs-material>=9.5",
32 | "mkdocstrings[python]>=0.24",
33 | ]
34 |
35 | [project.urls]
36 | Homepage = "https://github.com/kylebeggs/POD-RBF"
37 | Repository = "https://github.com/kylebeggs/POD-RBF"
38 | Documentation = "https://kylebeggs.github.io/POD-RBF"
39 |
40 | [tool.setuptools_scm]
41 |
42 | [tool.uv.sources]
43 | pod-rbf = { workspace = true }
44 |
45 | [dependency-groups]
46 | dev = [
47 | "pod-rbf",
48 | ]
49 |
--------------------------------------------------------------------------------
/pod_rbf/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | POD-RBF: Proper Orthogonal Decomposition - Radial Basis Function Network.
3 |
4 | A JAX-based implementation enabling autodifferentiation for:
5 | - Gradient optimization
6 | - Sensitivity analysis
7 | - Inverse problems
8 |
9 | Usage
10 | -----
11 | >>> import pod_rbf
12 | >>> import jax.numpy as jnp
13 | >>>
14 | >>> # Train model
15 | >>> result = pod_rbf.train(snapshot, params)
16 | >>>
17 | >>> # Inference
18 | >>> pred = pod_rbf.inference_single(result.state, jnp.array(450.0))
19 | >>>
20 | >>> # Autodiff
21 | >>> import jax
22 | >>> grad_fn = jax.grad(lambda p: jnp.sum(pod_rbf.inference_single(result.state, p)**2))
23 | >>> gradient = grad_fn(jnp.array(450.0))
24 | """
25 |
26 | import jax
27 |
28 | # Enable float64 for numerical stability (SVD, condition numbers)
29 | jax.config.update("jax_enable_x64", True)
30 |
31 | from .core import inference, inference_single, train
32 | from .io import build_snapshot_matrix, load_model, save_model
33 | from .types import ModelState, TrainConfig, TrainResult
34 |
35 | try:
36 | from importlib.metadata import version
37 |
38 | __version__ = version("pod_rbf")
39 | except Exception:
40 | __version__ = "unknown"
41 |
42 | __all__ = [
43 | # Core functions
44 | "train",
45 | "inference",
46 | "inference_single",
47 | # Types
48 | "ModelState",
49 | "TrainConfig",
50 | "TrainResult",
51 | # I/O
52 | "build_snapshot_matrix",
53 | "save_model",
54 | "load_model",
55 | ]
56 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: POD-RBF
2 | site_description: JAX-based POD-RBF for autodiff-enabled reduced order modeling
3 | site_url: https://kylebeggs.github.io/POD-RBF
4 | repo_url: https://github.com/kylebeggs/POD-RBF
5 | repo_name: kylebeggs/POD-RBF
6 |
7 | theme:
8 | name: material
9 | palette:
10 | - media: "(prefers-color-scheme: light)"
11 | scheme: default
12 | primary: indigo
13 | toggle:
14 | icon: material/brightness-7
15 | name: Switch to dark mode
16 | - media: "(prefers-color-scheme: dark)"
17 | scheme: slate
18 | primary: indigo
19 | toggle:
20 | icon: material/brightness-4
21 | name: Switch to light mode
22 | features:
23 | - navigation.instant
24 | - navigation.sections
25 | - navigation.top
26 | - search.highlight
27 | - content.code.copy
28 |
29 | plugins:
30 | - search
31 | - mkdocstrings:
32 | handlers:
33 | python:
34 | options:
35 | show_source: true
36 | show_root_heading: true
37 |
38 | nav:
39 | - Home: index.md
40 | - Getting Started: getting-started.md
41 | - User Guide:
42 | - Training Models: user-guide/training.md
43 | - Inference: user-guide/inference.md
44 | - Autodifferentiation: user-guide/autodiff.md
45 | - Saving & Loading: user-guide/io.md
46 | - API Reference: api/index.md
47 | - Examples: examples.md
48 |
49 | markdown_extensions:
50 | - pymdownx.highlight:
51 | anchor_linenums: true
52 | - pymdownx.superfences
53 | - admonition
54 | - pymdownx.details
55 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to PyPI
2 |
3 | on:
4 | push:
5 | tags:
6 | - 'v*'
7 |
8 | jobs:
9 | build:
10 | name: Build distribution
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v4
14 | with:
15 | persist-credentials: false
16 | fetch-depth: 0 # Required for setuptools-scm to get version from tags
17 |
18 | - name: Set up Python
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: "3.12"
22 |
23 | - name: Install build tools
24 | run: python -m pip install build --user
25 |
26 | - name: Build wheel and sdist
27 | run: python -m build
28 |
29 | - name: Upload artifacts
30 | uses: actions/upload-artifact@v4
31 | with:
32 | name: python-package-distributions
33 | path: dist/
34 |
35 | publish-to-pypi:
36 | name: Publish to PyPI
37 | needs: build
38 | runs-on: ubuntu-latest
39 | environment:
40 | name: pypi
41 | url: https://pypi.org/p/pod_rbf
42 | permissions:
43 | id-token: write
44 | steps:
45 | - name: Download artifacts
46 | uses: actions/download-artifact@v4
47 | with:
48 | name: python-package-distributions
49 | path: dist/
50 |
51 | - name: Publish to PyPI
52 | uses: pypa/gh-action-pypi-publish@release/v1
53 |
54 | create-github-release:
55 | name: Create GitHub Release
56 | needs: publish-to-pypi
57 | runs-on: ubuntu-latest
58 | permissions:
59 | contents: write
60 | steps:
61 | - name: Checkout code
62 | uses: actions/checkout@v4
63 |
64 | - name: Create Release
65 | env:
66 | GH_TOKEN: ${{ github.token }}
67 | run: |
68 | gh release create ${{ github.ref_name }} \
69 | --title "Release ${{ github.ref_name }}" \
70 | --generate-notes
71 |
--------------------------------------------------------------------------------
/pod_rbf/types.py:
--------------------------------------------------------------------------------
1 | """
2 | Data structures for POD-RBF.
3 |
4 | All types are NamedTuples for JAX pytree compatibility.
5 | """
6 |
7 | from typing import NamedTuple
8 |
9 | from jax import Array
10 |
11 |
12 | class TrainConfig(NamedTuple):
13 | """Immutable training configuration."""
14 |
15 | energy_threshold: float = 0.99
16 | mem_limit_gb: float = 16.0
17 | cond_range: tuple[float, float] = (1e11, 1e12)
18 | max_bisection_iters: int = 50
19 | c_low_init: float = 0.011
20 | c_high_init: float = 1.0
21 | c_high_step: float = 0.01
22 | c_high_search_iters: int = 200
23 | poly_degree: int = 2 # Polynomial augmentation degree (0=none, 1=linear, 2=quadratic)
24 | kernel: str = "imq" # RBF kernel type: 'imq', 'gaussian', 'polyharmonic_spline'
25 | kernel_order: int = 3 # For polyharmonic splines only (order of r^k)
26 |
27 |
28 | class ModelState(NamedTuple):
29 | """Immutable trained model state - a valid JAX pytree."""
30 |
31 | basis: Array # Truncated POD basis (n_samples, n_modes)
32 | weights: Array # RBF network weights (n_modes, n_train_points)
33 | shape_factor: float | None # Optimized RBF shape parameter (None for PHS)
34 | train_params: Array # Training parameters (n_params, n_train_points)
35 | params_range: Array # Parameter ranges for normalization (n_params,)
36 | truncated_energy: float # Energy retained after truncation
37 | cumul_energy: Array # Cumulative energy per mode
38 | poly_coeffs: Array | None # Polynomial coefficients (n_modes, n_poly) or None
39 | poly_degree: int # Polynomial degree used (0=none)
40 | kernel: str # Kernel type used for training
41 | kernel_order: int # PHS order (ignored for other kernels)
42 |
43 |
44 | class TrainResult(NamedTuple):
45 | """Result from training, includes diagnostics."""
46 |
47 | state: ModelState
48 | n_modes: int
49 | used_eig_decomp: bool # True if eigendecomposition was used
50 |
--------------------------------------------------------------------------------
/docs/examples.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | ## Jupyter Notebooks
4 |
5 | Explore these example notebooks to see POD-RBF in action:
6 |
7 | ### Lid-Driven Cavity
8 |
9 | A complete walkthrough using CFD data from a 2D lid-driven cavity simulation at various Reynolds numbers.
10 |
11 | [:octicons-mark-github-16: View on GitHub](https://github.com/kylebeggs/POD-RBF/tree/master/examples/lid-driven-cavity){ .md-button }
12 |
13 | **What you'll learn:**
14 |
15 | - Building a snapshot matrix from CSV files
16 | - Training a single-parameter model
17 | - Visualizing predictions vs. ground truth
18 |
19 | ### Multi-Parameter Example
20 |
21 | Training a model with two input parameters.
22 |
23 | [:octicons-mark-github-16: View on GitHub](https://github.com/kylebeggs/POD-RBF/blob/master/examples/2-parameters.ipynb){ .md-button }
24 |
25 | **What you'll learn:**
26 |
27 | - Setting up multi-parameter training data
28 | - Inference with multiple parameters
29 | - Parameter space exploration
30 |
31 | ### Heat Conduction
32 |
33 | A simple heat conduction problem on a unit square.
34 |
35 | [:octicons-mark-github-16: View on GitHub](https://github.com/kylebeggs/POD-RBF/tree/master/examples/heat-conduction){ .md-button }
36 |
37 | **What you'll learn:**
38 |
39 | - Basic POD-RBF workflow
40 | - Working with thermal simulation data
41 |
42 | ### Shape Parameter Optimization
43 |
44 | Exploring RBF shape parameter selection.
45 |
46 | [:octicons-mark-github-16: View on GitHub](https://github.com/kylebeggs/POD-RBF/tree/master/examples/shape-optimization){ .md-button }
47 |
48 | **What you'll learn:**
49 |
50 | - How shape parameters affect interpolation
51 | - Automatic vs. manual shape parameter selection
52 |
53 | ## Running the Examples
54 |
55 | Clone the repository and install the package:
56 |
57 | ```bash
58 | git clone https://github.com/kylebeggs/POD-RBF.git
59 | cd POD-RBF
60 | pip install -e .
61 | ```
62 |
63 | Then open the Jupyter notebooks in the `examples/` directory:
64 |
65 | ```bash
66 | jupyter notebook examples/
67 | ```
68 |
--------------------------------------------------------------------------------
/docs/getting-started.md:
--------------------------------------------------------------------------------
1 | # Getting Started
2 |
3 | ## Installation
4 |
5 | Install POD-RBF using pip:
6 |
7 | ```bash
8 | pip install pod-rbf
9 | ```
10 |
11 | Or using uv:
12 |
13 | ```bash
14 | uv add pod-rbf
15 | ```
16 |
17 | ## Basic Workflow
18 |
19 | POD-RBF follows a simple three-step workflow:
20 |
21 | 1. **Build a snapshot matrix** - Collect solution data at different parameter values
22 | 2. **Train the model** - Compute POD basis and RBF interpolation weights
23 | 3. **Inference** - Predict solutions at new parameter values
24 |
25 | ## Minimal Example
26 |
27 | ```python
28 | import pod_rbf
29 | import jax.numpy as jnp
30 | import numpy as np
31 |
32 | # 1. Define training parameters
33 | params = np.array([1, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000])
34 |
35 | # 2. Build snapshot matrix from CSV files
36 | # Each CSV file contains one solution snapshot
37 | snapshot = pod_rbf.build_snapshot_matrix("path/to/data/")
38 |
39 | # 3. Train the model
40 | result = pod_rbf.train(snapshot, params)
41 |
42 | # 4. Predict at a new parameter value
43 | prediction = pod_rbf.inference_single(result.state, jnp.array(450.0))
44 | ```
45 |
46 | ## Understanding the Snapshot Matrix
47 |
48 | The snapshot matrix `X` has shape `(n_samples, n_snapshots)`:
49 |
50 | - Each **column** is one solution snapshot at a specific parameter value
51 | - Each **row** corresponds to a spatial location or degree of freedom
52 | - `n_samples` is the number of points in your solution (e.g., mesh nodes)
53 | - `n_snapshots` is the number of parameter values you trained on
54 |
55 | For example, if you solve a problem on a 400-node mesh at 10 different parameter values, your snapshot matrix is `(400, 10)`.
56 |
57 | !!! note "Parameter Ordering"
58 | The order of columns in the snapshot matrix must match the order of your parameter array. If column 5 contains the solution at Re=500, then `params[4]` must equal 500.
59 |
60 | ## What's Next?
61 |
62 | - [Training Models](user-guide/training.md) - Learn about training configuration options
63 | - [Inference](user-guide/inference.md) - Single-point and batch predictions
64 | - [Autodifferentiation](user-guide/autodiff.md) - Use JAX for gradients and optimization
65 |
--------------------------------------------------------------------------------
/docs/user-guide/inference.md:
--------------------------------------------------------------------------------
1 | # Inference
2 |
3 | After training, use the model to predict solutions at new parameter values.
4 |
5 | ## Single-Point Inference
6 |
7 | For predicting at a single parameter value:
8 |
9 | ```python
10 | import pod_rbf
11 | import jax.numpy as jnp
12 |
13 | # Train the model
14 | result = pod_rbf.train(snapshot, params)
15 |
16 | # Predict at a new parameter
17 | prediction = pod_rbf.inference_single(result.state, jnp.array(450.0))
18 | ```
19 |
20 | The output shape is `(n_samples,)` - the same as one column of your snapshot matrix.
21 |
22 | ## Batch Inference
23 |
24 | For predicting at multiple parameter values simultaneously:
25 |
26 | ```python
27 | # Predict at multiple parameters
28 | new_params = jnp.array([350.0, 450.0, 550.0])
29 | predictions = pod_rbf.inference(result.state, new_params)
30 | ```
31 |
32 | The output shape is `(n_samples, n_points)` where `n_points` is the number of parameter values.
33 |
34 | ## Multi-Parameter Inference
35 |
36 | For models trained with multiple parameters:
37 |
38 | ```python
39 | # Single point with 2 parameters
40 | param = jnp.array([450.0, 0.15]) # [Re, Ma]
41 | prediction = pod_rbf.inference_single(result.state, param)
42 |
43 | # Batch with 2 parameters
44 | params = jnp.array([
45 | [350.0, 0.1],
46 | [450.0, 0.15],
47 | [550.0, 0.2],
48 | ])
49 | predictions = pod_rbf.inference(result.state, params)
50 | ```
51 |
52 | ## Using a Saved Model
53 |
54 | Load a previously saved model and use it for inference:
55 |
56 | ```python
57 | state = pod_rbf.load_model("model.pkl")
58 | prediction = pod_rbf.inference_single(state, jnp.array(450.0))
59 | ```
60 |
61 | ## Performance Tips
62 |
63 | 1. **Use batch inference** when predicting at multiple parameter values - it's more efficient than calling `inference_single` in a loop.
64 |
65 | 2. **JIT compilation** - The inference functions are JAX-compatible and can be JIT-compiled for faster repeated calls:
66 |
67 | ```python
68 | import jax
69 |
70 | inference_jit = jax.jit(lambda p: pod_rbf.inference_single(state, p))
71 | prediction = inference_jit(jnp.array(450.0))
72 | ```
73 |
74 | 3. **GPU acceleration** - If JAX is configured with GPU support, inference will automatically use the GPU.
75 |
--------------------------------------------------------------------------------
/docs/user-guide/io.md:
--------------------------------------------------------------------------------
1 | # Saving & Loading
2 |
3 | ## Loading Snapshot Data
4 |
5 | ### From CSV Files
6 |
7 | Load snapshots from a directory of CSV files:
8 |
9 | ```python
10 | import pod_rbf
11 |
12 | snapshot = pod_rbf.build_snapshot_matrix("path/to/data/")
13 | ```
14 |
15 | By default, this:
16 |
17 | - Loads all CSV files from the directory in alphanumeric order
18 | - Skips the first row (header)
19 | - Uses the first column
20 |
21 | Customize with optional parameters:
22 |
23 | ```python
24 | snapshot = pod_rbf.build_snapshot_matrix(
25 | "path/to/data/",
26 | skiprows=1, # Skip first N rows (default: 1 for header)
27 | usecols=0, # Column index to use (default: 0)
28 | verbose=True, # Show progress bar (default: True)
29 | )
30 | ```
31 |
32 | ### From NumPy Arrays
33 |
34 | If your data is already in memory:
35 |
36 | ```python
37 | import numpy as np
38 |
39 | # Combine individual solutions into a snapshot matrix
40 | # Each column is one snapshot
41 | snapshot = np.column_stack([sol1, sol2, sol3, sol4, sol5])
42 | ```
43 |
44 | ## Saving Models
45 |
46 | Save a trained model to disk:
47 |
48 | ```python
49 | import pod_rbf
50 |
51 | result = pod_rbf.train(snapshot, params)
52 | pod_rbf.save_model("model.pkl", result.state)
53 | ```
54 |
55 | The model is saved as a pickle file containing the `ModelState` NamedTuple.
56 |
57 | ## Loading Models
58 |
59 | Load a previously saved model:
60 |
61 | ```python
62 | state = pod_rbf.load_model("model.pkl")
63 |
64 | # Use for inference
65 | prediction = pod_rbf.inference_single(state, jnp.array(450.0))
66 | ```
67 |
68 | ## Model State Contents
69 |
70 | The saved `ModelState` contains everything needed for inference:
71 |
72 | | Field | Description |
73 | |-------|-------------|
74 | | `basis` | Truncated POD basis matrix |
75 | | `weights` | RBF interpolation weights |
76 | | `shape_factor` | Optimized RBF shape parameter |
77 | | `train_params` | Training parameter values |
78 | | `params_range` | Parameter ranges for normalization |
79 | | `truncated_energy` | Energy retained after truncation |
80 | | `cumul_energy` | Cumulative energy per mode |
81 | | `poly_coeffs` | Polynomial coefficients (if used) |
82 | | `poly_degree` | Polynomial degree used |
83 | | `kernel` | Kernel type used |
84 | | `kernel_order` | PHS order (for polyharmonic splines) |
85 |
86 | ## File Format Notes
87 |
88 | - Models are saved using Python's `pickle` module
89 | - Files are portable across machines with the same Python/JAX versions
90 | - File size depends on the number of modes and training points
91 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # POD-RBF
2 |
3 | [](https://github.com/kylebeggs/POD-RBF/actions/workflows/tests.yml)
4 | [](https://codecov.io/gh/kylebeggs/POD-RBF)
5 | [](https://www.python.org/downloads/)
6 |
7 | A Python package for building Reduced Order Models (ROMs) from high-dimensional data using Proper Orthogonal Decomposition combined with Radial Basis Function interpolation.
8 |
9 | 
10 |
11 | ## Features
12 |
13 | - **JAX-based** - Enables autodifferentiation for gradient optimization, sensitivity analysis, and inverse problems
14 | - **Shape parameter optimization** - Automatic tuning of RBF shape parameters
15 | - **Memory-aware algorithms** - Switches between eigenvalue decomposition and SVD based on memory requirements
16 |
17 | ## Quick Install
18 |
19 | ```bash
20 | pip install pod-rbf
21 | ```
22 |
23 | ## Quick Example
24 |
25 | ```python
26 | import pod_rbf
27 | import jax.numpy as jnp
28 | import numpy as np
29 |
30 | # Define training parameters
31 | Re = np.array([1, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000])
32 |
33 | # Build snapshot matrix from CSV files
34 | train_snapshot = pod_rbf.build_snapshot_matrix("data/train/")
35 |
36 | # Train the model
37 | result = pod_rbf.train(train_snapshot, Re)
38 |
39 | # Inference on unseen parameter
40 | sol = pod_rbf.inference_single(result.state, jnp.array(450.0))
41 | ```
42 |
43 | ## Next Steps
44 |
45 | - [Getting Started](getting-started.md) - Installation and first steps
46 | - [User Guide](user-guide/training.md) - Detailed usage instructions
47 | - [API Reference](api/index.md) - Complete API documentation
48 | - [Examples](examples.md) - Jupyter notebook examples
49 |
50 | ## References
51 |
52 | This implementation is based on the following papers:
53 |
54 | 1. [Solving inverse heat conduction problems using trained POD-RBF network inverse method](https://www.tandfonline.com/doi/full/10.1080/17415970701198290) - Ostrowski, Bialecki, Kassab (2008)
55 | 2. [RBF-trained POD-accelerated CFD analysis of wind loads on PV systems](https://www.emerald.com/insight/content/doi/10.1108/HFF-03-2016-0083/full/html) - Huayamave et al. (2017)
56 | 3. [Real-Time Thermomechanical Modeling of PV Cell Fabrication via a POD-Trained RBF Interpolation Network](https://www.techscience.com/CMES/v122n3/38374) - Das et al. (2020)
57 |
--------------------------------------------------------------------------------
/examples/lid-driven-cavity/lid-driven-cavity.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": "import numpy as np\nimport matplotlib.pyplot as plt\nimport pod_rbf\n\nprint(\"using version: {}\".format(pod_rbf.__version__))\n\nRe = np.linspace(0, 1000, num=11)\nRe[0] = 1\n\ncoords_path = \"data/train/re-0001.csv\"\nx, y = np.loadtxt(\n coords_path,\n delimiter=\",\",\n skiprows=1,\n usecols=(1, 2),\n unpack=True,\n)\n\n# make snapshot matrix from csv files\ntrain_snapshot = pod_rbf.build_snapshot_matrix(\"data/train\")"
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": "import jax.numpy as jnp\n\n# load validation\nval = np.loadtxt(\n \"data/validation/re-0450.csv\",\n # \"data/validation/re-0050.csv\",\n delimiter=\",\",\n skiprows=1,\n usecols=(0),\n unpack=True,\n)\n\n# train the model\nconfig = pod_rbf.TrainConfig(energy_threshold=0.9, poly_degree=2) # poly_degree: 0=none, 1=linear, 2=quadratic\nresult = pod_rbf.train(train_snapshot, Re, config)\nstate = result.state\nprint(\"Energy kept after truncating = {}%\".format(state.truncated_energy))\n\n# plot the energy decay\nplt.plot(state.cumul_energy)\n\n\n# inference the model on an unseen parameter\nsol = pod_rbf.inference_single(state, jnp.array(450.0))\n\n# calculate and plot the difference between inference and actual\ndiff = np.nan_to_num(np.abs(sol - val))\nprint(\"Average Percent Error = {}\".format(np.mean(diff)))\n\n# plot the inferenced solution\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(22, 9))\nax1.set_title(\"POD-RBF\", fontsize=40)\ncntr1 = ax1.tricontourf(\n x, y, sol, levels=np.linspace(0, 1, num=20), cmap=\"viridis\", extend=\"both\"\n)\nax1.set_xticks([])\nax1.set_yticks([])\n# plot the actual solution\nax2.set_title(\"Target\", fontsize=40)\ncntr2 = ax2.tricontourf(\n x, y, val, levels=np.linspace(0, 1, num=20), cmap=\"viridis\", extend=\"both\"\n)\ncbar_ax = fig.add_axes([0.485, 0.15, 0.025, 0.7])\nfig.colorbar(cntr2, cax=cbar_ax)\nax2.set_xticks([])\nax2.set_yticks([])\n# fig.tight_layout()\n\nfig2, ax = plt.subplots(1, 1, figsize=(12, 9))\ncntr = ax.tricontourf(x, y, diff, cmap=\"viridis\", extend=\"both\")\nfig2.colorbar(cntr)\nplt.show()"
16 | }
17 | ],
18 | "metadata": {
19 | "interpreter": {
20 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
21 | },
22 | "kernelspec": {
23 | "display_name": "Python 3.8.5 64-bit",
24 | "name": "python3"
25 | },
26 | "language_info": {
27 | "name": "python",
28 | "version": ""
29 | },
30 | "metadata": {
31 | "interpreter": {
32 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
33 | }
34 | },
35 | "orig_nbformat": 2
36 | },
37 | "nbformat": 4,
38 | "nbformat_minor": 2
39 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # VS Code
2 | *.code-workspace
3 | .vscode
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # uv
102 | # For libraries, uv.lock should not be committed as it locks dependencies for development only.
103 | # Library users will resolve dependencies based on pyproject.toml.
104 | uv.lock
105 |
106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107 | __pypackages__/
108 |
109 | # Celery stuff
110 | celerybeat-schedule
111 | celerybeat.pid
112 |
113 | # SageMath parsed files
114 | *.sage.py
115 |
116 | # Environments
117 | .env
118 | .venv
119 | env/
120 | venv/
121 | ENV/
122 | env.bak/
123 | venv.bak/
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 | .dmypy.json
138 | dmypy.json
139 |
140 | # Pyre type checker
141 | .pyre/
142 |
143 | # pytype static type analyzer
144 | .pytype/
145 |
146 | # Cython debug symbols
147 | cython_debug/
148 |
--------------------------------------------------------------------------------
/examples/heat-conduction/2d-heat.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import time
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def buildSnapshotMatrix(params, num_points):
8 | """
9 | Assemble the snapshot matrix
10 | """
11 | print("making the snapshot matrix... ", end="")
12 | start = time.time()
13 |
14 | # evaluate the analytical solution
15 | num_terms = 50
16 | L = 1
17 | T_L = params[0, :]
18 |
19 | n = np.arange(0, num_terms)
20 | # calculate lambdas
21 | lambs = np.pi * (2 * n + 1) / (2 * L)
22 |
23 | # define points
24 | x = np.linspace(0, L, num=num_points)
25 | X, Y = np.meshgrid(x, x, indexing="xy")
26 |
27 | snapshot = np.zeros((num_points ** 2, len(T_L)))
28 | for i in range(len(T_L)):
29 | # calculate constants
30 | C = (
31 | 8
32 | * T_L[i]
33 | * (2 * (-1) ** n / (lambs * L) - 1)
34 | / ((lambs * L) ** 2 * np.cosh(lambs * L))
35 | )
36 | T = np.zeros_like(X)
37 | for j in range(0, num_terms):
38 | T = T + C[j] * np.cosh(lambs[j] * X) * np.cos(lambs[j] * Y)
39 | snapshot[:, i] = T.flatten()
40 |
41 | print("took {:3.3f} sec".format(time.time() - start))
42 |
43 | return snapshot
44 |
45 |
46 | if __name__ == "__main__":
47 |
48 | import jax
49 | import jax.numpy as jnp
50 | import pod_rbf
51 |
52 | jax.config.update('jax_default_device', jax.devices('cpu')[0]) # Change to 'gpu' or 'tpu' for accelerators
53 |
54 | T_L = np.linspace(1, 100, num=11)
55 | T_L = np.expand_dims(T_L, axis=0)
56 | T_L_test = 55.0
57 | num_points = 41
58 |
59 | # make snapshot matrix
60 | snapshot = buildSnapshotMatrix(T_L, num_points)
61 |
62 | # calculate 'test' solution
63 | # evaluate the analytical solution
64 | num_terms = 50
65 | L = 1
66 | n = np.arange(0, num_terms)
67 | # calculate lambdas
68 | lambs = np.pi * (2 * n + 1) / (2 * L)
69 | # define points
70 | x = np.linspace(0, L, num=num_points)
71 | X, Y = np.meshgrid(x, x, indexing="xy")
72 | # calculate constants
73 | C = (
74 | 8
75 | * T_L_test
76 | * (2 * (-1) ** n / (lambs * L) - 1)
77 | / ((lambs * L) ** 2 * np.cosh(lambs * L))
78 | )
79 | T_test = np.zeros_like(X)
80 | for n in range(0, num_terms):
81 | T_test = T_test + C[n] * np.cosh(lambs[n] * X) * np.cos(lambs[n] * Y)
82 |
83 | # train the POD-RBF model
84 | config = pod_rbf.TrainConfig(energy_threshold=0.5, poly_degree=2)
85 | result = pod_rbf.train(snapshot, T_L, config)
86 | state = result.state
87 |
88 | # inference the trained model
89 | sol = pod_rbf.inference_single(state, jnp.array(T_L_test))
90 |
91 | print("Energy kept after truncating = {}%".format(state.truncated_energy))
92 | print("Cumulative Energy = {}%".format(state.cumul_energy))
93 |
94 | fig = plt.figure(figsize=(12, 9))
95 | c = plt.pcolormesh(T_test, cmap="magma")
96 | fig.colorbar(c)
97 |
98 | fig = plt.figure(figsize=(12, 9))
99 | c = plt.pcolormesh(sol.reshape((num_points, num_points)), cmap="magma")
100 | fig.colorbar(c)
101 |
102 | fig = plt.figure(figsize=(12, 9))
103 | diff = np.abs(sol.reshape((num_points, num_points)) - T_test) / T_test * 100
104 | c = plt.pcolormesh(diff, cmap="magma")
105 | fig.colorbar(c)
106 |
107 | plt.show()
108 |
--------------------------------------------------------------------------------
/docs/user-guide/autodiff.md:
--------------------------------------------------------------------------------
1 | # Autodifferentiation
2 |
3 | POD-RBF is built on JAX, enabling automatic differentiation through the inference functions. This is useful for optimization, sensitivity analysis, and inverse problems.
4 |
5 | ## Computing Gradients
6 |
7 | Use `jax.grad` to compute gradients with respect to parameters:
8 |
9 | ```python
10 | import jax
11 | import jax.numpy as jnp
12 | import pod_rbf
13 |
14 | # Train model
15 | result = pod_rbf.train(snapshot, params)
16 | state = result.state
17 |
18 | # Define an objective function
19 | def objective(param):
20 | prediction = pod_rbf.inference_single(state, param)
21 | return jnp.sum(prediction ** 2)
22 |
23 | # Compute gradient
24 | grad_fn = jax.grad(objective)
25 | gradient = grad_fn(jnp.array(450.0))
26 | ```
27 |
28 | ## Optimization Example
29 |
30 | Find the parameter value that minimizes a cost function:
31 |
32 | ```python
33 | import jax
34 | import jax.numpy as jnp
35 | from jax import grad
36 |
37 | def cost_function(param, target):
38 | prediction = pod_rbf.inference_single(state, param)
39 | return jnp.mean((prediction - target) ** 2)
40 |
41 | # Gradient descent
42 | param = jnp.array(500.0) # Initial guess
43 | learning_rate = 10.0
44 |
45 | for i in range(100):
46 | grad_val = grad(cost_function)(param, target_solution)
47 | param = param - learning_rate * grad_val
48 |
49 | print(f"Optimal parameter: {param}")
50 | ```
51 |
52 | ## Inverse Problems
53 |
54 | For inverse problems where you want to find the parameter that produced an observed solution:
55 |
56 | ```python
57 | import jax
58 | import jax.numpy as jnp
59 | from jax.scipy.optimize import minimize
60 |
61 | def inverse_objective(param):
62 | prediction = pod_rbf.inference_single(state, param)
63 | return jnp.sum((prediction - observed_solution) ** 2)
64 |
65 | # Use BFGS optimization
66 | result = minimize(
67 | inverse_objective,
68 | x0=jnp.array(500.0),
69 | method="BFGS",
70 | )
71 |
72 | recovered_param = result.x
73 | ```
74 |
75 | ## Sensitivity Analysis
76 |
77 | Compute how sensitive the solution is to parameter changes:
78 |
79 | ```python
80 | import jax
81 | import jax.numpy as jnp
82 |
83 | # Jacobian: how each output point changes with the parameter
84 | jacobian_fn = jax.jacobian(
85 | lambda p: pod_rbf.inference_single(state, p)
86 | )
87 | sensitivity = jacobian_fn(jnp.array(450.0))
88 |
89 | # sensitivity shape: (n_samples,) for single parameter
90 | # Positive values indicate the solution increases with the parameter
91 | ```
92 |
93 | ## Multi-Parameter Gradients
94 |
95 | For models with multiple parameters:
96 |
97 | ```python
98 | def objective(params):
99 | # params: [Re, Ma]
100 | prediction = pod_rbf.inference_single(state, params)
101 | return jnp.sum(prediction ** 2)
102 |
103 | # Gradient with respect to all parameters
104 | grad_fn = jax.grad(objective)
105 | gradients = grad_fn(jnp.array([450.0, 0.15]))
106 | # gradients shape: (2,) - one gradient per parameter
107 | ```
108 |
109 | ## JIT Compilation
110 |
111 | For performance, JIT-compile your gradient functions:
112 |
113 | ```python
114 | @jax.jit
115 | def compute_gradient(param):
116 | return jax.grad(objective)(param)
117 |
118 | # First call compiles; subsequent calls are fast
119 | gradient = compute_gradient(jnp.array(450.0))
120 | ```
121 |
122 | ## Higher-Order Derivatives
123 |
124 | JAX supports higher-order derivatives:
125 |
126 | ```python
127 | # Second derivative (Hessian for scalar output)
128 | hessian_fn = jax.hessian(objective)
129 | hessian = hessian_fn(jnp.array(450.0))
130 | ```
131 |
--------------------------------------------------------------------------------
/CLAUDE.md:
--------------------------------------------------------------------------------
1 | # CLAUDE.md
2 |
3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4 |
5 | ## Project Overview
6 |
7 | POD-RBF is a JAX-based Python library for building Reduced Order Models (ROMs) using Proper Orthogonal Decomposition combined with Radial Basis Function interpolation. It enables autodifferentiation for gradient optimization, sensitivity analysis, and inverse problems.
8 |
9 | ## Development Commands
10 |
11 | ```bash
12 | # Install for development
13 | pip install -e ".[dev]"
14 |
15 | # Run tests
16 | pytest tests/ -v
17 |
18 | # Run single test file
19 | pytest tests/test_core.py -v
20 |
21 | # Run specific test
22 | pytest tests/test_core.py::TestGradients::test_inverse_problem -v
23 | ```
24 |
25 | ## Architecture
26 |
27 | ### Module Structure
28 |
29 | ```
30 | pod_rbf/
31 | __init__.py # Public API, enables float64
32 | types.py # ModelState, TrainConfig, TrainResult (NamedTuples)
33 | core.py # train(), inference(), inference_single()
34 | rbf.py # build_collocation_matrix(), build_inference_matrix()
35 | decomposition.py # compute_pod_basis_svd(), compute_pod_basis_eig()
36 | shape_optimization.py # find_optimal_shape_param() (fixed-iteration bisection)
37 | io.py # build_snapshot_matrix(), save_model(), load_model()
38 | ```
39 |
40 | ### Key Types
41 |
42 | ```python
43 | class ModelState(NamedTuple):
44 | basis: Array # (n_samples, n_modes)
45 | weights: Array # (n_modes, n_train_points)
46 | shape_factor: float
47 | train_params: Array # (n_params, n_train_points)
48 | params_range: Array # (n_params,)
49 | truncated_energy: float
50 | cumul_energy: Array
51 |
52 | class TrainConfig(NamedTuple):
53 | energy_threshold: float = 0.99
54 | mem_limit_gb: float = 16.0
55 | cond_range: tuple = (1e11, 1e12)
56 | max_bisection_iters: int = 50
57 | ```
58 |
59 | ### API
60 |
61 | ```python
62 | import pod_rbf
63 | import jax
64 | import jax.numpy as jnp
65 |
66 | # Train with default config (energy_threshold=0.99)
67 | result = pod_rbf.train(snapshot, params)
68 | state = result.state
69 |
70 | # Train with custom config
71 | config = pod_rbf.TrainConfig(energy_threshold=0.9)
72 | result = pod_rbf.train(snapshot, params, config)
73 |
74 | # Inference (single point)
75 | pred = pod_rbf.inference_single(state, jnp.array(450.0))
76 |
77 | # Inference (batch)
78 | preds = pod_rbf.inference(state, jnp.array([400.0, 450.0, 500.0]))
79 |
80 | # Autodiff
81 | grad_fn = jax.grad(lambda p: jnp.sum(pod_rbf.inference_single(state, p)**2))
82 | gradient = grad_fn(jnp.array(450.0))
83 |
84 | # I/O
85 | snapshot = pod_rbf.build_snapshot_matrix("data/train/") # load CSVs from directory
86 | pod_rbf.save_model("model.pkl", state)
87 | state = pod_rbf.load_model("model.pkl")
88 | ```
89 |
90 | ### Data Shape Conventions
91 |
92 | - **Snapshot matrix**: `(n_samples, n_snapshots)` - each column is one parameter's solution
93 | - **Parameters**: 1D `(n_snapshots,)` or 2D `(n_params, n_snapshots)`
94 | - **Inference output**: `(n_samples,)` for single, `(n_samples, n_points)` for batch
95 |
96 | ### Key Algorithms
97 |
98 | 1. **POD truncation**: Keeps modes until cumulative energy exceeds `energy_threshold`
99 | 2. **RBF kernel**: Hardy Inverse Multi-Quadrics: `1/√(r²/c² + 1)`
100 | 3. **Shape optimization**: Fixed-iteration bisection (50 iters) for condition number in [10^11, 10^12]
101 |
102 | ## Dependencies
103 |
104 | - `jax>=0.4.0`, `jaxlib>=0.4.0` - Autodiff and JIT compilation
105 | - `numpy` - File I/O operations
106 | - `tqdm` - Progress bars
107 | - Python ≥ 3.10
108 |
--------------------------------------------------------------------------------
/docs/user-guide/training.md:
--------------------------------------------------------------------------------
1 | # Training Models
2 |
3 | ## Building the Snapshot Matrix
4 |
5 | The snapshot matrix contains your training data. Each column is a solution snapshot at a specific parameter value.
6 |
7 | ### From CSV Files
8 |
9 | If your snapshots are stored as individual CSV files in a directory:
10 |
11 | ```python
12 | import pod_rbf
13 |
14 | snapshot = pod_rbf.build_snapshot_matrix("path/to/data/")
15 | ```
16 |
17 | Files are loaded in alphanumeric order. The function expects one value per row in each CSV file.
18 |
19 | !!! tip "File Organization"
20 | Keep all training snapshots in a dedicated directory. Files are sorted alphanumerically, so use consistent naming like `snapshot_001.csv`, `snapshot_002.csv`, etc.
21 |
22 | ### From Arrays
23 |
24 | If you already have your data in memory:
25 |
26 | ```python
27 | import numpy as np
28 |
29 | # Shape: (n_samples, n_snapshots)
30 | snapshot = np.column_stack([solution1, solution2, solution3, ...])
31 | ```
32 |
33 | ## Training
34 |
35 | ### Basic Training
36 |
37 | ```python
38 | import pod_rbf
39 | import numpy as np
40 |
41 | params = np.array([100, 200, 300, 400, 500])
42 | result = pod_rbf.train(snapshot, params)
43 | ```
44 |
45 | The `result` object contains:
46 |
47 | - `result.state` - The trained model state (use this for inference)
48 | - `result.n_modes` - Number of POD modes retained
49 | - `result.used_eig_decomp` - Whether eigendecomposition was used (vs SVD)
50 |
51 | ### Training Configuration
52 |
53 | Customize training with `TrainConfig`:
54 |
55 | ```python
56 | from pod_rbf import TrainConfig
57 |
58 | config = TrainConfig(
59 | energy_threshold=0.99, # Keep modes until 99% energy retained
60 | kernel="imq", # RBF kernel: 'imq', 'gaussian', 'polyharmonic_spline'
61 | poly_degree=2, # Polynomial augmentation: 0=none, 1=linear, 2=quadratic
62 | )
63 |
64 | result = pod_rbf.train(snapshot, params, config)
65 | ```
66 |
67 | ### Configuration Options
68 |
69 | | Parameter | Default | Description |
70 | |-----------|---------|-------------|
71 | | `energy_threshold` | 0.99 | POD truncation threshold (0-1) |
72 | | `kernel` | `"imq"` | RBF kernel type |
73 | | `poly_degree` | 2 | Polynomial augmentation degree |
74 | | `mem_limit_gb` | 16.0 | Memory limit for algorithm selection |
75 | | `cond_range` | (1e11, 1e12) | Target condition number range |
76 | | `max_bisection_iters` | 50 | Max iterations for shape optimization |
77 |
78 | ### Kernel Options
79 |
80 | POD-RBF supports three RBF kernels:
81 |
82 | - **`imq`** (Inverse Multi-Quadrics) - Default, good general-purpose choice
83 | - **`gaussian`** - Smoother interpolation, requires careful shape parameter tuning
84 | - **`polyharmonic_spline`** - No shape parameter needed, use with `kernel_order`
85 |
86 | ```python
87 | # Using polyharmonic splines (no shape parameter optimization)
88 | config = TrainConfig(kernel="polyharmonic_spline", kernel_order=3)
89 | result = pod_rbf.train(snapshot, params, config)
90 | ```
91 |
92 | ## Multi-Parameter Training
93 |
94 | For problems with multiple parameters:
95 |
96 | ```python
97 | import numpy as np
98 |
99 | # Parameters shape: (n_params, n_snapshots)
100 | params = np.array([
101 | [100, 200, 300, 400], # Parameter 1 (e.g., Reynolds number)
102 | [0.1, 0.1, 0.2, 0.2], # Parameter 2 (e.g., Mach number)
103 | ])
104 |
105 | result = pod_rbf.train(snapshot, params)
106 | ```
107 |
108 | See the [2-parameter example](https://github.com/kylebeggs/POD-RBF/blob/master/examples/2-parameters.ipynb) for a complete walkthrough.
109 |
110 | ## Manual Shape Parameter
111 |
112 | If you want to specify the RBF shape parameter instead of using automatic optimization:
113 |
114 | ```python
115 | result = pod_rbf.train(snapshot, params, shape_factor=0.5)
116 | ```
117 |
118 | ## Understanding POD Truncation
119 |
120 | POD extracts the dominant modes from your snapshot data. The `energy_threshold` controls how many modes are kept:
121 |
122 | - `0.99` (default) - Keep modes until 99% of total energy is captured
123 | - Higher values retain more modes (more accurate, slower inference)
124 | - Lower values retain fewer modes (faster inference, may lose accuracy)
125 |
126 | After training, check how much energy was retained:
127 |
128 | ```python
129 | print(f"Modes retained: {result.n_modes}")
130 | print(f"Energy retained: {result.state.truncated_energy:.4f}")
131 | ```
132 |
--------------------------------------------------------------------------------
/pod_rbf/kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | RBF kernel functions and dispatcher.
3 |
4 | Supports multiple kernel types:
5 | - Inverse Multi-Quadrics (IMQ): phi(r) = 1 / sqrt(r²/c² + 1)
6 | - Gaussian: phi(r) = exp(-r²/c²)
7 | - Polyharmonic Splines (PHS): phi(r) = r^k or r^k*log(r)
8 | """
9 |
10 | from enum import Enum
11 |
12 | import jax.numpy as jnp
13 | from jax import Array
14 |
15 |
16 | class KernelType(str, Enum):
17 | """RBF kernel types (internal use only)."""
18 |
19 | IMQ = "imq"
20 | GAUSSIAN = "gaussian"
21 | POLYHARMONIC_SPLINE = "polyharmonic_spline"
22 |
23 |
24 | def kernel_imq(r2: Array, shape_factor: float) -> Array:
25 | """
26 | Inverse Multiquadrics kernel.
27 |
28 | phi(r) = 1 / sqrt(r²/c² + 1)
29 |
30 | Parameters
31 | ----------
32 | r2 : Array
33 | Squared distances between points.
34 | shape_factor : float
35 | Shape parameter c.
36 |
37 | Returns
38 | -------
39 | Array
40 | Kernel values, same shape as r2.
41 | """
42 | return 1.0 / jnp.sqrt(r2 / (shape_factor**2) + 1.0)
43 |
44 |
45 | def kernel_gaussian(r2: Array, shape_factor: float) -> Array:
46 | """
47 | Gaussian kernel.
48 |
49 | phi(r) = exp(-r²/c²)
50 |
51 | Parameters
52 | ----------
53 | r2 : Array
54 | Squared distances between points.
55 | shape_factor : float
56 | Shape parameter c.
57 |
58 | Returns
59 | -------
60 | Array
61 | Kernel values, same shape as r2.
62 | """
63 | return jnp.exp(-r2 / (shape_factor**2))
64 |
65 |
66 | def kernel_polyharmonic_spline(r2: Array, order: int) -> Array:
67 | """
68 | Polyharmonic spline kernel.
69 |
70 | - Odd order k: phi(r) = r^k
71 | - Even order k: phi(r) = r^k * log(r)
72 |
73 | Parameters
74 | ----------
75 | r2 : Array
76 | Squared distances between points.
77 | order : int
78 | Polynomial order (typically 1-5).
79 | Odd: r, r³, r⁵
80 | Even: r²log(r), r⁴log(r)
81 |
82 | Returns
83 | -------
84 | Array
85 | Kernel values, same shape as r2.
86 |
87 | Notes
88 | -----
89 | For even orders, handles r=0 case to avoid log(0) singularity.
90 | Uses r2-based formulation to avoid gradient singularity from sqrt at r2=0.
91 | """
92 | if order % 2 == 1:
93 | # Odd order: r^k = (r2)^(k/2)
94 | # Using power of r2 directly avoids sqrt gradient singularity at r2=0
95 | return jnp.power(r2, order / 2.0)
96 | else:
97 | # Even order: r^k * log(r) = (r2)^(k/2) * log(sqrt(r2))
98 | # = (r2)^(k/2) * (1/2) * log(r2)
99 | # Handle r2=0 case: set to 0 when r2 < threshold
100 | return jnp.where(
101 | r2 > 1e-30,
102 | jnp.power(r2, order / 2.0) * 0.5 * jnp.log(r2),
103 | 0.0,
104 | )
105 |
106 |
107 | def apply_kernel(
108 | r2: Array,
109 | kernel: str,
110 | shape_factor: float | None,
111 | kernel_order: int,
112 | ) -> Array:
113 | """
114 | Apply RBF kernel to distance matrix.
115 |
116 | Dispatcher function that selects and applies the appropriate kernel
117 | based on the kernel type string.
118 |
119 | Parameters
120 | ----------
121 | r2 : Array
122 | Squared distances between points.
123 | kernel : str
124 | Kernel type: 'imq', 'gaussian', or 'polyharmonic_spline'.
125 | shape_factor : float | None
126 | Shape parameter for IMQ and Gaussian kernels.
127 | Ignored for polyharmonic splines.
128 | kernel_order : int
129 | Order for polyharmonic splines.
130 | Ignored for other kernels.
131 |
132 | Returns
133 | -------
134 | Array
135 | Kernel values applied to distance matrix.
136 |
137 | Raises
138 | ------
139 | ValueError
140 | If kernel type is not recognized.
141 | """
142 | kernel_type = KernelType(kernel)
143 |
144 | if kernel_type == KernelType.IMQ:
145 | return kernel_imq(r2, shape_factor)
146 | elif kernel_type == KernelType.GAUSSIAN:
147 | return kernel_gaussian(r2, shape_factor)
148 | elif kernel_type == KernelType.POLYHARMONIC_SPLINE:
149 | return kernel_polyharmonic_spline(r2, kernel_order)
150 | else:
151 | raise ValueError(f"Unknown kernel type: {kernel}")
152 |
153 |
154 | # Kernel-specific defaults for shape parameter optimization
155 | KERNEL_SHAPE_DEFAULTS = {
156 | "imq": {"c_low": 0.011, "c_high": 1.0, "c_step": 0.01},
157 | "gaussian": {"c_low": 0.1, "c_high": 10.0, "c_step": 0.1},
158 | }
159 |
--------------------------------------------------------------------------------
/pod_rbf/shape_optimization.py:
--------------------------------------------------------------------------------
1 | """
2 | Shape parameter optimization for RBF interpolation.
3 |
4 | Uses fixed-iteration bisection to find optimal shape parameter c such that
5 | the collocation matrix condition number falls within a target range.
6 | """
7 |
8 | import jax
9 | import jax.numpy as jnp
10 | from jax import Array
11 |
12 | from .kernels import KERNEL_SHAPE_DEFAULTS
13 | from .rbf import build_collocation_matrix
14 |
15 |
16 | def find_optimal_shape_param(
17 | train_params: Array,
18 | params_range: Array,
19 | kernel: str = "imq",
20 | kernel_order: int = 3,
21 | cond_range: tuple[float, float] = (1e11, 1e12),
22 | max_iters: int = 50,
23 | c_low_init: float | None = None,
24 | c_high_init: float | None = None,
25 | c_high_step: float | None = None,
26 | c_high_search_iters: int = 200,
27 | ) -> float | None:
28 | """
29 | Find optimal RBF shape parameter via fixed-iteration bisection.
30 |
31 | Target: condition number in [cond_range[0], cond_range[1]].
32 |
33 | Parameters
34 | ----------
35 | train_params : Array
36 | Training parameters, shape (n_params, n_train_points).
37 | params_range : Array
38 | Range of each parameter for normalization, shape (n_params,).
39 | kernel : str, optional
40 | Kernel type: 'imq', 'gaussian', or 'polyharmonic_spline'.
41 | Default is 'imq'.
42 | kernel_order : int, optional
43 | Order for polyharmonic splines (default 3).
44 | Ignored for other kernels.
45 | cond_range : tuple
46 | Target condition number range (lower, upper).
47 | max_iters : int
48 | Maximum bisection iterations.
49 | c_low_init : float | None, optional
50 | Initial lower bound for shape parameter.
51 | If None, uses kernel-specific default.
52 | c_high_init : float | None, optional
53 | Initial upper bound for shape parameter.
54 | If None, uses kernel-specific default.
55 | c_high_step : float | None, optional
56 | Step size for expanding upper bound search.
57 | If None, uses kernel-specific default.
58 | c_high_search_iters : int
59 | Maximum iterations for upper bound search.
60 |
61 | Returns
62 | -------
63 | float | None
64 | Optimal shape parameter.
65 | Returns None for kernels that don't use shape parameters (e.g., PHS).
66 | """
67 | # PHS doesn't use shape parameters
68 | if kernel == "polyharmonic_spline":
69 | return None
70 |
71 | # Use kernel-specific defaults if not provided
72 | defaults = KERNEL_SHAPE_DEFAULTS.get(kernel, KERNEL_SHAPE_DEFAULTS["imq"])
73 | c_low_init = c_low_init or defaults["c_low"]
74 | c_high_init = c_high_init or defaults["c_high"]
75 | c_high_step = c_high_step or defaults["c_step"]
76 |
77 | cond_low, cond_high = cond_range
78 |
79 | # Step 1: Find upper bound where cond >= cond_low
80 | def search_c_high_iter(i: int, carry: tuple) -> tuple:
81 | c_high, found = carry
82 | C = build_collocation_matrix(
83 | train_params, params_range, kernel, c_high, kernel_order
84 | )
85 | cond = jnp.linalg.cond(C)
86 | should_continue = (~found) & (cond < cond_low)
87 | new_c_high = jnp.where(should_continue, c_high + c_high_step, c_high)
88 | new_found = found | (cond >= cond_low)
89 | return (new_c_high, new_found)
90 |
91 | c_high, _ = jax.lax.fori_loop(
92 | 0, c_high_search_iters, search_c_high_iter, (c_high_init, False)
93 | )
94 |
95 | # Step 2: Bisection to find optimal c in range
96 | def bisection_iter(i: int, carry: tuple) -> tuple:
97 | c_low_bound, c_high_bound, optim_c, found = carry
98 |
99 | mid_c = (c_low_bound + c_high_bound) / 2.0
100 | C = build_collocation_matrix(
101 | train_params, params_range, kernel, mid_c, kernel_order
102 | )
103 | cond = jnp.linalg.cond(C)
104 |
105 | # Check if condition number is in target range
106 | in_range = (cond >= cond_low) & (cond <= cond_high)
107 | below_range = cond < cond_low
108 |
109 | # Update bounds based on condition number (only if not yet found)
110 | new_c_low = jnp.where(below_range & ~found, mid_c, c_low_bound)
111 | new_c_high = jnp.where((~below_range) & (~in_range) & ~found, mid_c, c_high_bound)
112 | new_optim_c = jnp.where(in_range & ~found, mid_c, optim_c)
113 | new_found = found | in_range
114 |
115 | return (new_c_low, new_c_high, new_optim_c, new_found)
116 |
117 | initial_guess = (c_low_init + c_high) / 2.0
118 | _, _, optim_c, _ = jax.lax.fori_loop(
119 | 0, max_iters, bisection_iter, (c_low_init, c_high, initial_guess, False)
120 | )
121 |
122 | return optim_c
123 |
--------------------------------------------------------------------------------
/pod_rbf/decomposition.py:
--------------------------------------------------------------------------------
1 | """
2 | POD basis computation via SVD or eigendecomposition.
3 | """
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | from jax import Array
8 |
9 |
10 | def compute_pod_basis_svd(
11 | snapshot: Array,
12 | energy_threshold: float,
13 | ) -> tuple[Array, Array, float]:
14 | """
15 | Compute truncated POD basis via SVD.
16 |
17 | Use for smaller datasets (< mem_limit_gb).
18 |
19 | Parameters
20 | ----------
21 | snapshot : Array
22 | Snapshot matrix, shape (n_samples, n_snapshots).
23 | energy_threshold : float
24 | Minimum fraction of total energy to retain (0 < threshold <= 1).
25 |
26 | Returns
27 | -------
28 | basis : Array
29 | Truncated POD basis, shape (n_samples, n_modes).
30 | cumul_energy : Array
31 | Cumulative energy fraction per mode.
32 | truncated_energy : float
33 | Actual energy fraction retained.
34 | """
35 | U, S, _ = jnp.linalg.svd(snapshot, full_matrices=False)
36 |
37 | cumul_energy = jnp.cumsum(S) / jnp.sum(S)
38 |
39 | # Handle energy_threshold >= 1 (keep all modes)
40 | keep_all = energy_threshold >= 1.0
41 | # Find first index where cumul_energy > threshold
42 | mask = cumul_energy > energy_threshold
43 | trunc_id = jnp.where(
44 | keep_all,
45 | len(S) - 1,
46 | jnp.where(jnp.any(mask), jnp.argmax(mask), len(S) - 1),
47 | )
48 |
49 | truncated_energy = cumul_energy[trunc_id]
50 |
51 | # Dynamic slice to get truncated basis
52 | basis = jax.lax.dynamic_slice(U, (0, 0), (U.shape[0], trunc_id + 1))
53 |
54 | return basis, cumul_energy, truncated_energy
55 |
56 |
57 | def compute_pod_basis_eig(
58 | snapshot: Array,
59 | energy_threshold: float,
60 | ) -> tuple[Array, Array, float]:
61 | """
62 | Compute truncated POD basis via eigendecomposition.
63 |
64 | More memory-efficient for large datasets (>= mem_limit_gb).
65 | Computes (n_snapshots x n_snapshots) covariance instead of full SVD.
66 |
67 | Parameters
68 | ----------
69 | snapshot : Array
70 | Snapshot matrix, shape (n_samples, n_snapshots).
71 | energy_threshold : float
72 | Minimum fraction of total energy to retain (0 < threshold <= 1).
73 |
74 | Returns
75 | -------
76 | basis : Array
77 | Truncated POD basis, shape (n_samples, n_modes).
78 | cumul_energy : Array
79 | Cumulative energy fraction per mode.
80 | truncated_energy : float
81 | Actual energy fraction retained.
82 | """
83 | # Covariance matrix (n_snapshots x n_snapshots)
84 | cov = snapshot.T @ snapshot
85 | eig_vals, eig_vecs = jnp.linalg.eigh(cov)
86 |
87 | # eigh returns ascending order, reverse to descending
88 | eig_vals = jnp.abs(eig_vals[::-1])
89 | eig_vecs = eig_vecs[:, ::-1]
90 |
91 | cumul_energy = jnp.cumsum(eig_vals) / jnp.sum(eig_vals)
92 |
93 | # Handle energy_threshold >= 1 (keep all modes)
94 | keep_all = energy_threshold >= 1.0
95 | mask = cumul_energy > energy_threshold
96 | trunc_id = jnp.where(
97 | keep_all,
98 | len(eig_vals) - 1,
99 | jnp.where(jnp.any(mask), jnp.argmax(mask), len(eig_vals) - 1),
100 | )
101 |
102 | truncated_energy = cumul_energy[trunc_id]
103 |
104 | # Truncate eigenvalues and eigenvectors
105 | eig_vals_trunc = jax.lax.dynamic_slice(eig_vals, (0,), (trunc_id + 1,))
106 | eig_vecs_trunc = jax.lax.dynamic_slice(
107 | eig_vecs, (0, 0), (eig_vecs.shape[0], trunc_id + 1)
108 | )
109 |
110 | # Compute POD basis from eigenvectors
111 | basis = (snapshot @ eig_vecs_trunc) / jnp.sqrt(eig_vals_trunc)
112 |
113 | return basis, cumul_energy, truncated_energy
114 |
115 |
116 | def compute_pod_basis(
117 | snapshot: Array,
118 | energy_threshold: float,
119 | use_eig: bool = False,
120 | ) -> tuple[Array, Array, float]:
121 | """
122 | Compute truncated POD basis.
123 |
124 | Dispatches to SVD or eigendecomposition based on use_eig flag.
125 | The flag should be determined BEFORE JIT compilation based on memory.
126 |
127 | Parameters
128 | ----------
129 | snapshot : Array
130 | Snapshot matrix, shape (n_samples, n_snapshots).
131 | energy_threshold : float
132 | Minimum fraction of total energy to retain (0 < threshold <= 1).
133 | use_eig : bool
134 | If True, use eigendecomposition (memory efficient for large data).
135 | If False, use SVD (faster for smaller data).
136 |
137 | Returns
138 | -------
139 | basis : Array
140 | Truncated POD basis, shape (n_samples, n_modes).
141 | cumul_energy : Array
142 | Cumulative energy fraction per mode.
143 | truncated_energy : float
144 | Actual energy fraction retained.
145 | """
146 | if use_eig:
147 | return compute_pod_basis_eig(snapshot, energy_threshold)
148 | return compute_pod_basis_svd(snapshot, energy_threshold)
149 |
--------------------------------------------------------------------------------
/pod_rbf/io.py:
--------------------------------------------------------------------------------
1 | """
2 | File I/O utilities for POD-RBF.
3 |
4 | Uses NumPy for file operations (not differentiable).
5 | """
6 |
7 | import os
8 | import pickle
9 |
10 | import jax.numpy as jnp
11 | import numpy as np
12 | from tqdm import tqdm
13 |
14 | from .types import ModelState
15 |
16 |
17 | def build_snapshot_matrix(
18 | dirpath: str,
19 | skiprows: int = 1,
20 | usecols: int | tuple[int, ...] = 0,
21 | verbose: bool = True,
22 | ) -> np.ndarray:
23 | """
24 | Load snapshot matrix from CSV files in directory.
25 |
26 | Files are loaded in alphanumeric order. Ensure parameter array
27 | matches this ordering.
28 |
29 | Parameters
30 | ----------
31 | dirpath : str
32 | Directory containing CSV files.
33 | skiprows : int
34 | Number of header rows to skip in each file.
35 | usecols : int or tuple
36 | Column(s) to read from each file.
37 | verbose : bool
38 | Show progress bar.
39 |
40 | Returns
41 | -------
42 | np.ndarray
43 | Snapshot matrix, shape (n_samples, n_snapshots).
44 | Returns NumPy array - convert to JAX as needed.
45 | """
46 | files = sorted(
47 | [
48 | f
49 | for f in os.listdir(dirpath)
50 | if os.path.isfile(os.path.join(dirpath, f)) and f.endswith(".csv")
51 | ]
52 | )
53 |
54 | if not files:
55 | raise ValueError(f"No CSV files found in {dirpath}")
56 |
57 | # Get dimensions from first file
58 | first_data = np.loadtxt(
59 | os.path.join(dirpath, files[0]),
60 | delimiter=",",
61 | skiprows=skiprows,
62 | usecols=usecols,
63 | )
64 | n_samples = len(first_data) if first_data.ndim > 0 else 1
65 | n_snapshots = len(files)
66 |
67 | snapshot = np.zeros((n_samples, n_snapshots))
68 |
69 | iterator = tqdm(files, desc="Loading snapshots") if verbose else files
70 | for i, f in enumerate(iterator):
71 | data = np.loadtxt(
72 | os.path.join(dirpath, f),
73 | delimiter=",",
74 | skiprows=skiprows,
75 | usecols=usecols,
76 | )
77 | data_len = len(data) if data.ndim > 0 else 1
78 | assert data_len == n_samples, f"Inconsistent samples in {f}: got {data_len}, expected {n_samples}"
79 | snapshot[:, i] = data
80 |
81 | return snapshot
82 |
83 |
84 | def save_model(filename: str, state: ModelState) -> None:
85 | """
86 | Save model state to file.
87 |
88 | Parameters
89 | ----------
90 | filename : str
91 | Output filename.
92 | state : ModelState
93 | Trained model state.
94 | """
95 | # Convert JAX arrays to NumPy for pickling
96 | state_dict = {
97 | "basis": np.asarray(state.basis),
98 | "weights": np.asarray(state.weights),
99 | "shape_factor": float(state.shape_factor) if state.shape_factor is not None else None,
100 | "train_params": np.asarray(state.train_params),
101 | "params_range": np.asarray(state.params_range),
102 | "truncated_energy": float(state.truncated_energy),
103 | "cumul_energy": np.asarray(state.cumul_energy),
104 | "poly_coeffs": np.asarray(state.poly_coeffs) if state.poly_coeffs is not None else None,
105 | "poly_degree": int(state.poly_degree),
106 | "kernel": state.kernel,
107 | "kernel_order": int(state.kernel_order),
108 | }
109 | with open(filename, "wb") as f:
110 | pickle.dump(state_dict, f)
111 |
112 |
113 | def load_model(filename: str) -> ModelState:
114 | """
115 | Load model state from file.
116 |
117 | Parameters
118 | ----------
119 | filename : str
120 | Input filename.
121 |
122 | Returns
123 | -------
124 | ModelState
125 | Loaded model state with JAX arrays.
126 | """
127 | with open(filename, "rb") as f:
128 | state_dict = pickle.load(f)
129 |
130 | # Handle backward compatibility for models saved without poly fields
131 | poly_coeffs = state_dict.get("poly_coeffs")
132 | if poly_coeffs is not None:
133 | poly_coeffs = jnp.array(poly_coeffs)
134 | poly_degree = state_dict.get("poly_degree", 0)
135 |
136 | # Handle backward compatibility for models saved without kernel fields
137 | # Legacy models used IMQ kernel only
138 | kernel = state_dict.get("kernel", "imq")
139 | kernel_order = state_dict.get("kernel_order", 3)
140 |
141 | return ModelState(
142 | basis=jnp.array(state_dict["basis"]),
143 | weights=jnp.array(state_dict["weights"]),
144 | shape_factor=state_dict["shape_factor"],
145 | train_params=jnp.array(state_dict["train_params"]),
146 | params_range=jnp.array(state_dict["params_range"]),
147 | truncated_energy=state_dict["truncated_energy"],
148 | cumul_energy=jnp.array(state_dict["cumul_energy"]),
149 | poly_coeffs=poly_coeffs,
150 | poly_degree=poly_degree,
151 | kernel=kernel,
152 | kernel_order=kernel_order,
153 | )
154 |
--------------------------------------------------------------------------------
/tests/test_io.py:
--------------------------------------------------------------------------------
1 | """Tests for I/O utilities."""
2 |
3 | import os
4 | import tempfile
5 |
6 | import jax.numpy as jnp
7 | import numpy as np
8 | import pytest
9 |
10 | from pod_rbf.io import build_snapshot_matrix, load_model, save_model
11 | from pod_rbf.types import ModelState
12 |
13 |
14 | class TestSaveLoadModel:
15 | """Test model serialization."""
16 |
17 | @pytest.fixture
18 | def sample_state(self):
19 | """Create sample model state."""
20 | return ModelState(
21 | basis=jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
22 | weights=jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
23 | shape_factor=0.5,
24 | train_params=jnp.array([[1.0, 2.0, 3.0]]),
25 | params_range=jnp.array([2.0]),
26 | truncated_energy=0.99,
27 | cumul_energy=jnp.array([0.9, 0.99]),
28 | poly_coeffs=jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
29 | poly_degree=2,
30 | kernel="imq",
31 | kernel_order=3,
32 | )
33 |
34 | def test_save_load_roundtrip(self, sample_state):
35 | """Saved and loaded state should match."""
36 | with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f:
37 | filename = f.name
38 |
39 | try:
40 | save_model(filename, sample_state)
41 | loaded = load_model(filename)
42 |
43 | assert jnp.allclose(loaded.basis, sample_state.basis)
44 | assert jnp.allclose(loaded.weights, sample_state.weights)
45 | assert loaded.shape_factor == sample_state.shape_factor
46 | assert jnp.allclose(loaded.train_params, sample_state.train_params)
47 | assert jnp.allclose(loaded.params_range, sample_state.params_range)
48 | assert loaded.truncated_energy == sample_state.truncated_energy
49 | assert jnp.allclose(loaded.cumul_energy, sample_state.cumul_energy)
50 | assert jnp.allclose(loaded.poly_coeffs, sample_state.poly_coeffs)
51 | assert loaded.poly_degree == sample_state.poly_degree
52 | finally:
53 | os.unlink(filename)
54 |
55 | def test_loaded_state_is_model_state(self, sample_state):
56 | """Loaded state should be ModelState instance."""
57 | with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f:
58 | filename = f.name
59 |
60 | try:
61 | save_model(filename, sample_state)
62 | loaded = load_model(filename)
63 |
64 | assert isinstance(loaded, ModelState)
65 | finally:
66 | os.unlink(filename)
67 |
68 |
69 | class TestBuildSnapshotMatrix:
70 | """Test snapshot matrix loading from CSV files."""
71 |
72 | @pytest.fixture
73 | def csv_dir(self):
74 | """Create temporary directory with CSV files."""
75 | with tempfile.TemporaryDirectory() as tmpdir:
76 | # Create 3 CSV files with 5 data points each
77 | for i in range(3):
78 | data = np.arange(5) * (i + 1) # [0,1,2,3,4] * (i+1)
79 | filepath = os.path.join(tmpdir, f"data_{i:03d}.csv")
80 | np.savetxt(filepath, data[:, None], delimiter=",", header="value", comments="")
81 |
82 | yield tmpdir
83 |
84 | def test_loads_all_files(self, csv_dir):
85 | """Should load all CSV files in directory."""
86 | snapshot = build_snapshot_matrix(csv_dir, verbose=False)
87 |
88 | assert snapshot.shape == (5, 3), f"Expected (5, 3), got {snapshot.shape}"
89 |
90 | def test_correct_values(self, csv_dir):
91 | """Should load correct values."""
92 | snapshot = build_snapshot_matrix(csv_dir, verbose=False)
93 |
94 | # File data_000.csv has [0,1,2,3,4]
95 | assert np.allclose(snapshot[:, 0], [0, 1, 2, 3, 4])
96 | # File data_001.csv has [0,2,4,6,8]
97 | assert np.allclose(snapshot[:, 1], [0, 2, 4, 6, 8])
98 | # File data_002.csv has [0,3,6,9,12]
99 | assert np.allclose(snapshot[:, 2], [0, 3, 6, 9, 12])
100 |
101 | def test_alphanumeric_order(self, csv_dir):
102 | """Files should be loaded in alphanumeric order."""
103 | snapshot = build_snapshot_matrix(csv_dir, verbose=False)
104 |
105 | # First file should be data_000.csv
106 | # Last file should be data_002.csv
107 | assert snapshot[1, 0] == 1 # data_000: index 1 = 1
108 | assert snapshot[1, 2] == 3 # data_002: index 1 = 3
109 |
110 | def test_empty_dir_raises(self):
111 | """Should raise ValueError for empty directory."""
112 | with tempfile.TemporaryDirectory() as tmpdir:
113 | with pytest.raises(ValueError, match="No CSV files"):
114 | build_snapshot_matrix(tmpdir, verbose=False)
115 |
116 | def test_inconsistent_samples_raises(self):
117 | """Should raise AssertionError for inconsistent sample counts."""
118 | with tempfile.TemporaryDirectory() as tmpdir:
119 | # Create files with different row counts
120 | np.savetxt(os.path.join(tmpdir, "a.csv"), [1, 2, 3], delimiter=",", header="x", comments="")
121 | np.savetxt(os.path.join(tmpdir, "b.csv"), [1, 2], delimiter=",", header="x", comments="")
122 |
123 | with pytest.raises(AssertionError, match="Inconsistent"):
124 | build_snapshot_matrix(tmpdir, verbose=False)
125 |
--------------------------------------------------------------------------------
/tests/test_decomposition.py:
--------------------------------------------------------------------------------
1 | """Tests for POD basis decomposition."""
2 |
3 | import jax.numpy as jnp
4 | import numpy as np
5 | import pytest
6 |
7 | from pod_rbf.decomposition import (
8 | compute_pod_basis,
9 | compute_pod_basis_eig,
10 | compute_pod_basis_svd,
11 | )
12 |
13 |
14 | class TestComputePodBasisSvd:
15 | """Test SVD-based POD basis computation."""
16 |
17 | def test_orthonormality(self):
18 | """SVD-based basis should be orthonormal."""
19 | np.random.seed(42)
20 | snapshot = jnp.array(np.random.randn(100, 10))
21 |
22 | basis, _, _ = compute_pod_basis_svd(snapshot, 0.99)
23 |
24 | gram = basis.T @ basis
25 | identity = jnp.eye(gram.shape[0])
26 | assert jnp.allclose(gram, identity, atol=1e-10), "Basis should be orthonormal"
27 |
28 | def test_energy_threshold(self):
29 | """Retained energy should meet threshold."""
30 | np.random.seed(42)
31 | snapshot = jnp.array(np.random.randn(50, 10))
32 |
33 | basis, cumul_energy, truncated_energy = compute_pod_basis_svd(snapshot, 0.95)
34 |
35 | assert truncated_energy >= 0.95, f"Truncated energy {truncated_energy} < 0.95"
36 |
37 | def test_basis_shape(self):
38 | """Basis should have correct shape."""
39 | np.random.seed(42)
40 | snapshot = jnp.array(np.random.randn(100, 10))
41 |
42 | basis, _, _ = compute_pod_basis_svd(snapshot, 0.99)
43 |
44 | assert basis.shape[0] == 100, "Basis should have n_samples rows"
45 | assert basis.shape[1] <= 10, "Basis should have at most n_snapshots columns"
46 | assert basis.shape[1] >= 1, "Basis should have at least 1 column"
47 |
48 | def test_cumul_energy_monotonic(self):
49 | """Cumulative energy should be monotonically increasing."""
50 | np.random.seed(42)
51 | snapshot = jnp.array(np.random.randn(50, 15))
52 |
53 | _, cumul_energy, _ = compute_pod_basis_svd(snapshot, 0.99)
54 |
55 | diffs = jnp.diff(cumul_energy)
56 | assert jnp.all(diffs >= 0), "Cumulative energy should be monotonically increasing"
57 |
58 | def test_keep_all_energy(self):
59 | """With threshold >= 1, should keep all modes."""
60 | np.random.seed(42)
61 | snapshot = jnp.array(np.random.randn(50, 10))
62 |
63 | basis, _, truncated_energy = compute_pod_basis_svd(snapshot, 1.0)
64 |
65 | assert basis.shape[1] == 10, "Should keep all modes when threshold=1.0"
66 | assert jnp.isclose(truncated_energy, 1.0, atol=1e-10)
67 |
68 |
69 | class TestComputePodBasisEig:
70 | """Test eigendecomposition-based POD basis computation."""
71 |
72 | def test_orthonormality(self):
73 | """Eig-based basis should be orthonormal."""
74 | np.random.seed(42)
75 | snapshot = jnp.array(np.random.randn(100, 10))
76 |
77 | basis, _, _ = compute_pod_basis_eig(snapshot, 0.99)
78 |
79 | gram = basis.T @ basis
80 | identity = jnp.eye(gram.shape[0])
81 | assert jnp.allclose(gram, identity, atol=1e-8), "Basis should be orthonormal"
82 |
83 | def test_energy_threshold(self):
84 | """Retained energy should meet threshold."""
85 | np.random.seed(42)
86 | snapshot = jnp.array(np.random.randn(50, 10))
87 |
88 | basis, cumul_energy, truncated_energy = compute_pod_basis_eig(snapshot, 0.95)
89 |
90 | assert truncated_energy >= 0.95, f"Truncated energy {truncated_energy} < 0.95"
91 |
92 | def test_basis_shape(self):
93 | """Basis should have correct shape."""
94 | np.random.seed(42)
95 | snapshot = jnp.array(np.random.randn(100, 10))
96 |
97 | basis, _, _ = compute_pod_basis_eig(snapshot, 0.99)
98 |
99 | assert basis.shape[0] == 100, "Basis should have n_samples rows"
100 | assert basis.shape[1] <= 10, "Basis should have at most n_snapshots columns"
101 | assert basis.shape[1] >= 1, "Basis should have at least 1 column"
102 |
103 |
104 | class TestSvdEigEquivalence:
105 | """Test that SVD and eigendecomposition produce equivalent results."""
106 |
107 | def test_span_equivalence(self):
108 | """SVD and eig bases should span the same subspace."""
109 | np.random.seed(42)
110 | snapshot = jnp.array(np.random.randn(100, 10))
111 |
112 | basis_svd, _, _ = compute_pod_basis_svd(snapshot, 0.99)
113 | basis_eig, _, _ = compute_pod_basis_eig(snapshot, 0.99)
114 |
115 | # Same number of modes (may differ by 1 due to numerical differences)
116 | n_modes_svd = basis_svd.shape[1]
117 | n_modes_eig = basis_eig.shape[1]
118 | assert abs(n_modes_svd - n_modes_eig) <= 1, "Mode counts should be similar"
119 |
120 | # Project onto common subspace - projection matrices should be similar
121 | proj_svd = basis_svd @ basis_svd.T
122 | proj_eig = basis_eig @ basis_eig.T
123 |
124 | # They span similar subspaces if the projections are close
125 | # (accounting for potentially different number of modes)
126 | min_modes = min(n_modes_svd, n_modes_eig)
127 | basis_svd_trunc = basis_svd[:, :min_modes]
128 | basis_eig_trunc = basis_eig[:, :min_modes]
129 | proj_svd_trunc = basis_svd_trunc @ basis_svd_trunc.T
130 | proj_eig_trunc = basis_eig_trunc @ basis_eig_trunc.T
131 |
132 | assert jnp.allclose(proj_svd_trunc, proj_eig_trunc, atol=1e-6), "Projections should be similar"
133 |
134 | def test_energy_equivalence(self):
135 | """SVD and eig should both meet energy threshold."""
136 | np.random.seed(42)
137 | snapshot = jnp.array(np.random.randn(100, 10))
138 |
139 | _, _, energy_svd = compute_pod_basis_svd(snapshot, 0.95)
140 | _, _, energy_eig = compute_pod_basis_eig(snapshot, 0.95)
141 |
142 | # Both should meet threshold (may differ due to discrete truncation)
143 | assert energy_svd >= 0.95, f"SVD energy {energy_svd} below threshold"
144 | assert energy_eig >= 0.95, f"Eig energy {energy_eig} below threshold"
145 |
146 |
147 | class TestComputePodBasis:
148 | """Test dispatch function."""
149 |
150 | def test_dispatch_svd(self):
151 | """use_eig=False should use SVD."""
152 | np.random.seed(42)
153 | snapshot = jnp.array(np.random.randn(50, 10))
154 |
155 | basis, _, _ = compute_pod_basis(snapshot, 0.99, use_eig=False)
156 | basis_svd, _, _ = compute_pod_basis_svd(snapshot, 0.99)
157 |
158 | assert jnp.allclose(basis, basis_svd), "use_eig=False should match SVD"
159 |
160 | def test_dispatch_eig(self):
161 | """use_eig=True should use eigendecomposition."""
162 | np.random.seed(42)
163 | snapshot = jnp.array(np.random.randn(50, 10))
164 |
165 | basis, _, _ = compute_pod_basis(snapshot, 0.99, use_eig=True)
166 | basis_eig, _, _ = compute_pod_basis_eig(snapshot, 0.99)
167 |
168 | assert jnp.allclose(basis, basis_eig), "use_eig=True should match eig"
169 |
--------------------------------------------------------------------------------
/pod_rbf/core.py:
--------------------------------------------------------------------------------
1 | """
2 | Core POD-RBF training and inference functions.
3 |
4 | Pure functional interface for JAX autodiff compatibility.
5 | """
6 |
7 | import jax.numpy as jnp
8 | from jax import Array
9 |
10 | from .decomposition import compute_pod_basis
11 | from .rbf import (
12 | build_collocation_matrix,
13 | build_inference_matrix,
14 | build_polynomial_basis,
15 | solve_augmented_system_direct,
16 | solve_augmented_system_schur,
17 | )
18 | from .shape_optimization import find_optimal_shape_param
19 | from .types import ModelState, TrainConfig, TrainResult
20 |
21 |
22 | def _normalize_params(params: Array) -> Array:
23 | """Ensure params is 2D: (n_params, n_points)."""
24 | if params.ndim == 1:
25 | return params[None, :]
26 | return params
27 |
28 |
29 | def train(
30 | snapshot: Array,
31 | train_params: Array,
32 | config: TrainConfig = TrainConfig(),
33 | shape_factor: float | None = None,
34 | ) -> TrainResult:
35 | """
36 | Train POD-RBF model.
37 |
38 | Parameters
39 | ----------
40 | snapshot : Array
41 | Solution snapshots, shape (n_samples, n_snapshots).
42 | Each column is a snapshot at a different parameter value.
43 | train_params : Array
44 | Parameter values, shape (n_snapshots,) or (n_params, n_snapshots).
45 | config : TrainConfig
46 | Training configuration.
47 | shape_factor : float, optional
48 | RBF shape parameter. If None, automatically optimized.
49 |
50 | Returns
51 | -------
52 | TrainResult
53 | Training result containing model state and diagnostics.
54 | """
55 | train_params = _normalize_params(jnp.asarray(train_params))
56 | snapshot = jnp.asarray(snapshot)
57 | n_params, n_snapshots = train_params.shape
58 |
59 | assert snapshot.shape[1] == n_snapshots, (
60 | f"Mismatch: {snapshot.shape[1]} snapshots vs {n_snapshots} params"
61 | )
62 |
63 | # Compute parameter ranges for normalization
64 | params_range = jnp.ptp(train_params, axis=1)
65 |
66 | # Determine decomposition method based on memory
67 | memory_gb = snapshot.nbytes / 1e9
68 | use_eig = memory_gb >= config.mem_limit_gb
69 |
70 | # Find optimal shape factor if not provided
71 | if shape_factor is None:
72 | shape_factor = find_optimal_shape_param(
73 | train_params,
74 | params_range,
75 | kernel=config.kernel,
76 | kernel_order=config.kernel_order,
77 | cond_range=config.cond_range,
78 | max_iters=config.max_bisection_iters,
79 | c_low_init=config.c_low_init,
80 | c_high_init=config.c_high_init,
81 | c_high_step=config.c_high_step,
82 | c_high_search_iters=config.c_high_search_iters,
83 | )
84 |
85 | # Compute truncated POD basis
86 | basis, cumul_energy, truncated_energy = compute_pod_basis(
87 | snapshot, config.energy_threshold, use_eig=use_eig
88 | )
89 |
90 | # Build collocation matrix
91 | F = build_collocation_matrix(
92 | train_params,
93 | params_range,
94 | kernel=config.kernel,
95 | shape_factor=shape_factor,
96 | kernel_order=config.kernel_order,
97 | )
98 | A = basis.T @ snapshot # (n_modes, n_train)
99 |
100 | # Compute weights using Schur complement solver or fallback to pinv
101 | poly_degree = config.poly_degree
102 | if poly_degree > 0:
103 | P = build_polynomial_basis(train_params, params_range, poly_degree)
104 | # Use direct solver for PHS (F is not SPD), Schur for others (F is SPD)
105 | if config.kernel == "polyharmonic_spline":
106 | weights, poly_coeffs = solve_augmented_system_direct(F, P, A)
107 | else:
108 | weights, poly_coeffs = solve_augmented_system_schur(F, P, A)
109 | else:
110 | weights = A @ jnp.linalg.pinv(F.T)
111 | poly_coeffs = None
112 |
113 | state = ModelState(
114 | basis=basis,
115 | weights=weights,
116 | shape_factor=float(shape_factor) if shape_factor is not None else None,
117 | train_params=train_params,
118 | params_range=params_range,
119 | truncated_energy=float(truncated_energy),
120 | cumul_energy=cumul_energy,
121 | poly_coeffs=poly_coeffs,
122 | poly_degree=poly_degree,
123 | kernel=config.kernel,
124 | kernel_order=config.kernel_order,
125 | )
126 |
127 | return TrainResult(
128 | state=state,
129 | n_modes=basis.shape[1],
130 | used_eig_decomp=use_eig,
131 | )
132 |
133 |
134 | def _inference_impl(
135 | basis: Array,
136 | weights: Array,
137 | train_params: Array,
138 | params_range: Array,
139 | shape_factor: float | None,
140 | poly_coeffs: Array | None,
141 | poly_degree: int,
142 | inf_params: Array,
143 | kernel: str,
144 | kernel_order: int,
145 | ) -> Array:
146 | """Core inference implementation for JIT compilation."""
147 | F = build_inference_matrix(
148 | train_params,
149 | inf_params,
150 | params_range,
151 | kernel=kernel,
152 | shape_factor=shape_factor,
153 | kernel_order=kernel_order,
154 | )
155 |
156 | # RBF contribution
157 | A = weights @ F.T # (n_modes, n_inf)
158 |
159 | # Add polynomial contribution if used
160 | if poly_coeffs is not None:
161 | P_inf = build_polynomial_basis(inf_params, params_range, poly_degree)
162 | A = A + poly_coeffs @ P_inf.T # (n_modes, n_inf)
163 |
164 | return basis @ A
165 |
166 |
167 | def inference(state: ModelState, inf_params: Array) -> Array:
168 | """
169 | Inference trained model at multiple parameter points.
170 |
171 | Parameters
172 | ----------
173 | state : ModelState
174 | Trained model state from train().
175 | inf_params : Array
176 | Inference parameters, shape (n_params, n_points) or (n_points,).
177 |
178 | Returns
179 | -------
180 | Array
181 | Predicted solutions, shape (n_samples, n_points).
182 | """
183 | inf_params = _normalize_params(jnp.asarray(inf_params))
184 |
185 | # Extract poly_degree as Python int before JIT tracing
186 | poly_degree = int(state.poly_degree) if state.poly_coeffs is not None else 0
187 |
188 | return _inference_impl(
189 | state.basis,
190 | state.weights,
191 | state.train_params,
192 | state.params_range,
193 | state.shape_factor,
194 | state.poly_coeffs,
195 | poly_degree,
196 | inf_params,
197 | kernel=state.kernel,
198 | kernel_order=state.kernel_order,
199 | )
200 |
201 |
202 | def inference_single(state: ModelState, inf_param: Array) -> Array:
203 | """
204 | Inference trained model at a single parameter point.
205 |
206 | More convenient for gradient computation.
207 |
208 | Parameters
209 | ----------
210 | state : ModelState
211 | Trained model state from train().
212 | inf_param : Array
213 | Single inference parameter, scalar or shape (n_params,).
214 |
215 | Returns
216 | -------
217 | Array
218 | Predicted solution, shape (n_samples,).
219 | """
220 | inf_param = jnp.asarray(inf_param)
221 |
222 | # Handle scalar input
223 | if inf_param.ndim == 0:
224 | inf_param = inf_param[None]
225 |
226 | # Shape to (n_params, 1) for inference
227 | inf_params = inf_param[:, None]
228 |
229 | result = inference(state, inf_params)
230 | return result[:, 0]
231 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # POD-RBF
2 |
3 | [](https://github.com/kylebeggs/POD-RBF/actions/workflows/tests.yml)
4 | [](https://codecov.io/gh/kylebeggs/POD-RBF)
5 | [](https://www.python.org/downloads/)
6 |
7 | 
8 |
9 | A Python package for building a Reduced Order Model (ROM) from high-dimensional data using a Proper
10 | Orthogonal Decomposition - Radial Basis Function (POD-RBF) Network.
11 |
12 | Given a 'snapshot' matrix of the data points with varying parameters, this code contains functions
13 | to find the truncated POD basis and interpolate using a RBF network for new parameters.
14 |
15 | This code is based on the following papers implementing the method:
16 |
17 | 1. [Solving inverse heat conduction problems using trained POD-RBF network inverse method - Ostrowski, Bialecki, Kassab (2008)](https://www.tandfonline.com/doi/full/10.1080/17415970701198290)
18 | 2. [RBF-trained POD-accelerated CFD analysis of wind loads on PV systems - Huayamave, Ceballos, Barriento, Seigneur, Barkaszi, Divo, and Kassab (2017)](https://www.emerald.com/insight/content/doi/10.1108/HFF-03-2016-0083/full/html)
19 | 3. [Real-Time Thermomechanical Modeling of PV Cell Fabrication via a POD-Trained RBF Interpolation Network - Das, Khoury, Divo, Huayamave, Ceballos, Eaglin, Kassab, Payne, Yelundur, and Seigneur (2020)](https://www.techscience.com/CMES/v122n3/38374)
20 |
21 | Features:
22 |
23 | * **JAX-based** - enables autodifferentiation for gradient optimization, sensitivity analysis, and inverse problems
24 | * Shape parameter optimization for the Radial Basis Functions (RBFs)
25 | * Algorithm switching based on memory requirements (eigenvalue decomposition vs. SVD)
26 |
27 | ## Installation
28 |
29 | ```bash
30 | pip install pod-rbf
31 | # or
32 | uv add pod-rbf
33 | ```
34 |
35 | ## Example
36 |
37 | In the [example](https://github.com/kylebeggs/POD-RBF/tree/master/examples) folder you can find two
38 | examples of that demonstrates how to use the package. The first is a simple heat conduction problem
39 | on a unit square.
40 |
41 | The other example will be demonstrated step-by-step here. We seek to build a ROM of the 2D
42 | lid-driven cavity problem. For the impatient, here is the full code to run this example. I will
43 | break down each line in the sections below.
44 |
45 | If you wish to build a ROM with multiple parameters, see this basic [2-parameter example](https://github.com/kylebeggs/POD-RBF/tree/master/examples/2-parameters.ipynb).
46 |
47 | ```python
48 | import pod_rbf
49 | import jax.numpy as jnp
50 | import numpy as np
51 |
52 | Re = np.linspace(0, 1000, num=11)
53 | Re[0] = 1
54 |
55 | # make snapshot matrix from csv files
56 | train_snapshot = pod_rbf.build_snapshot_matrix("examples/lid-driven-cavity/data/train/")
57 |
58 | # train the model (keeps 99% energy in POD modes by default)
59 | result = pod_rbf.train(train_snapshot, Re)
60 |
61 | # inference on an unseen parameter
62 | sol = pod_rbf.inference_single(result.state, jnp.array(450.0))
63 | ```
64 |
65 | ### Building the snapshot matrix
66 |
67 | First, we need to build the snapshot matrix, X, which contains the data we are training on. It must be of the form where each column is the k-th 'snapshot' of the solution field given some
68 | parameter, p_k, with n samples in the snapshot at locations x_n. A single snapshot is below
69 |
70 | 
71 |
72 | and the snapshot matrix would then look like
73 |
74 | 
75 |
76 | where m is the total number of snapshots.
77 |
78 | For example, suppose our lid-driven cavity was solved on a mesh with 400 cells and we varied the
79 | parameter of interest (Re number in this case) 10 times. We would have a matrix of size (n,m) =
80 | (400,10).
81 |
82 | For our example, solutions were generated using STAR-CCM+ for Reynolds numbers of 1-1000 in
83 | increments of 100. The Re number will serve as our single parameter in this case. Snapshots were
84 | generated as a separate .csv file for each. To make it easier to combine them all into the snapshot
85 | matrix, there is a function which takes the path and file pattern. The same syntax is borrowed from
86 | the ffmpeg tool - that is, if you had files named as sample_001.csv, sample_002.csv ... you would
87 | input sample_%03d.csv. The files for this example are named as re-%04d.csv so we would issue a
88 | command as
89 |
90 | ```python
91 | >>> import pod_rbf
92 | >>> train_snapshot = pod_rbf.build_snapshot_matrix("examples/lid-driven-cavity/data/train/")
93 | ```
94 |
95 | ---
96 | Note: if you are using this approach where each snapshot is contained in a different csv file,
97 | please group all of them into a directory of their own.
98 |
99 | ---
100 |
101 | If you notice, these files are contained in the train folder, as I also generated some more
102 | snapshots for validation (which as you probably guessed is in the /data/validation folder). now we
103 | need to generate the array of input parameters that correspond to each snapshot.
104 |
105 | ```python
106 | >>> Re = np.linspace(0, 1000, num=11)
107 | >>> Re[0] = 1
108 | ```
109 |
110 | ---
111 | Note: it is extremely important that each input parameter maps to the same element number of the
112 | snapshot matrix. For example if the 5th column (index 4) then the input parameter used to generate
113 | that snapshot should be what you find in the 5th element (index 4) of the array, e.g.
114 | ```train_snapshot[:,4] -> Re[4]```. The csv files are loaded in alpha-numeric order so that is why
115 | the input parameter array goes from 1 -> 1000.
116 |
117 | ---
118 |
119 | where ```Re``` is an array of input parameters that we are training the model on. Next, we train
120 | the model with a single function call. We choose to keep 99% of the energy in POD modes (this is
121 | the default, so you don't have to set that).
122 |
123 | ```python
124 | >>> result = pod_rbf.train(train_snapshot, Re)
125 | >>> # Or with custom config:
126 | >>> config = pod_rbf.TrainConfig(energy_threshold=0.99)
127 | >>> result = pod_rbf.train(train_snapshot, Re, config)
128 | ```
129 |
130 | Now that the weights and truncated POD basis have been calculated and stored in `result.state`, we
131 | can inference on the model using any input parameter.
132 |
133 | ```python
134 | >>> import jax.numpy as jnp
135 | >>> sol = pod_rbf.inference_single(result.state, jnp.array(450.0))
136 | ```
137 |
138 | and we can plot the results comparing the inference and target below
139 |
140 | 
141 |
142 | and for Reynold's number of 50:
143 |
144 | 
145 |
146 |
147 | ### Saving and loading models
148 | You can save and load the trained model state:
149 |
150 | ```python
151 | >>> pod_rbf.save_model("model.pkl", result.state)
152 | >>> state = pod_rbf.load_model("model.pkl")
153 | >>> sol = pod_rbf.inference_single(state, jnp.array(450.0))
154 | ```
155 |
156 | ### Autodifferentiation
157 |
158 | Since POD-RBF is built on JAX, you can compute gradients for optimization and inverse problems:
159 |
160 | ```python
161 | >>> import jax
162 | >>> grad_fn = jax.grad(lambda p: jnp.sum(pod_rbf.inference_single(result.state, p)**2))
163 | >>> gradient = grad_fn(jnp.array(450.0))
164 | ```
165 |
--------------------------------------------------------------------------------
/examples/heat-conduction/design-optimization.py:
--------------------------------------------------------------------------------
1 | """
2 | Design Optimization with Autodiff
3 |
4 | Demonstrates using JAX autodifferentiation through POD-RBF for design optimization.
5 |
6 | Problem: 1D nonlinear heat conduction with temperature-dependent thermal conductivity
7 | k(T) = k₀(1 + βT)
8 |
9 | The nonlinear dependence on boundary temperature T_L creates a non-trivial optimization
10 | landscape requiring multiple POD modes to capture.
11 |
12 | Objective: Find boundary temperature T_L that achieves a target average temperature.
13 | """
14 |
15 | import time
16 | import numpy as np
17 | import matplotlib.pyplot as plt
18 |
19 | import jax
20 | import jax.numpy as jnp
21 | import pod_rbf
22 |
23 |
24 | jax.config.update("jax_default_device", jax.devices("cpu")[0])
25 |
26 | # Physical parameters
27 | T_0 = 300.0 # Left boundary temperature (K)
28 | BETA = 0.002 # Temperature coefficient for conductivity
29 | L = 1.0 # Domain length
30 |
31 |
32 | def analytical_solution(x, T_L):
33 | """
34 | Analytical solution for 1D steady heat conduction with k(T) = k₀(1 + βT).
35 |
36 | Uses Kirchhoff transformation: θ = T + (β/2)T²
37 | The transformed variable θ satisfies linear diffusion, giving θ(x) linear in x.
38 | Inverting the quadratic yields T(x).
39 | """
40 | theta_0 = T_0 + (BETA / 2) * T_0**2
41 | theta_L = T_L + (BETA / 2) * T_L**2
42 | theta = theta_0 + (theta_L - theta_0) * (x / L)
43 | return (-1 + np.sqrt(1 + 2 * BETA * theta)) / BETA
44 |
45 |
46 | def build_snapshot_matrix(T_L_values, x):
47 | """Build snapshot matrix from analytical solutions at different T_L values."""
48 | print("Building snapshot matrix... ", end="")
49 | start = time.time()
50 |
51 | n_points = len(x)
52 | n_snapshots = len(T_L_values)
53 | snapshot = np.zeros((n_points, n_snapshots))
54 |
55 | for i, T_L in enumerate(T_L_values):
56 | snapshot[:, i] = analytical_solution(x, T_L)
57 |
58 | print(f"took {time.time() - start:.3f} sec")
59 | return snapshot
60 |
61 |
62 | def run_optimization():
63 | # Spatial discretization
64 | n_points = 100
65 | x = np.linspace(0, L, n_points)
66 |
67 | # Training: sample T_L over range [350, 600] K
68 | T_L_train = np.linspace(350, 600, num=20)
69 | snapshot = build_snapshot_matrix(T_L_train, x)
70 |
71 | # Train ROM
72 | config = pod_rbf.TrainConfig(energy_threshold=0.9999, poly_degree=2)
73 | result = pod_rbf.train(snapshot, T_L_train, config)
74 | state = result.state
75 |
76 | print(f"Trained with {result.n_modes} modes, energy retained: {state.truncated_energy:.6f}")
77 | print(f"Cumulative energy per mode: {state.cumul_energy}")
78 |
79 | # Target: achieve average temperature of 400 K
80 | T_target = 400.0
81 |
82 | def objective(T_L):
83 | """Squared error between predicted average temp and target."""
84 | pred = pod_rbf.inference_single(state, T_L)
85 | avg_temp = jnp.mean(pred)
86 | return (avg_temp - T_target) ** 2
87 |
88 | # JIT-compile functions
89 | grad_fn = jax.jit(jax.grad(objective))
90 | obj_fn = jax.jit(objective)
91 |
92 | # Also track average temperature
93 | @jax.jit
94 | def avg_temp_fn(T_L):
95 | return jnp.mean(pod_rbf.inference_single(state, T_L))
96 |
97 | # Optimization settings
98 | T_L_init = jnp.array(550.0) # Start far from optimal
99 | T_L = T_L_init
100 | lr = 1.0 # Learning rate
101 | n_iters = 30
102 | T_L_min, T_L_max = 350.0, 600.0 # Valid parameter range
103 |
104 | # Track history
105 | history = {
106 | "T_L": [float(T_L)],
107 | "objective": [float(obj_fn(T_L))],
108 | "avg_temp": [float(avg_temp_fn(T_L))],
109 | "grad": [],
110 | }
111 |
112 | print(f"\nOptimizing: find T_L to achieve average temperature = {T_target} K")
113 | print(f"Initial: T_L = {T_L_init:.1f} K, avg_temp = {history['avg_temp'][0]:.2f} K")
114 |
115 | # Gradient descent
116 | for i in range(n_iters):
117 | grad = grad_fn(T_L)
118 | history["grad"].append(float(grad))
119 |
120 | T_L = T_L - lr * grad
121 | T_L = jnp.clip(T_L, T_L_min, T_L_max)
122 |
123 | obj_val = obj_fn(T_L)
124 | avg_temp = avg_temp_fn(T_L)
125 | history["T_L"].append(float(T_L))
126 | history["objective"].append(float(obj_val))
127 | history["avg_temp"].append(float(avg_temp))
128 |
129 | if (i + 1) % 5 == 0:
130 | print(
131 | f" Iter {i+1:3d}: T_L = {T_L:.2f} K, "
132 | f"avg_temp = {avg_temp:.2f} K, loss = {obj_val:.4e}"
133 | )
134 |
135 | T_L_opt = float(T_L)
136 | print(f"\nOptimal T_L = {T_L_opt:.2f} K (avg_temp = {history['avg_temp'][-1]:.2f} K)")
137 |
138 | # Get solutions for visualization
139 | pred_init = pod_rbf.inference_single(state, T_L_init)
140 | pred_opt = pod_rbf.inference_single(state, jnp.array(T_L_opt))
141 |
142 | # Analytical solutions for comparison
143 | T_analytical_init = analytical_solution(x, float(T_L_init))
144 | T_analytical_opt = analytical_solution(x, T_L_opt)
145 |
146 | return (
147 | history,
148 | x,
149 | pred_init,
150 | pred_opt,
151 | T_analytical_init,
152 | T_analytical_opt,
153 | T_L_init,
154 | T_L_opt,
155 | T_target,
156 | )
157 |
158 |
159 | def plot_results(
160 | history, x, pred_init, pred_opt, T_anal_init, T_anal_opt, T_L_init, T_L_opt, T_target
161 | ):
162 | fig, axes = plt.subplots(2, 2, figsize=(12, 10))
163 |
164 | # Objective convergence
165 | ax = axes[0, 0]
166 | ax.semilogy(history["objective"], "b-", linewidth=2)
167 | ax.set_xlabel("Iteration")
168 | ax.set_ylabel("Objective (squared error)")
169 | ax.set_title("Optimization Convergence")
170 | ax.grid(True, alpha=0.3)
171 |
172 | # Average temperature convergence
173 | ax = axes[0, 1]
174 | ax.plot(history["avg_temp"], "r-", linewidth=2, label="Predicted avg temp")
175 | ax.axhline(y=T_target, color="g", linestyle="--", linewidth=2, label=f"Target: {T_target} K")
176 | ax.set_xlabel("Iteration")
177 | ax.set_ylabel("Average Temperature (K)")
178 | ax.set_title("Average Temperature vs Target")
179 | ax.legend()
180 | ax.grid(True, alpha=0.3)
181 |
182 | # Initial temperature profile
183 | ax = axes[1, 0]
184 | ax.plot(x, pred_init, "b-", linewidth=2, label="ROM prediction")
185 | ax.plot(x, T_anal_init, "r--", linewidth=2, label="Analytical")
186 | ax.axhline(
187 | y=float(jnp.mean(pred_init)),
188 | color="g",
189 | linestyle=":",
190 | label=f"Avg: {float(jnp.mean(pred_init)):.1f} K",
191 | )
192 | ax.set_xlabel("x")
193 | ax.set_ylabel("Temperature (K)")
194 | ax.set_title(f"Initial: T_L = {float(T_L_init):.1f} K")
195 | ax.legend()
196 | ax.grid(True, alpha=0.3)
197 |
198 | # Optimized temperature profile
199 | ax = axes[1, 1]
200 | ax.plot(x, pred_opt, "b-", linewidth=2, label="ROM prediction")
201 | ax.plot(x, T_anal_opt, "r--", linewidth=2, label="Analytical")
202 | ax.axhline(
203 | y=float(jnp.mean(pred_opt)),
204 | color="g",
205 | linestyle=":",
206 | label=f"Avg: {float(jnp.mean(pred_opt)):.1f} K",
207 | )
208 | ax.axhline(y=T_target, color="orange", linestyle="--", alpha=0.7, label=f"Target: {T_target} K")
209 | ax.set_xlabel("x")
210 | ax.set_ylabel("Temperature (K)")
211 | ax.set_title(f"Optimized: T_L = {T_L_opt:.1f} K")
212 | ax.legend()
213 | ax.grid(True, alpha=0.3)
214 |
215 | plt.tight_layout()
216 | plt.show()
217 |
218 |
219 | if __name__ == "__main__":
220 | results = run_optimization()
221 | plot_results(*results)
222 |
--------------------------------------------------------------------------------
/tests/test_kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for RBF kernel functions.
3 | """
4 |
5 | import jax.numpy as jnp
6 | import pytest
7 |
8 | from pod_rbf.kernels import (
9 | KernelType,
10 | apply_kernel,
11 | kernel_gaussian,
12 | kernel_imq,
13 | kernel_polyharmonic_spline,
14 | )
15 |
16 |
17 | class TestKernelIMQ:
18 | """Test Inverse Multiquadrics kernel."""
19 |
20 | def test_imq_at_zero(self):
21 | """IMQ kernel should be 1 at r=0."""
22 | r2 = jnp.array([0.0])
23 | result = kernel_imq(r2, shape_factor=1.0)
24 | assert jnp.allclose(result, 1.0)
25 |
26 | def test_imq_bounded(self):
27 | """IMQ kernel values should be in (0, 1]."""
28 | r2 = jnp.array([0.0, 0.1, 1.0, 10.0, 100.0])
29 | result = kernel_imq(r2, shape_factor=1.0)
30 | assert jnp.all(result > 0)
31 | assert jnp.all(result <= 1.0)
32 |
33 | def test_imq_decay(self):
34 | """IMQ should decay monotonically with distance."""
35 | r2 = jnp.array([0.0, 1.0, 4.0, 9.0])
36 | result = kernel_imq(r2, shape_factor=1.0)
37 | # Check monotonic decrease
38 | assert jnp.all(result[:-1] >= result[1:])
39 |
40 | def test_imq_shape_factor_effect(self):
41 | """Larger shape factor should give slower decay."""
42 | r2 = jnp.array([4.0])
43 | val_small_c = kernel_imq(r2, shape_factor=0.5)
44 | val_large_c = kernel_imq(r2, shape_factor=2.0)
45 | # Larger c means slower decay, so value should be higher
46 | assert val_large_c > val_small_c
47 |
48 |
49 | class TestKernelGaussian:
50 | """Test Gaussian kernel."""
51 |
52 | def test_gaussian_at_zero(self):
53 | """Gaussian kernel should be 1 at r=0."""
54 | r2 = jnp.array([0.0])
55 | result = kernel_gaussian(r2, shape_factor=1.0)
56 | assert jnp.allclose(result, 1.0)
57 |
58 | def test_gaussian_bounded(self):
59 | """Gaussian kernel values should be in (0, 1]."""
60 | r2 = jnp.array([0.0, 0.1, 1.0, 10.0, 100.0])
61 | result = kernel_gaussian(r2, shape_factor=1.0)
62 | assert jnp.all(result > 0)
63 | assert jnp.all(result <= 1.0)
64 |
65 | def test_gaussian_decay(self):
66 | """Gaussian should decay monotonically with distance."""
67 | r2 = jnp.array([0.0, 1.0, 4.0, 9.0])
68 | result = kernel_gaussian(r2, shape_factor=1.0)
69 | # Check monotonic decrease
70 | assert jnp.all(result[:-1] >= result[1:])
71 |
72 | def test_gaussian_exponential_decay(self):
73 | """Gaussian should decay exponentially."""
74 | r2 = jnp.array([0.0, 1.0])
75 | c = 1.0
76 | result = kernel_gaussian(r2, shape_factor=c)
77 | # At r²=1, should be exp(-1/c²) = exp(-1)
78 | assert jnp.allclose(result[1], jnp.exp(-1.0))
79 |
80 | def test_gaussian_shape_factor_effect(self):
81 | """Larger shape factor should give slower decay."""
82 | r2 = jnp.array([4.0])
83 | val_small_c = kernel_gaussian(r2, shape_factor=0.5)
84 | val_large_c = kernel_gaussian(r2, shape_factor=2.0)
85 | # Larger c means slower decay, so value should be higher
86 | assert val_large_c > val_small_c
87 |
88 |
89 | class TestKernelPolyharmonicSpline:
90 | """Test Polyharmonic Spline kernels."""
91 |
92 | def test_phs_order_1(self):
93 | """PHS order 1: phi(r) = r."""
94 | r2 = jnp.array([0.0, 1.0, 4.0, 9.0])
95 | result = kernel_polyharmonic_spline(r2, order=1)
96 | expected = jnp.sqrt(r2)
97 | assert jnp.allclose(result, expected)
98 |
99 | def test_phs_order_3(self):
100 | """PHS order 3: phi(r) = r³."""
101 | r2 = jnp.array([0.0, 1.0, 4.0, 9.0])
102 | result = kernel_polyharmonic_spline(r2, order=3)
103 | r = jnp.sqrt(r2)
104 | expected = r**3
105 | assert jnp.allclose(result, expected)
106 |
107 | def test_phs_order_5(self):
108 | """PHS order 5: phi(r) = r⁵."""
109 | r2 = jnp.array([0.0, 1.0, 4.0, 9.0])
110 | result = kernel_polyharmonic_spline(r2, order=5)
111 | r = jnp.sqrt(r2)
112 | expected = r**5
113 | assert jnp.allclose(result, expected)
114 |
115 | def test_phs_order_2(self):
116 | """PHS order 2: phi(r) = r²*log(r)."""
117 | # Avoid r=0 for log
118 | r2 = jnp.array([1.0, 4.0, 9.0])
119 | result = kernel_polyharmonic_spline(r2, order=2)
120 | r = jnp.sqrt(r2)
121 | expected = (r**2) * jnp.log(r)
122 | assert jnp.allclose(result, expected)
123 |
124 | def test_phs_order_4(self):
125 | """PHS order 4: phi(r) = r⁴*log(r)."""
126 | # Avoid r=0 for log
127 | r2 = jnp.array([1.0, 4.0, 9.0])
128 | result = kernel_polyharmonic_spline(r2, order=4)
129 | r = jnp.sqrt(r2)
130 | expected = (r**4) * jnp.log(r)
131 | assert jnp.allclose(result, expected)
132 |
133 | def test_phs_at_zero_odd(self):
134 | """PHS odd orders should be 0 at r=0."""
135 | r2 = jnp.array([0.0])
136 | for order in [1, 3, 5]:
137 | result = kernel_polyharmonic_spline(r2, order=order)
138 | assert jnp.allclose(result, 0.0)
139 |
140 | def test_phs_at_zero_even(self):
141 | """PHS even orders should be 0 at r=0 (handled by jnp.where)."""
142 | r2 = jnp.array([0.0])
143 | for order in [2, 4]:
144 | result = kernel_polyharmonic_spline(r2, order=order)
145 | # Should be set to 0 by the where clause to avoid log(0)
146 | assert jnp.isfinite(result[0])
147 | assert jnp.allclose(result, 0.0)
148 |
149 | def test_phs_growth(self):
150 | """PHS should grow with distance (unlike IMQ/Gaussian)."""
151 | r2 = jnp.array([0.0, 1.0, 4.0])
152 | result = kernel_polyharmonic_spline(r2, order=3)
153 | # Should increase (not decay)
154 | assert result[0] <= result[1] <= result[2]
155 |
156 |
157 | class TestApplyKernel:
158 | """Test kernel dispatcher."""
159 |
160 | def test_apply_imq(self):
161 | """Dispatcher should correctly apply IMQ kernel."""
162 | r2 = jnp.array([0.0, 1.0, 4.0])
163 | result = apply_kernel(r2, "imq", shape_factor=1.0, kernel_order=3)
164 | expected = kernel_imq(r2, shape_factor=1.0)
165 | assert jnp.allclose(result, expected)
166 |
167 | def test_apply_gaussian(self):
168 | """Dispatcher should correctly apply Gaussian kernel."""
169 | r2 = jnp.array([0.0, 1.0, 4.0])
170 | result = apply_kernel(r2, "gaussian", shape_factor=1.0, kernel_order=3)
171 | expected = kernel_gaussian(r2, shape_factor=1.0)
172 | assert jnp.allclose(result, expected)
173 |
174 | def test_apply_phs(self):
175 | """Dispatcher should correctly apply PHS kernel."""
176 | r2 = jnp.array([0.0, 1.0, 4.0])
177 | result = apply_kernel(r2, "polyharmonic_spline", shape_factor=None, kernel_order=3)
178 | expected = kernel_polyharmonic_spline(r2, order=3)
179 | assert jnp.allclose(result, expected)
180 |
181 | def test_apply_invalid_kernel(self):
182 | """Dispatcher should raise error for invalid kernel."""
183 | r2 = jnp.array([1.0])
184 | with pytest.raises(ValueError, match="is not a valid KernelType"):
185 | apply_kernel(r2, "invalid_kernel", shape_factor=1.0, kernel_order=3)
186 |
187 |
188 | class TestKernelType:
189 | """Test KernelType enum."""
190 |
191 | def test_enum_values(self):
192 | """Test that enum values are correct."""
193 | assert KernelType.IMQ == "imq"
194 | assert KernelType.GAUSSIAN == "gaussian"
195 | assert KernelType.POLYHARMONIC_SPLINE == "polyharmonic_spline"
196 |
197 | def test_enum_from_string(self):
198 | """Test creating enum from string."""
199 | assert KernelType("imq") == KernelType.IMQ
200 | assert KernelType("gaussian") == KernelType.GAUSSIAN
201 | assert KernelType("polyharmonic_spline") == KernelType.POLYHARMONIC_SPLINE
202 |
--------------------------------------------------------------------------------
/pod_rbf/rbf.py:
--------------------------------------------------------------------------------
1 | """
2 | Radial Basis Function (RBF) kernel and matrix construction.
3 |
4 | Supports multiple kernel types:
5 | - Inverse Multi-Quadrics (IMQ): phi(r) = 1 / sqrt(r²/c² + 1)
6 | - Gaussian: phi(r) = exp(-r²/c²)
7 | - Polyharmonic Splines (PHS): phi(r) = r^k or r^k*log(r)
8 | """
9 |
10 | import jax
11 | import jax.numpy as jnp
12 | import jax.scipy.linalg as jla
13 | from jax import Array
14 |
15 | from .kernels import apply_kernel
16 |
17 |
18 | def build_collocation_matrix(
19 | train_params: Array,
20 | params_range: Array,
21 | kernel: str = "imq",
22 | shape_factor: float | None = None,
23 | kernel_order: int = 3,
24 | ) -> Array:
25 | """
26 | Build RBF collocation matrix for training.
27 |
28 | Parameters
29 | ----------
30 | train_params : Array
31 | Training parameters, shape (n_params, n_train_points).
32 | params_range : Array
33 | Range of each parameter for normalization, shape (n_params,).
34 | kernel : str, optional
35 | Kernel type: 'imq', 'gaussian', or 'polyharmonic_spline'.
36 | Default is 'imq'.
37 | shape_factor : float | None, optional
38 | RBF shape parameter c. Required for IMQ and Gaussian kernels.
39 | Ignored for polyharmonic splines.
40 | kernel_order : int, optional
41 | Order for polyharmonic splines (default 3).
42 | Ignored for other kernels.
43 |
44 | Returns
45 | -------
46 | Array
47 | Collocation matrix, shape (n_train_points, n_train_points).
48 | """
49 | n_params, n_train = train_params.shape
50 |
51 | def accumulate_r2(i: int, r2: Array) -> Array:
52 | param_row = train_params[i, :]
53 | diff = param_row[:, None] - param_row[None, :] # (n_train, n_train)
54 | return r2 + (diff / params_range[i]) ** 2
55 |
56 | r2 = jax.lax.fori_loop(0, n_params, accumulate_r2, jnp.zeros((n_train, n_train)))
57 |
58 | return apply_kernel(r2, kernel, shape_factor, kernel_order)
59 |
60 |
61 | def build_inference_matrix(
62 | train_params: Array,
63 | inf_params: Array,
64 | params_range: Array,
65 | kernel: str = "imq",
66 | shape_factor: float | None = None,
67 | kernel_order: int = 3,
68 | ) -> Array:
69 | """
70 | Build RBF inference matrix for prediction at new parameters.
71 |
72 | Parameters
73 | ----------
74 | train_params : Array
75 | Training parameters, shape (n_params, n_train_points).
76 | inf_params : Array
77 | Inference parameters, shape (n_params, n_inf_points).
78 | params_range : Array
79 | Range of each parameter for normalization, shape (n_params,).
80 | kernel : str, optional
81 | Kernel type: 'imq', 'gaussian', or 'polyharmonic_spline'.
82 | Default is 'imq'.
83 | shape_factor : float | None, optional
84 | RBF shape parameter c. Required for IMQ and Gaussian kernels.
85 | Ignored for polyharmonic splines.
86 | kernel_order : int, optional
87 | Order for polyharmonic splines (default 3).
88 | Ignored for other kernels.
89 |
90 | Returns
91 | -------
92 | Array
93 | Inference matrix, shape (n_inf_points, n_train_points).
94 | """
95 | n_params = train_params.shape[0]
96 | n_inf = inf_params.shape[1]
97 | n_train = train_params.shape[1]
98 |
99 | def accumulate_r2(i: int, r2: Array) -> Array:
100 | diff = inf_params[i, :, None] - train_params[i, None, :] # (n_inf, n_train)
101 | return r2 + (diff / params_range[i]) ** 2
102 |
103 | r2 = jax.lax.fori_loop(0, n_params, accumulate_r2, jnp.zeros((n_inf, n_train)))
104 |
105 | return apply_kernel(r2, kernel, shape_factor, kernel_order)
106 |
107 |
108 | def build_polynomial_basis(
109 | params: Array,
110 | params_range: Array,
111 | degree: int = 2,
112 | ) -> Array:
113 | """
114 | Build polynomial basis matrix for RBF augmentation.
115 |
116 | Parameters
117 | ----------
118 | params : Array
119 | Parameters, shape (n_params, n_points).
120 | params_range : Array
121 | Range of each parameter for normalization, shape (n_params,).
122 | degree : int
123 | Polynomial degree (0=constant, 1=linear, 2=quadratic).
124 | Note: This should be a Python int, not a traced value.
125 |
126 | Returns
127 | -------
128 | Array
129 | Polynomial basis matrix, shape (n_points, n_poly).
130 | - degree 0: [1] -> 1 column
131 | - degree 1: [1, p1, p2, ...] -> n_params + 1 columns
132 | - degree 2: [1, p1, ..., pn, p1², ..., pn², p1*p2, ...] -> (n+1)(n+2)/2 cols
133 | """
134 | n_params, n_points = params.shape
135 |
136 | # Normalize parameters by dividing by range (consistent between train/inference)
137 | p_norm = params / params_range[:, None]
138 |
139 | # Build polynomial terms - degree must be a Python int for JIT compatibility
140 | terms = [jnp.ones((n_points,))] # constant term
141 |
142 | if degree >= 1:
143 | # Linear terms
144 | for i in range(n_params):
145 | terms.append(p_norm[i, :])
146 |
147 | if degree >= 2:
148 | # Squared terms
149 | for i in range(n_params):
150 | terms.append(p_norm[i, :] ** 2)
151 | # Cross terms
152 | for i in range(n_params):
153 | for j in range(i + 1, n_params):
154 | terms.append(p_norm[i, :] * p_norm[j, :])
155 |
156 | return jnp.stack(terms, axis=1)
157 |
158 |
159 | def solve_augmented_system_schur(
160 | F: Array,
161 | P: Array,
162 | rhs: Array,
163 | ) -> tuple[Array, Array]:
164 | """
165 | Solve augmented RBF system via Schur complement.
166 |
167 | Solves the saddle-point system:
168 | [F P] [λ] [rhs]
169 | [P.T 0] [c] = [0]
170 |
171 | Using Schur complement: S = P.T @ F^{-1} @ P
172 |
173 | Parameters
174 | ----------
175 | F : Array
176 | RBF collocation matrix, shape (n_train, n_train). Symmetric positive definite.
177 | P : Array
178 | Polynomial basis matrix, shape (n_train, n_poly).
179 | rhs : Array
180 | Right-hand side, shape (n_rhs, n_train). Each row is a separate RHS.
181 |
182 | Returns
183 | -------
184 | tuple[Array, Array]
185 | rbf_weights : shape (n_rhs, n_train)
186 | poly_coeffs : shape (n_rhs, n_poly)
187 | """
188 | # Cholesky factorization of F (symmetric positive definite)
189 | cho_F = jla.cho_factor(F)
190 |
191 | # Solve F @ X = P for X = F^{-1} @ P
192 | F_inv_P = jla.cho_solve(cho_F, P) # (n_train, n_poly)
193 |
194 | # Schur complement: S = P.T @ F^{-1} @ P
195 | S = P.T @ F_inv_P # (n_poly, n_poly)
196 |
197 | # Solve F @ Y = rhs.T for Y = F^{-1} @ rhs.T
198 | F_inv_rhs = jla.cho_solve(cho_F, rhs.T) # (n_train, n_rhs)
199 |
200 | # Solve S @ c.T = P.T @ F^{-1} @ rhs.T for polynomial coefficients
201 | schur_rhs = P.T @ F_inv_rhs # (n_poly, n_rhs)
202 | poly_coeffs = jnp.linalg.solve(S, schur_rhs).T # (n_rhs, n_poly)
203 |
204 | # Back-substitute: λ = F^{-1} @ (rhs - P @ c)
205 | rbf_weights = (F_inv_rhs - F_inv_P @ poly_coeffs.T).T # (n_rhs, n_train)
206 |
207 | return rbf_weights, poly_coeffs
208 |
209 |
210 | def solve_augmented_system_direct(
211 | F: Array,
212 | P: Array,
213 | rhs: Array,
214 | ) -> tuple[Array, Array]:
215 | """
216 | Solve augmented RBF system by direct assembly and solve.
217 |
218 | Solves the saddle-point system:
219 | [F P] [λ] [rhs]
220 | [P.T 0] [c] = [0]
221 |
222 | This method assembles and solves the full system directly, which works
223 | for kernels where F is not positive definite (e.g., polyharmonic splines).
224 |
225 | Parameters
226 | ----------
227 | F : Array
228 | RBF collocation matrix, shape (n_train, n_train).
229 | P : Array
230 | Polynomial basis matrix, shape (n_train, n_poly).
231 | rhs : Array
232 | Right-hand side, shape (n_rhs, n_train). Each row is a separate RHS.
233 |
234 | Returns
235 | -------
236 | tuple[Array, Array]
237 | rbf_weights : shape (n_rhs, n_train)
238 | poly_coeffs : shape (n_rhs, n_poly)
239 | """
240 | n_train = F.shape[0]
241 | n_poly = P.shape[1]
242 | n_rhs = rhs.shape[0]
243 |
244 | # Assemble full augmented system matrix
245 | # [F P ]
246 | # [P' 0 ]
247 | top = jnp.hstack([F, P])
248 | bottom = jnp.hstack([P.T, jnp.zeros((n_poly, n_poly))])
249 | A_aug = jnp.vstack([top, bottom])
250 |
251 | # Assemble augmented RHS
252 | # [rhs]
253 | # [0 ]
254 | rhs_aug = jnp.hstack([rhs, jnp.zeros((n_rhs, n_poly))])
255 |
256 | # Solve the full system
257 | solution = jnp.linalg.solve(A_aug, rhs_aug.T).T # (n_rhs, n_train + n_poly)
258 |
259 | # Extract weights and polynomial coefficients
260 | rbf_weights = solution[:, :n_train]
261 | poly_coeffs = solution[:, n_train:]
262 |
263 | return rbf_weights, poly_coeffs
264 |
--------------------------------------------------------------------------------
/tests/test_rbf.py:
--------------------------------------------------------------------------------
1 | """Tests for RBF matrix construction."""
2 |
3 | import jax.numpy as jnp
4 | import pytest
5 |
6 | from pod_rbf.rbf import (
7 | build_collocation_matrix,
8 | build_inference_matrix,
9 | build_polynomial_basis,
10 | solve_augmented_system_schur,
11 | )
12 |
13 |
14 | class TestBuildCollocationMatrix:
15 | """Test RBF collocation matrix construction."""
16 |
17 | def test_symmetry(self):
18 | """Collocation matrix should be symmetric."""
19 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
20 | params_range = jnp.array([4.0])
21 | C = build_collocation_matrix(params, params_range, shape_factor=1.0)
22 |
23 | assert jnp.allclose(C, C.T), "Collocation matrix should be symmetric"
24 |
25 | def test_diagonal_ones(self):
26 | """Diagonal should be 1 (r=0 -> phi=1)."""
27 | params = jnp.array([[1.0, 2.0, 3.0]])
28 | params_range = jnp.array([2.0])
29 | C = build_collocation_matrix(params, params_range, shape_factor=1.0)
30 |
31 | assert jnp.allclose(jnp.diag(C), 1.0), "Diagonal elements should be 1"
32 |
33 | def test_shape(self):
34 | """Output shape should be (n_train, n_train)."""
35 | n_train = 7
36 | params = jnp.linspace(0, 10, n_train)[None, :]
37 | params_range = jnp.array([10.0])
38 | C = build_collocation_matrix(params, params_range, shape_factor=0.5)
39 |
40 | assert C.shape == (n_train, n_train), f"Expected ({n_train}, {n_train}), got {C.shape}"
41 |
42 | def test_multi_param(self):
43 | """Should work with multiple parameters."""
44 | n_train = 5
45 | n_params = 3
46 | params = jnp.array([
47 | [1.0, 2.0, 3.0, 4.0, 5.0],
48 | [0.1, 0.2, 0.3, 0.4, 0.5],
49 | [10.0, 20.0, 30.0, 40.0, 50.0],
50 | ])
51 | params_range = jnp.array([4.0, 0.4, 40.0])
52 | C = build_collocation_matrix(params, params_range, shape_factor=1.0)
53 |
54 | assert C.shape == (n_train, n_train)
55 | assert jnp.allclose(C, C.T), "Multi-param collocation should be symmetric"
56 | assert jnp.allclose(jnp.diag(C), 1.0), "Diagonal should still be 1"
57 |
58 | def test_values_positive(self):
59 | """All values should be positive (IMQ kernel is always positive)."""
60 | params = jnp.array([[1.0, 5.0, 10.0, 20.0]])
61 | params_range = jnp.array([19.0])
62 | C = build_collocation_matrix(params, params_range, shape_factor=0.5)
63 |
64 | assert jnp.all(C > 0), "All values should be positive"
65 |
66 | def test_values_bounded(self):
67 | """Values should be in (0, 1] for IMQ kernel."""
68 | params = jnp.array([[1.0, 5.0, 10.0, 20.0]])
69 | params_range = jnp.array([19.0])
70 | C = build_collocation_matrix(params, params_range, shape_factor=0.5)
71 |
72 | assert jnp.all(C > 0), "All values should be > 0"
73 | assert jnp.all(C <= 1.0), "All values should be <= 1"
74 |
75 |
76 | class TestBuildInferenceMatrix:
77 | """Test RBF inference matrix construction."""
78 |
79 | def test_shape(self):
80 | """Output shape should be (n_inf, n_train)."""
81 | n_train = 5
82 | n_inf = 3
83 | train_params = jnp.linspace(0, 10, n_train)[None, :]
84 | inf_params = jnp.array([[2.5, 5.0, 7.5]])
85 | params_range = jnp.array([10.0])
86 |
87 | F = build_inference_matrix(train_params, inf_params, params_range, shape_factor=1.0)
88 |
89 | assert F.shape == (n_inf, n_train), f"Expected ({n_inf}, {n_train}), got {F.shape}"
90 |
91 | def test_at_training_points(self):
92 | """Inference at training points should have max value 1 at that column."""
93 | train_params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
94 | inf_params = jnp.array([[3.0]]) # Exactly at training point
95 | params_range = jnp.array([4.0])
96 |
97 | F = build_inference_matrix(train_params, inf_params, params_range, shape_factor=1.0)
98 |
99 | # At training point index 2, value should be 1
100 | assert jnp.isclose(F[0, 2], 1.0), f"Expected 1.0 at training point, got {F[0, 2]}"
101 |
102 | def test_values_positive(self):
103 | """All values should be positive."""
104 | train_params = jnp.array([[1.0, 5.0, 10.0]])
105 | inf_params = jnp.array([[2.0, 7.0]])
106 | params_range = jnp.array([9.0])
107 |
108 | F = build_inference_matrix(train_params, inf_params, params_range, shape_factor=0.5)
109 |
110 | assert jnp.all(F > 0), "All values should be positive"
111 |
112 | def test_multi_param(self):
113 | """Should work with multiple parameters."""
114 | train_params = jnp.array([
115 | [1.0, 2.0, 3.0],
116 | [0.1, 0.2, 0.3],
117 | ])
118 | inf_params = jnp.array([
119 | [1.5, 2.5],
120 | [0.15, 0.25],
121 | ])
122 | params_range = jnp.array([2.0, 0.2])
123 |
124 | F = build_inference_matrix(train_params, inf_params, params_range, shape_factor=1.0)
125 |
126 | assert F.shape == (2, 3)
127 | assert jnp.all(F > 0)
128 |
129 |
130 | class TestBuildPolynomialBasis:
131 | """Test polynomial basis matrix construction."""
132 |
133 | def test_degree_0_shape(self):
134 | """Degree 0 should return constant column only."""
135 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
136 | params_range = jnp.array([4.0])
137 | P = build_polynomial_basis(params, params_range, degree=0)
138 |
139 | assert P.shape == (5, 1), f"Expected (5, 1), got {P.shape}"
140 | assert jnp.allclose(P[:, 0], 1.0), "Constant column should be all ones"
141 |
142 | def test_degree_1_shape(self):
143 | """Degree 1 should return constant + linear terms."""
144 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
145 | params_range = jnp.array([4.0])
146 | P = build_polynomial_basis(params, params_range, degree=1)
147 |
148 | # 1D: [1, p] -> 2 columns
149 | assert P.shape == (5, 2), f"Expected (5, 2), got {P.shape}"
150 | assert jnp.allclose(P[:, 0], 1.0), "First column should be all ones"
151 |
152 | def test_degree_2_shape_1d(self):
153 | """Degree 2 with 1 param should return 3 columns."""
154 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
155 | params_range = jnp.array([4.0])
156 | P = build_polynomial_basis(params, params_range, degree=2)
157 |
158 | # 1D: [1, p, p²] -> 3 columns
159 | assert P.shape == (5, 3), f"Expected (5, 3), got {P.shape}"
160 |
161 | def test_degree_2_shape_2d(self):
162 | """Degree 2 with 2 params should return 6 columns."""
163 | params = jnp.array([
164 | [1.0, 2.0, 3.0, 4.0, 5.0],
165 | [0.1, 0.2, 0.3, 0.4, 0.5],
166 | ])
167 | params_range = jnp.array([4.0, 0.4])
168 | P = build_polynomial_basis(params, params_range, degree=2)
169 |
170 | # 2D: [1, p1, p2, p1², p2², p1*p2] -> 6 columns
171 | assert P.shape == (5, 6), f"Expected (5, 6), got {P.shape}"
172 |
173 | def test_normalized_values(self):
174 | """Polynomial values should scale appropriately with params_range."""
175 | params = jnp.array([[10.0, 20.0, 30.0, 40.0, 50.0]])
176 | params_range = jnp.array([40.0])
177 | P = build_polynomial_basis(params, params_range, degree=2)
178 |
179 | # Normalized values are params/range, so for [10, 20, 30, 40, 50]/40 = [0.25, 0.5, 0.75, 1.0, 1.25]
180 | # Constant term should be 1
181 | assert jnp.allclose(P[:, 0], 1.0), "Constant column should be all ones"
182 | # Linear terms should be params/range
183 | expected_linear = jnp.array([0.25, 0.5, 0.75, 1.0, 1.25])
184 | assert jnp.allclose(P[:, 1], expected_linear), f"Linear term mismatch: {P[:, 1]} vs {expected_linear}"
185 |
186 | def test_multi_param_cross_terms(self):
187 | """Cross terms should be computed correctly for multi-param case."""
188 | params = jnp.array([
189 | [0.0, 2.0, 4.0],
190 | [0.0, 2.0, 4.0],
191 | ])
192 | params_range = jnp.array([4.0, 4.0])
193 | P = build_polynomial_basis(params, params_range, degree=2)
194 |
195 | # Columns: [1, p1, p2, p1², p2², p1*p2]
196 | # Normalized: [0, 0.5, 1] for both params
197 | assert P.shape == (3, 6)
198 | # Last column is cross term p1*p2
199 | # At normalized values [0, 0.5, 1], cross terms are [0, 0.25, 1]
200 | expected_cross = jnp.array([0.0, 0.25, 1.0])
201 | assert jnp.allclose(P[:, 5], expected_cross), f"Cross term mismatch: {P[:, 5]} vs {expected_cross}"
202 |
203 |
204 | class TestSolveAugmentedSystemSchur:
205 | """Test Schur complement solver."""
206 |
207 | def test_simple_system(self):
208 | """Solve a simple augmented system and verify solution."""
209 | # Simple 3x3 SPD matrix
210 | F = jnp.array([
211 | [4.0, 1.0, 0.5],
212 | [1.0, 3.0, 0.5],
213 | [0.5, 0.5, 2.0],
214 | ])
215 | # Linear polynomial basis
216 | P = jnp.array([
217 | [1.0, 0.0],
218 | [1.0, 0.5],
219 | [1.0, 1.0],
220 | ])
221 | # Single RHS
222 | rhs = jnp.array([[1.0, 2.0, 3.0]])
223 |
224 | rbf_weights, poly_coeffs = solve_augmented_system_schur(F, P, rhs)
225 |
226 | assert rbf_weights.shape == (1, 3), f"RBF weights shape: {rbf_weights.shape}"
227 | assert poly_coeffs.shape == (1, 2), f"Poly coeffs shape: {poly_coeffs.shape}"
228 |
229 | # Verify solution satisfies augmented system
230 | # F @ λ + P @ c = rhs
231 | lhs1 = F @ rbf_weights.T + P @ poly_coeffs.T
232 | assert jnp.allclose(lhs1.T, rhs, rtol=1e-5), f"First equation not satisfied: {lhs1.T} vs {rhs}"
233 |
234 | # P.T @ λ = 0 (orthogonality constraint)
235 | lhs2 = P.T @ rbf_weights.T
236 | assert jnp.allclose(lhs2, 0, atol=1e-10), f"Orthogonality constraint not satisfied: {lhs2}"
237 |
238 | def test_multiple_rhs(self):
239 | """Solve system with multiple right-hand sides."""
240 | F = jnp.array([
241 | [4.0, 1.0, 0.5],
242 | [1.0, 3.0, 0.5],
243 | [0.5, 0.5, 2.0],
244 | ])
245 | P = jnp.array([
246 | [1.0, 0.0],
247 | [1.0, 0.5],
248 | [1.0, 1.0],
249 | ])
250 | # Two RHS
251 | rhs = jnp.array([
252 | [1.0, 2.0, 3.0],
253 | [0.5, 1.0, 1.5],
254 | ])
255 |
256 | rbf_weights, poly_coeffs = solve_augmented_system_schur(F, P, rhs)
257 |
258 | assert rbf_weights.shape == (2, 3)
259 | assert poly_coeffs.shape == (2, 2)
260 |
261 | # Verify both solutions satisfy the system
262 | for i in range(2):
263 | lhs1 = F @ rbf_weights[i, :] + P @ poly_coeffs[i, :]
264 | assert jnp.allclose(lhs1, rhs[i, :], rtol=1e-5), f"RHS {i}: First equation not satisfied"
265 |
266 | lhs2 = P.T @ rbf_weights[i, :]
267 | assert jnp.allclose(lhs2, 0, atol=1e-10), f"RHS {i}: Orthogonality constraint not satisfied"
268 |
269 | def test_polynomial_reproduction(self):
270 | """Schur complement should reproduce polynomials exactly."""
271 | # If RHS is in the span of P, polynomial coeffs should capture it fully
272 | n_points = 5
273 | params = jnp.linspace(0, 1, n_points)[None, :]
274 | params_range = jnp.array([1.0])
275 |
276 | F = build_collocation_matrix(params, params_range, shape_factor=0.5)
277 | P = build_polynomial_basis(params, params_range, degree=1) # [1, p]
278 |
279 | # RHS that's a linear function: f = 2 + 3*p
280 | rhs = 2.0 + 3.0 * params # shape (1, n_points)
281 |
282 | rbf_weights, poly_coeffs = solve_augmented_system_schur(F, P, rhs)
283 |
284 | # RBF weights should be near zero (polynomial captures everything)
285 | assert jnp.allclose(rbf_weights, 0, atol=1e-8), f"RBF weights should be ~0 for polynomial RHS: {rbf_weights}"
286 |
287 | # Poly coeffs should be [2, 3]
288 | assert jnp.allclose(poly_coeffs[0, 0], 2.0, rtol=1e-5), f"Constant coeff: {poly_coeffs[0, 0]}"
289 | assert jnp.allclose(poly_coeffs[0, 1], 3.0, rtol=1e-5), f"Linear coeff: {poly_coeffs[0, 1]}"
290 |
291 |
292 | class TestCollocationMatrixKernels:
293 | """Test collocation matrix with different kernel types."""
294 |
295 | def test_imq_kernel(self):
296 | """Test collocation matrix with IMQ kernel."""
297 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
298 | params_range = jnp.array([4.0])
299 | C = build_collocation_matrix(
300 | params, params_range, kernel="imq", shape_factor=1.0, kernel_order=3
301 | )
302 |
303 | assert C.shape == (5, 5)
304 | assert jnp.allclose(C, C.T)
305 | assert jnp.allclose(jnp.diag(C), 1.0)
306 | assert jnp.all(C > 0) and jnp.all(C <= 1.0)
307 |
308 | def test_gaussian_kernel(self):
309 | """Test collocation matrix with Gaussian kernel."""
310 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
311 | params_range = jnp.array([4.0])
312 | C = build_collocation_matrix(
313 | params, params_range, kernel="gaussian", shape_factor=1.0, kernel_order=3
314 | )
315 |
316 | assert C.shape == (5, 5)
317 | assert jnp.allclose(C, C.T)
318 | assert jnp.allclose(jnp.diag(C), 1.0)
319 | assert jnp.all(C > 0) and jnp.all(C <= 1.0)
320 |
321 | def test_phs_kernel(self):
322 | """Test collocation matrix with polyharmonic spline kernel."""
323 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
324 | params_range = jnp.array([4.0])
325 | C = build_collocation_matrix(
326 | params,
327 | params_range,
328 | kernel="polyharmonic_spline",
329 | shape_factor=None,
330 | kernel_order=3,
331 | )
332 |
333 | assert C.shape == (5, 5)
334 | assert jnp.allclose(C, C.T)
335 | assert jnp.allclose(jnp.diag(C), 0.0) # PHS is 0 at r=0
336 |
337 | def test_kernels_produce_different_matrices(self):
338 | """Different kernels should produce different matrices."""
339 | params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
340 | params_range = jnp.array([4.0])
341 |
342 | C_imq = build_collocation_matrix(
343 | params, params_range, kernel="imq", shape_factor=1.0
344 | )
345 | C_gauss = build_collocation_matrix(
346 | params, params_range, kernel="gaussian", shape_factor=1.0
347 | )
348 | C_phs = build_collocation_matrix(
349 | params, params_range, kernel="polyharmonic_spline", kernel_order=3
350 | )
351 |
352 | # Matrices should be different
353 | assert not jnp.allclose(C_imq, C_gauss)
354 | assert not jnp.allclose(C_imq, C_phs)
355 | assert not jnp.allclose(C_gauss, C_phs)
356 |
357 |
358 | class TestInferenceMatrixKernels:
359 | """Test inference matrix with different kernel types."""
360 |
361 | def test_imq_kernel(self):
362 | """Test inference matrix with IMQ kernel."""
363 | train_params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
364 | inf_params = jnp.array([[2.5, 3.5]])
365 | params_range = jnp.array([4.0])
366 |
367 | F = build_inference_matrix(
368 | train_params,
369 | inf_params,
370 | params_range,
371 | kernel="imq",
372 | shape_factor=1.0,
373 | kernel_order=3,
374 | )
375 |
376 | assert F.shape == (2, 5)
377 | assert jnp.all(F > 0)
378 |
379 | def test_gaussian_kernel(self):
380 | """Test inference matrix with Gaussian kernel."""
381 | train_params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
382 | inf_params = jnp.array([[2.5, 3.5]])
383 | params_range = jnp.array([4.0])
384 |
385 | F = build_inference_matrix(
386 | train_params,
387 | inf_params,
388 | params_range,
389 | kernel="gaussian",
390 | shape_factor=1.0,
391 | kernel_order=3,
392 | )
393 |
394 | assert F.shape == (2, 5)
395 | assert jnp.all(F > 0)
396 |
397 | def test_phs_kernel(self):
398 | """Test inference matrix with polyharmonic spline kernel."""
399 | train_params = jnp.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
400 | inf_params = jnp.array([[2.5, 3.5]])
401 | params_range = jnp.array([4.0])
402 |
403 | F = build_inference_matrix(
404 | train_params,
405 | inf_params,
406 | params_range,
407 | kernel="polyharmonic_spline",
408 | shape_factor=None,
409 | kernel_order=3,
410 | )
411 |
412 | assert F.shape == (2, 5)
413 |
414 | def test_at_training_point_all_kernels(self):
415 | """All kernels should have value 1 (IMQ/Gauss) or 0 (PHS) at r=0."""
416 | train_params = jnp.array([[1.0, 2.0, 3.0]])
417 | inf_params = jnp.array([[2.0]]) # Exactly at training point
418 | params_range = jnp.array([2.0])
419 |
420 | # IMQ
421 | F_imq = build_inference_matrix(
422 | train_params, inf_params, params_range, kernel="imq", shape_factor=1.0
423 | )
424 | assert jnp.isclose(F_imq[0, 1], 1.0)
425 |
426 | # Gaussian
427 | F_gauss = build_inference_matrix(
428 | train_params,
429 | inf_params,
430 | params_range,
431 | kernel="gaussian",
432 | shape_factor=1.0,
433 | )
434 | assert jnp.isclose(F_gauss[0, 1], 1.0)
435 |
436 | # PHS
437 | F_phs = build_inference_matrix(
438 | train_params,
439 | inf_params,
440 | params_range,
441 | kernel="polyharmonic_spline",
442 | kernel_order=3,
443 | )
444 | assert jnp.isclose(F_phs[0, 1], 0.0)
445 |
--------------------------------------------------------------------------------
/examples/shape-optimization/optimize_shape.py:
--------------------------------------------------------------------------------
1 | """
2 | Shape Parameter Optimization via Autodiff
3 |
4 | Demonstrates using JAX autodifferentiation with JAXopt L-BFGS to find
5 | the optimal RBF shape parameter using the Rippa criterion (LOO-CV).
6 |
7 | The Rippa criterion provides a closed-form leave-one-out cross-validation
8 | error that is differentiable with respect to the shape parameter, enabling
9 | gradient-based optimization.
10 |
11 | Key concepts demonstrated:
12 | 1. Rippa criterion (LOO-CV) for RBF shape parameter selection
13 | 2. JAX autodiff through RBF matrix construction
14 | 3. JAXopt L-BFGS optimizer for scientific computing
15 | 4. Comparison with POD-RBF's condition-number based auto-optimization
16 |
17 | Requirements:
18 | pip install jaxopt
19 |
20 | References:
21 | Rippa, S. (1999). "An algorithm for selecting a good value for the
22 | parameter c in radial basis function interpolation."
23 | Advances in Computational Mathematics, 11(2-3), 193-210.
24 | """
25 |
26 | import time
27 |
28 | import jax
29 | import jax.numpy as jnp
30 | import jaxopt
31 | import matplotlib.pyplot as plt
32 | import numpy as np
33 |
34 | from pod_rbf.rbf import build_collocation_matrix
35 |
36 | # Use CPU for reproducibility
37 | jax.config.update("jax_default_device", jax.devices("cpu")[0])
38 |
39 |
40 | # =============================================================================
41 | # Test Function
42 | # =============================================================================
43 |
44 |
45 | def runge_function(x):
46 | """
47 | Runge function - a classic test case for interpolation.
48 |
49 | f(x) = 1 / (1 + 25x^2)
50 |
51 | This function has a sharp peak at x=0 and is challenging to interpolate
52 | accurately, making it ideal for demonstrating the importance of shape
53 | parameter selection.
54 | """
55 | return 1.0 / (1.0 + 25.0 * x**2)
56 |
57 |
58 | # =============================================================================
59 | # Rippa Criterion (LOO-CV Cost Function)
60 | # =============================================================================
61 |
62 |
63 | def loocv_cost(shape_factor, x, y, kernel="imq"):
64 | """
65 | Compute leave-one-out cross-validation error using Rippa's closed-form formula.
66 |
67 | The Rippa criterion computes the LOO-CV error without actually performing
68 | n separate leave-one-out fits. For RBF interpolation A @ c = y, the LOO
69 | error at point i is:
70 |
71 | e_i = c_i / A_inv_ii
72 |
73 | where c_i is the i-th interpolation coefficient and A_inv_ii is the i-th
74 | diagonal element of A^{-1}.
75 |
76 | Parameters
77 | ----------
78 | shape_factor : float
79 | RBF shape parameter to evaluate.
80 | x : Array
81 | Training point locations, shape (n_points,).
82 | y : Array
83 | Training values, shape (n_points,).
84 | kernel : str
85 | Kernel type: 'imq' or 'gaussian'.
86 |
87 | Returns
88 | -------
89 | float
90 | Mean squared LOO-CV error.
91 | """
92 | # Build collocation matrix with current shape factor
93 | # pod_rbf expects (n_params, n_points) shape
94 | x_2d = x[None, :]
95 | x_range = jnp.array([jnp.ptp(x)])
96 |
97 | A = build_collocation_matrix(x_2d, x_range, kernel=kernel, shape_factor=shape_factor)
98 |
99 | # Solve for RBF coefficients: A @ c = y
100 | c = jnp.linalg.solve(A, y)
101 |
102 | # Compute diagonal of A^{-1}
103 | # For numerical stability, we could use the formula:
104 | # diag(A^{-1})_i = e_i^T @ A^{-1} @ e_i
105 | # But direct inversion is fine for moderate problem sizes
106 | A_inv = jnp.linalg.inv(A)
107 | A_inv_diag = jnp.diag(A_inv)
108 |
109 | # Rippa criterion: LOO error at each point
110 | loo_errors = c / A_inv_diag
111 |
112 | # Return mean squared error
113 | return jnp.mean(loo_errors**2)
114 |
115 |
116 | # =============================================================================
117 | # Shape Parameter Optimization
118 | # =============================================================================
119 |
120 |
121 | def optimize_shape_parameter(x, y, kernel="imq", initial_guess=1.0, verbose=True):
122 | """
123 | Find optimal shape parameter using L-BFGS on Rippa criterion.
124 |
125 | Uses JAXopt's L-BFGS optimizer to minimize the LOO-CV error, with gradients
126 | computed automatically via JAX autodiff.
127 |
128 | We optimize in log-space (log_c) to ensure positivity without breaking gradients.
129 | shape_factor = exp(log_c)
130 |
131 | Parameters
132 | ----------
133 | x : Array
134 | Training point locations.
135 | y : Array
136 | Training values.
137 | kernel : str
138 | Kernel type.
139 | initial_guess : float
140 | Starting value for shape parameter.
141 | verbose : bool
142 | Print optimization progress.
143 |
144 | Returns
145 | -------
146 | optimal_shape : float
147 | Optimized shape parameter.
148 | history : dict
149 | Optimization history with shape parameters, costs, and gradients.
150 | """
151 | history = {"shape_factor": [initial_guess], "cost": [], "grad_log": []}
152 |
153 | # Optimize in log-space for unconstrained optimization with guaranteed positivity
154 | # shape_factor = exp(log_c), so we optimize log_c
155 | def objective(log_c):
156 | shape_factor = jnp.exp(log_c[0])
157 | return loocv_cost(shape_factor, x, y, kernel)
158 |
159 | # JIT compile for speed
160 | objective_jit = jax.jit(objective)
161 | grad_fn = jax.jit(jax.grad(objective))
162 |
163 | # Initial evaluation (in log space)
164 | log_c_init = jnp.log(initial_guess)
165 | init_cost = objective_jit(jnp.array([log_c_init]))
166 | init_grad = grad_fn(jnp.array([log_c_init]))
167 | history["cost"].append(float(init_cost))
168 | history["grad_log"].append(float(init_grad[0]))
169 |
170 | if verbose:
171 | print(f"Initial: shape_factor = {initial_guess:.6f}, LOO-CV cost = {init_cost:.6e}")
172 |
173 | # Create L-BFGS solver
174 | solver = jaxopt.LBFGS(
175 | fun=objective,
176 | maxiter=100,
177 | tol=1e-10,
178 | )
179 |
180 | # Initialize in log space
181 | log_c = jnp.array([log_c_init])
182 |
183 | # Run optimization with manual iteration to track history
184 | state = solver.init_state(log_c)
185 |
186 | for i in range(100):
187 | log_c, state = solver.update(log_c, state)
188 | cost = objective_jit(log_c)
189 | grad = grad_fn(log_c)
190 |
191 | shape_factor = float(jnp.exp(log_c[0]))
192 | history["shape_factor"].append(shape_factor)
193 | history["cost"].append(float(cost))
194 | history["grad_log"].append(float(grad[0]))
195 |
196 | if verbose and (i + 1) % 10 == 0:
197 | print(
198 | f" Iter {i+1:3d}: shape_factor = {shape_factor:.6f}, "
199 | f"cost = {cost:.6e}, |grad| = {jnp.abs(grad[0]):.2e}"
200 | )
201 |
202 | # Check convergence (gradient in log space)
203 | if jnp.abs(grad[0]) < 1e-10:
204 | if verbose:
205 | print(f" Converged at iteration {i+1}")
206 | break
207 |
208 | optimal_shape = float(jnp.exp(log_c[0]))
209 |
210 | if verbose:
211 | print(f"\nOptimal shape parameter: {optimal_shape:.6f}")
212 | print(f"Final LOO-CV cost: {history['cost'][-1]:.6e}")
213 |
214 | return optimal_shape, history
215 |
216 |
217 | # =============================================================================
218 | # RBF Interpolation Helper
219 | # =============================================================================
220 |
221 |
222 | def rbf_interpolate(x_train, y_train, x_eval, shape_factor, kernel="imq"):
223 | """
224 | Perform RBF interpolation at evaluation points.
225 |
226 | Parameters
227 | ----------
228 | x_train : Array
229 | Training point locations.
230 | y_train : Array
231 | Training values.
232 | x_eval : Array
233 | Points at which to evaluate the interpolant.
234 | shape_factor : float
235 | RBF shape parameter.
236 | kernel : str
237 | Kernel type.
238 |
239 | Returns
240 | -------
241 | Array
242 | Interpolated values at x_eval.
243 | """
244 | from pod_rbf.rbf import build_inference_matrix
245 |
246 | x_train_2d = x_train[None, :]
247 | x_eval_2d = x_eval[None, :]
248 | x_range = jnp.array([jnp.ptp(x_train)])
249 |
250 | # Build collocation matrix and solve for coefficients
251 | A = build_collocation_matrix(
252 | x_train_2d, x_range, kernel=kernel, shape_factor=shape_factor
253 | )
254 | c = jnp.linalg.solve(A, y_train)
255 |
256 | # Build inference matrix and evaluate
257 | F = build_inference_matrix(
258 | x_train_2d, x_eval_2d, x_range, kernel=kernel, shape_factor=shape_factor
259 | )
260 |
261 | return F @ c
262 |
263 |
264 | # =============================================================================
265 | # Visualization
266 | # =============================================================================
267 |
268 |
269 | def plot_loocv_landscape(x, y, optimal_shape, kernel="imq"):
270 | """Plot LOO-CV cost as a function of shape parameter."""
271 | shape_factors = np.logspace(-2, 1, 100)
272 | costs = []
273 |
274 | loocv_jit = jax.jit(lambda c: loocv_cost(c, x, y, kernel))
275 |
276 | for c in shape_factors:
277 | costs.append(float(loocv_jit(c)))
278 |
279 | fig, ax = plt.subplots(figsize=(8, 5))
280 | ax.semilogy(shape_factors, costs, "b-", linewidth=2, label="LOO-CV cost")
281 | ax.axvline(
282 | optimal_shape, color="r", linestyle="--", linewidth=2, label=f"Optimal: {optimal_shape:.4f}"
283 | )
284 | ax.set_xlabel("Shape Parameter (c)", fontsize=12)
285 | ax.set_ylabel("LOO-CV Cost (log scale)", fontsize=12)
286 | ax.set_title("LOO-CV Error Landscape", fontsize=14)
287 | ax.legend(fontsize=11)
288 | ax.grid(True, alpha=0.3)
289 | ax.set_xscale("log")
290 |
291 | return fig
292 |
293 |
294 | def plot_interpolation_comparison(x_train, y_train, x_eval, y_true, shape_factors, kernel="imq"):
295 | """Compare interpolation with different shape parameters."""
296 | n_shapes = len(shape_factors)
297 | fig, axes = plt.subplots(1, n_shapes, figsize=(5 * n_shapes, 4))
298 |
299 | if n_shapes == 1:
300 | axes = [axes]
301 |
302 | for ax, shape_factor in zip(axes, shape_factors):
303 | # Interpolate
304 | y_interp = rbf_interpolate(x_train, y_train, x_eval, shape_factor, kernel)
305 |
306 | # Compute error
307 | rmse = np.sqrt(np.mean((np.array(y_interp) - np.array(y_true)) ** 2))
308 |
309 | # Plot
310 | ax.plot(x_eval, y_true, "k-", linewidth=2, label="True function")
311 | ax.plot(x_eval, y_interp, "b--", linewidth=2, label="RBF interpolant")
312 | ax.scatter(x_train, y_train, c="r", s=50, zorder=5, label="Training points")
313 |
314 | ax.set_xlabel("x", fontsize=11)
315 | ax.set_ylabel("f(x)", fontsize=11)
316 | ax.set_title(f"c = {shape_factor:.4f}, RMSE = {rmse:.2e}", fontsize=12)
317 | ax.legend(fontsize=9)
318 | ax.grid(True, alpha=0.3)
319 |
320 | plt.tight_layout()
321 | return fig
322 |
323 |
324 | def plot_convergence(history):
325 | """Plot optimization convergence."""
326 | fig, axes = plt.subplots(1, 3, figsize=(15, 4))
327 |
328 | # Cost convergence
329 | ax = axes[0]
330 | ax.semilogy(history["cost"], "b-", linewidth=2, marker="o", markersize=4)
331 | ax.set_xlabel("Iteration", fontsize=12)
332 | ax.set_ylabel("LOO-CV Cost (log scale)", fontsize=12)
333 | ax.set_title("Cost Convergence", fontsize=14)
334 | ax.grid(True, alpha=0.3)
335 |
336 | # Shape parameter evolution
337 | ax = axes[1]
338 | ax.plot(history["shape_factor"], "r-", linewidth=2, marker="o", markersize=4)
339 | ax.set_xlabel("Iteration", fontsize=12)
340 | ax.set_ylabel("Shape Parameter", fontsize=12)
341 | ax.set_title("Shape Parameter Evolution", fontsize=14)
342 | ax.grid(True, alpha=0.3)
343 |
344 | # Gradient magnitude (in log space)
345 | ax = axes[2]
346 | grad_key = "grad_log" if "grad_log" in history else "grad"
347 | ax.semilogy(np.abs(history[grad_key]), "g-", linewidth=2, marker="o", markersize=4)
348 | ax.set_xlabel("Iteration", fontsize=12)
349 | ax.set_ylabel("|Gradient| (log scale)", fontsize=12)
350 | ax.set_title("Gradient Magnitude", fontsize=14)
351 | ax.grid(True, alpha=0.3)
352 |
353 | plt.tight_layout()
354 | return fig
355 |
356 |
357 | # =============================================================================
358 | # Main
359 | # =============================================================================
360 |
361 |
362 | def find_good_initial_guess(x, y, kernel="imq"):
363 | """Scan a range of shape parameters to find a good starting point."""
364 | loocv_jit = jax.jit(lambda c: loocv_cost(c, x, y, kernel))
365 |
366 | # Scan over a range of shape parameters
367 | shape_factors = np.logspace(-2, 1, 50)
368 | costs = []
369 |
370 | for c in shape_factors:
371 | try:
372 | cost = float(loocv_jit(c))
373 | if np.isfinite(cost):
374 | costs.append(cost)
375 | else:
376 | costs.append(np.inf)
377 | except Exception:
378 | costs.append(np.inf)
379 |
380 | # Find minimum
381 | best_idx = np.argmin(costs)
382 | return shape_factors[best_idx], shape_factors, costs
383 |
384 |
385 | def main():
386 | print("=" * 70)
387 | print("Shape Parameter Optimization via Autodiff (Rippa Criterion + L-BFGS)")
388 | print("=" * 70)
389 |
390 | # Setup
391 | np.random.seed(42)
392 | kernel = "imq"
393 |
394 | # Generate training data (Runge function)
395 | n_train = 15
396 | x_train = jnp.linspace(-1, 1, n_train)
397 | y_train = runge_function(x_train)
398 |
399 | # Dense evaluation grid
400 | x_eval = jnp.linspace(-1, 1, 200)
401 | y_true = runge_function(x_eval)
402 |
403 | print(f"\nTest problem: Runge function f(x) = 1/(1 + 25x^2)")
404 | print(f"Training points: {n_train}")
405 | print(f"Kernel: {kernel.upper()}")
406 |
407 | # ==========================================================================
408 | # First, scan the landscape to find a good initial guess
409 | # ==========================================================================
410 | print("\n" + "-" * 70)
411 | print("Scanning LOO-CV landscape for good initial guess...")
412 | print("-" * 70)
413 |
414 | initial_guess, scan_shapes, scan_costs = find_good_initial_guess(x_train, y_train, kernel)
415 | print(f"Best from scan: shape_factor = {initial_guess:.6f}, LOO-CV = {np.min(scan_costs):.6e}")
416 |
417 | # ==========================================================================
418 | # Optimize shape parameter using Rippa criterion + L-BFGS
419 | # ==========================================================================
420 | print("\n" + "-" * 70)
421 | print("Refining with JAXopt L-BFGS...")
422 | print("-" * 70)
423 |
424 | start = time.time()
425 | optimal_shape, history = optimize_shape_parameter(
426 | x_train, y_train, kernel=kernel, initial_guess=initial_guess, verbose=True
427 | )
428 | opt_time = time.time() - start
429 | print(f"Optimization time: {opt_time:.3f} sec")
430 |
431 | # ==========================================================================
432 | # Compare with different shape parameters
433 | # ==========================================================================
434 | print("\n" + "-" * 70)
435 | print("Comparing interpolation accuracy...")
436 | print("-" * 70)
437 |
438 | # Test several shape parameters
439 | test_shapes = [0.1, optimal_shape, 2.0]
440 | shape_labels = ["Too small (0.1)", f"Optimal ({optimal_shape:.4f})", "Too large (2.0)"]
441 |
442 | for shape, label in zip(test_shapes, shape_labels):
443 | y_interp = rbf_interpolate(x_train, y_train, x_eval, shape, kernel)
444 | rmse = np.sqrt(np.mean((np.array(y_interp) - np.array(y_true)) ** 2))
445 | loocv = loocv_cost(shape, x_train, y_train, kernel)
446 | print(f" {label:25s}: RMSE = {rmse:.4e}, LOO-CV = {loocv:.4e}")
447 |
448 | # ==========================================================================
449 | # Demonstrate autodiff explicitly
450 | # ==========================================================================
451 | print("\n" + "-" * 70)
452 | print("Demonstrating autodiff capabilities...")
453 | print("-" * 70)
454 |
455 | # Show gradient computation
456 | grad_fn = jax.jit(jax.grad(lambda c: loocv_cost(c, x_train, y_train, kernel)))
457 |
458 | for c in [0.1, 0.5, optimal_shape, 2.0]:
459 | grad = grad_fn(c)
460 | print(f" d(LOO-CV)/dc at c={c:.4f}: {grad:.6e}")
461 |
462 | # Show Hessian computation
463 | hess_fn = jax.jit(jax.hessian(lambda c: loocv_cost(c, x_train, y_train, kernel)))
464 | hess_at_opt = hess_fn(optimal_shape)
465 | print(f"\n d²(LOO-CV)/dc² at optimal c={optimal_shape:.4f}: {hess_at_opt:.6e}")
466 | print(" (Positive Hessian confirms this is a local minimum)")
467 |
468 | # ==========================================================================
469 | # Visualizations
470 | # ==========================================================================
471 | print("\n" + "-" * 70)
472 | print("Generating plots...")
473 | print("-" * 70)
474 |
475 | # Plot 1: LOO-CV landscape
476 | fig1 = plot_loocv_landscape(x_train, y_train, optimal_shape, kernel)
477 |
478 | # Plot 2: Interpolation comparison
479 | fig2 = plot_interpolation_comparison(
480 | x_train, y_train, x_eval, y_true, test_shapes, kernel
481 | )
482 |
483 | # Plot 3: Convergence history
484 | fig3 = plot_convergence(history)
485 |
486 | print("\n" + "=" * 70)
487 | print("Done! This example demonstrated:")
488 | print(" 1. Rippa criterion (LOO-CV) for RBF shape parameter selection")
489 | print(" 2. JAX autodiff through build_collocation_matrix")
490 | print(" 3. JAXopt L-BFGS optimization")
491 | print(" 4. Gradient and Hessian computation of LOO-CV cost")
492 | print("=" * 70)
493 |
494 | plt.show()
495 |
496 |
497 | if __name__ == "__main__":
498 | main()
499 |
--------------------------------------------------------------------------------
/tests/test_core.py:
--------------------------------------------------------------------------------
1 | """Tests for core train/inference functions."""
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import numpy as np
6 | import pytest
7 |
8 | import pod_rbf
9 | from pod_rbf.core import inference, inference_single, train
10 | from pod_rbf.types import ModelState, TrainConfig, TrainResult
11 |
12 |
13 | class TestTrain:
14 | """Test training function."""
15 |
16 | @pytest.fixture
17 | def linear_data(self):
18 | """Create simple linear test data: f(x, p) = p * x."""
19 | x = jnp.linspace(0, 1, 50)
20 | params = jnp.linspace(1, 10, 10)
21 | # snapshot[i, j] = params[j] * x[i]
22 | snapshot = x[:, None] * params[None, :]
23 | return snapshot, params, x
24 |
25 | def test_returns_train_result(self, linear_data):
26 | """Train should return TrainResult."""
27 | snapshot, params, _ = linear_data
28 | result = train(snapshot, params)
29 |
30 | assert isinstance(result, TrainResult)
31 | assert isinstance(result.state, ModelState)
32 | assert result.n_modes > 0
33 | assert isinstance(result.used_eig_decomp, bool)
34 |
35 | def test_model_state_shapes(self, linear_data):
36 | """Model state should have correct shapes."""
37 | snapshot, params, _ = linear_data
38 | result = train(snapshot, params)
39 | state = result.state
40 |
41 | n_samples, n_snapshots = snapshot.shape
42 | n_modes = result.n_modes
43 |
44 | assert state.basis.shape == (n_samples, n_modes)
45 | assert state.weights.shape == (n_modes, n_snapshots)
46 | assert state.train_params.shape == (1, n_snapshots)
47 | assert state.params_range.shape == (1,)
48 |
49 | def test_custom_config(self, linear_data):
50 | """Should accept custom config."""
51 | snapshot, params, _ = linear_data
52 | config = TrainConfig(energy_threshold=0.9)
53 | result = train(snapshot, params, config=config)
54 |
55 | assert result.state.truncated_energy >= 0.9
56 |
57 | def test_fixed_shape_factor(self, linear_data):
58 | """Should accept fixed shape factor."""
59 | snapshot, params, _ = linear_data
60 | result = train(snapshot, params, shape_factor=0.5)
61 |
62 | assert result.state.shape_factor == 0.5
63 |
64 | def test_multi_param(self):
65 | """Should work with multiple parameters."""
66 | x = jnp.linspace(0, 1, 30)
67 | p1 = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
68 | p2 = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5])
69 | params = jnp.stack([p1, p2], axis=0)
70 |
71 | # f(x, p1, p2) = p1 * x + p2
72 | snapshot = x[:, None] * p1[None, :] + p2[None, :]
73 |
74 | result = train(snapshot, params)
75 |
76 | assert result.state.train_params.shape == (2, 5)
77 | assert result.state.params_range.shape == (2,)
78 |
79 |
80 | class TestInference:
81 | """Test inference functions."""
82 |
83 | @pytest.fixture
84 | def trained_model(self):
85 | """Create and train a model on linear data."""
86 | x = jnp.linspace(0, 1, 50)
87 | params = jnp.linspace(1, 10, 10)
88 | snapshot = x[:, None] * params[None, :]
89 | result = train(snapshot, params)
90 | return result.state, x, params, snapshot
91 |
92 | def test_inference_single_shape(self, trained_model):
93 | """inference_single should return 1D array."""
94 | state, x, _, _ = trained_model
95 | pred = inference_single(state, jnp.array(5.0))
96 |
97 | assert pred.shape == (len(x),)
98 |
99 | def test_inference_batch_shape(self, trained_model):
100 | """inference should return 2D array for batch input."""
101 | state, x, _, _ = trained_model
102 | pred = inference(state, jnp.array([[2.0, 5.0, 8.0]]))
103 |
104 | assert pred.shape == (len(x), 3)
105 |
106 | def test_interpolation_at_training_points(self, trained_model):
107 | """Inference at training points should match training data."""
108 | state, x, params, snapshot = trained_model
109 |
110 | for i, p in enumerate(params):
111 | pred = inference_single(state, p)
112 | expected = snapshot[:, i]
113 | assert jnp.allclose(pred, expected, rtol=1e-3), f"Mismatch at training point {i}"
114 |
115 | def test_interpolation_between_points(self, trained_model):
116 | """Inference between training points should be accurate for linear data."""
117 | state, x, _, _ = trained_model
118 |
119 | # Test at midpoint
120 | pred = inference_single(state, jnp.array(5.5))
121 | expected = 5.5 * x
122 |
123 | assert jnp.allclose(pred, expected, rtol=1e-2), "Interpolation at midpoint should be accurate"
124 |
125 | def test_scalar_param(self, trained_model):
126 | """Should handle scalar parameter input."""
127 | state, x, _, _ = trained_model
128 | pred = inference_single(state, 5.0)
129 |
130 | assert pred.shape == (len(x),)
131 |
132 |
133 | class TestGradients:
134 | """Test autodifferentiation capabilities."""
135 |
136 | @pytest.fixture
137 | def trained_model(self):
138 | """Create trained model for gradient tests."""
139 | x = jnp.linspace(0, 1, 50)
140 | params = jnp.linspace(1, 10, 10)
141 | snapshot = x[:, None] * params[None, :]
142 | result = train(snapshot, params)
143 | return result.state, x
144 |
145 | def test_grad_wrt_param(self, trained_model):
146 | """Gradient of inference w.r.t. parameter should exist and be non-zero."""
147 | state, _ = trained_model
148 |
149 | def loss(p):
150 | pred = inference_single(state, p)
151 | return jnp.sum(pred**2)
152 |
153 | grad_fn = jax.grad(loss)
154 | grad = grad_fn(jnp.array(5.0))
155 |
156 | assert not jnp.isnan(grad), "Gradient should not be NaN"
157 | assert grad != 0.0, "Gradient should be non-zero"
158 |
159 | def test_inverse_problem(self, trained_model):
160 | """Test solving inverse problem via gradient descent."""
161 | state, x = trained_model
162 |
163 | # Target: parameter = 7.5
164 | target = 7.5 * x
165 |
166 | def loss(p):
167 | pred = inference_single(state, p)
168 | return jnp.mean((pred - target) ** 2)
169 |
170 | # Gradient descent
171 | p = jnp.array(5.0) # Initial guess
172 | lr = 0.5
173 |
174 | for _ in range(50):
175 | g = jax.grad(loss)(p)
176 | p = p - lr * g
177 |
178 | assert jnp.abs(p - 7.5) < 0.1, f"Recovered parameter {p} should be close to 7.5"
179 |
180 | def test_jacobian(self, trained_model):
181 | """Jacobian should have correct shape."""
182 | state, x = trained_model
183 |
184 | jacobian = jax.jacobian(lambda p: inference_single(state, p))(jnp.array(5.0))
185 |
186 | assert jacobian.shape == (len(x),), f"Jacobian shape should be ({len(x)},), got {jacobian.shape}"
187 |
188 | def test_value_and_grad(self, trained_model):
189 | """value_and_grad should work."""
190 | state, _ = trained_model
191 |
192 | def loss(p):
193 | pred = inference_single(state, p)
194 | return jnp.sum(pred**2)
195 |
196 | val, grad = jax.value_and_grad(loss)(jnp.array(5.0))
197 |
198 | assert not jnp.isnan(val)
199 | assert not jnp.isnan(grad)
200 |
201 |
202 | class TestSchurComplement:
203 | """Test Schur complement solver integration."""
204 |
205 | @pytest.fixture
206 | def linear_data(self):
207 | """Create simple linear test data: f(x, p) = p * x."""
208 | x = jnp.linspace(0, 1, 50)
209 | params = jnp.linspace(1, 10, 10)
210 | snapshot = x[:, None] * params[None, :]
211 | return snapshot, params, x
212 |
213 | def test_model_state_has_poly_fields(self, linear_data):
214 | """Model state should have polynomial coefficient fields."""
215 | snapshot, params, _ = linear_data
216 | result = train(snapshot, params)
217 | state = result.state
218 |
219 | assert hasattr(state, "poly_coeffs"), "ModelState should have poly_coeffs field"
220 | assert hasattr(state, "poly_degree"), "ModelState should have poly_degree field"
221 | assert state.poly_degree == 2, "Default poly_degree should be 2"
222 | assert state.poly_coeffs is not None, "poly_coeffs should not be None with default config"
223 |
224 | def test_poly_coeffs_shape(self, linear_data):
225 | """Polynomial coefficients should have correct shape."""
226 | snapshot, params, _ = linear_data
227 | result = train(snapshot, params)
228 | state = result.state
229 |
230 | n_modes = result.n_modes
231 | # For 1D params with degree 2: 3 polynomial terms [1, p, p²]
232 | n_poly = 3
233 | assert state.poly_coeffs.shape == (n_modes, n_poly), f"Expected ({n_modes}, {n_poly}), got {state.poly_coeffs.shape}"
234 |
235 | def test_poly_degree_0_fallback(self, linear_data):
236 | """poly_degree=0 should use pinv fallback."""
237 | snapshot, params, _ = linear_data
238 | config = TrainConfig(poly_degree=0)
239 | result = train(snapshot, params, config=config)
240 | state = result.state
241 |
242 | assert state.poly_degree == 0
243 | assert state.poly_coeffs is None, "poly_coeffs should be None when poly_degree=0"
244 |
245 | def test_interpolation_with_schur(self, linear_data):
246 | """Schur complement solver should interpolate training points accurately."""
247 | snapshot, params, x = linear_data
248 | result = train(snapshot, params)
249 | state = result.state
250 |
251 | for i, p in enumerate(params):
252 | pred = inference_single(state, p)
253 | expected = snapshot[:, i]
254 | assert jnp.allclose(pred, expected, rtol=1e-3), f"Mismatch at training point {i}"
255 |
256 | def test_interpolation_between_points_schur(self, linear_data):
257 | """Schur solver should interpolate accurately between training points."""
258 | snapshot, params, x = linear_data
259 | result = train(snapshot, params)
260 | state = result.state
261 |
262 | pred = inference_single(state, jnp.array(5.5))
263 | expected = 5.5 * x
264 |
265 | assert jnp.allclose(pred, expected, rtol=1e-2), "Interpolation at midpoint should be accurate"
266 |
267 | def test_grad_with_schur(self, linear_data):
268 | """Autodiff should work through Schur complement solver."""
269 | snapshot, params, _ = linear_data
270 | result = train(snapshot, params)
271 | state = result.state
272 |
273 | def loss(p):
274 | pred = inference_single(state, p)
275 | return jnp.sum(pred**2)
276 |
277 | grad_fn = jax.grad(loss)
278 | grad = grad_fn(jnp.array(5.0))
279 |
280 | assert not jnp.isnan(grad), "Gradient should not be NaN"
281 | assert grad != 0.0, "Gradient should be non-zero"
282 |
283 | def test_multi_param_with_schur(self):
284 | """Schur solver should work with multiple parameters."""
285 | x = jnp.linspace(0, 1, 30)
286 | # Use uncorrelated parameters to avoid rank-deficient polynomial basis
287 | p1 = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
288 | p2 = jnp.array([0.5, 0.1, 0.4, 0.2, 0.3])
289 | params = jnp.stack([p1, p2], axis=0)
290 |
291 | snapshot = x[:, None] * p1[None, :] + p2[None, :]
292 |
293 | result = train(snapshot, params)
294 | state = result.state
295 |
296 | # For 2D params with degree 2: 6 polynomial terms
297 | n_poly = 6
298 | assert state.poly_coeffs.shape[1] == n_poly, f"Expected {n_poly} poly terms, got {state.poly_coeffs.shape[1]}"
299 |
300 | # Test interpolation at training points
301 | for i in range(len(p1)):
302 | pred = inference_single(state, jnp.array([p1[i], p2[i]]))
303 | expected = snapshot[:, i]
304 | assert jnp.allclose(pred, expected, rtol=1e-3), f"Mismatch at training point {i}"
305 |
306 |
307 | class TestJIT:
308 | """Test JIT compilation."""
309 |
310 | @pytest.fixture
311 | def trained_model(self):
312 | """Create trained model for JIT tests."""
313 | x = jnp.linspace(0, 1, 50)
314 | params = jnp.linspace(1, 10, 10)
315 | snapshot = x[:, None] * params[None, :]
316 | result = train(snapshot, params)
317 | return result.state
318 |
319 | def test_inference_single_jit(self, trained_model):
320 | """inference_single should JIT compile via closure pattern."""
321 | state = trained_model
322 |
323 | # Create a closure that captures state (recommended pattern for JAX)
324 | @jax.jit
325 | def jitted_inference(p):
326 | return inference_single(state, p)
327 |
328 | # First call compiles
329 | pred1 = jitted_inference(jnp.array(5.0))
330 | # Second call uses cached compilation
331 | pred2 = jitted_inference(jnp.array(6.0))
332 |
333 | assert pred1.shape == (50,)
334 | assert pred2.shape == (50,)
335 |
336 | def test_inference_batch_jit(self, trained_model):
337 | """inference should JIT compile via closure pattern."""
338 | state = trained_model
339 |
340 | @jax.jit
341 | def jitted_inference(params):
342 | return inference(state, params)
343 |
344 | pred = jitted_inference(jnp.array([[2.0, 5.0, 8.0]]))
345 |
346 | assert pred.shape == (50, 3)
347 |
348 | def test_grad_jit(self, trained_model):
349 | """Gradient computation should JIT compile."""
350 | state = trained_model
351 |
352 | @jax.jit
353 | def loss_and_grad(p):
354 | pred = inference_single(state, p)
355 | loss_val = jnp.sum(pred**2)
356 | return loss_val
357 |
358 | grad_fn = jax.jit(jax.grad(loss_and_grad))
359 |
360 | loss = loss_and_grad(jnp.array(5.0))
361 | grad = grad_fn(jnp.array(5.0))
362 |
363 | assert not jnp.isnan(loss)
364 | assert not jnp.isnan(grad)
365 |
366 |
367 | class TestKernelTypes:
368 | """Test training and inference with different kernel types."""
369 |
370 | @pytest.fixture
371 | def linear_data(self):
372 | """Create simple linear test data: f(x, p) = p * x."""
373 | x = jnp.linspace(0, 1, 50)
374 | params = jnp.linspace(1, 10, 10)
375 | snapshot = x[:, None] * params[None, :]
376 | return snapshot, params, x
377 |
378 | def test_train_with_gaussian_kernel(self, linear_data):
379 | """Train with Gaussian kernel."""
380 | snapshot, params, _ = linear_data
381 | config = TrainConfig(kernel="gaussian")
382 | result = train(snapshot, params, config=config)
383 |
384 | assert result.state.kernel == "gaussian"
385 | assert result.state.shape_factor is not None
386 | assert result.n_modes > 0
387 |
388 | def test_train_with_phs_kernel(self, linear_data):
389 | """Train with polyharmonic spline kernel."""
390 | snapshot, params, _ = linear_data
391 | config = TrainConfig(kernel="polyharmonic_spline", kernel_order=3)
392 | result = train(snapshot, params, config=config)
393 |
394 | assert result.state.kernel == "polyharmonic_spline"
395 | assert result.state.kernel_order == 3
396 | assert result.state.shape_factor is None # PHS doesn't use shape parameter
397 | assert result.n_modes > 0
398 |
399 | def test_train_with_imq_kernel_explicit(self, linear_data):
400 | """Train with IMQ kernel explicitly specified."""
401 | snapshot, params, _ = linear_data
402 | config = TrainConfig(kernel="imq")
403 | result = train(snapshot, params, config=config)
404 |
405 | assert result.state.kernel == "imq"
406 | assert result.state.shape_factor is not None
407 | assert result.n_modes > 0
408 |
409 | def test_inference_with_gaussian_kernel(self, linear_data):
410 | """Inference should work with Gaussian kernel."""
411 | snapshot, params, x = linear_data
412 | config = TrainConfig(kernel="gaussian")
413 | result = train(snapshot, params, config=config)
414 |
415 | # Test inference at training points
416 | pred = inference_single(result.state, params[5])
417 | expected = snapshot[:, 5]
418 | assert jnp.allclose(pred, expected, rtol=1e-3)
419 |
420 | def test_inference_with_phs_kernel(self, linear_data):
421 | """Inference should work with PHS kernel."""
422 | snapshot, params, x = linear_data
423 | config = TrainConfig(kernel="polyharmonic_spline", kernel_order=3)
424 | result = train(snapshot, params, config=config)
425 |
426 | # Test inference at training points
427 | pred = inference_single(result.state, params[5])
428 | expected = snapshot[:, 5]
429 | assert jnp.allclose(pred, expected, rtol=1e-3)
430 |
431 | def test_all_kernels_interpolate_training_data(self, linear_data):
432 | """All kernels should interpolate training data accurately."""
433 | snapshot, params, x = linear_data
434 |
435 | kernels = [
436 | ("imq", {}),
437 | ("gaussian", {}),
438 | ("polyharmonic_spline", {"kernel_order": 3}),
439 | ]
440 |
441 | for kernel, kwargs in kernels:
442 | config = TrainConfig(kernel=kernel, **kwargs)
443 | result = train(snapshot, params, config=config)
444 |
445 | # Check interpolation at a few training points
446 | for i in [0, 5, 9]:
447 | pred = inference_single(result.state, params[i])
448 | expected = snapshot[:, i]
449 | assert jnp.allclose(
450 | pred, expected, rtol=1e-3
451 | ), f"Kernel {kernel} failed to interpolate training point {i}"
452 |
453 | def test_phs_different_orders(self, linear_data):
454 | """Test PHS with different orders."""
455 | snapshot, params, x = linear_data
456 |
457 | for order in [1, 3, 5]:
458 | config = TrainConfig(kernel="polyharmonic_spline", kernel_order=order)
459 | result = train(snapshot, params, config=config)
460 |
461 | assert result.state.kernel_order == order
462 | assert result.state.shape_factor is None
463 |
464 | # Should still interpolate training data
465 | pred = inference_single(result.state, params[5])
466 | expected = snapshot[:, 5]
467 | assert jnp.allclose(pred, expected, rtol=1e-3), f"PHS order {order} failed"
468 |
469 |
470 | class TestKernelGradients:
471 | """Test autodiff with different kernels."""
472 |
473 | @pytest.fixture
474 | def linear_data(self):
475 | """Create simple linear test data."""
476 | x = jnp.linspace(0, 1, 50)
477 | params = jnp.linspace(1, 10, 10)
478 | snapshot = x[:, None] * params[None, :]
479 | return snapshot, params, x
480 |
481 | def test_grad_with_gaussian_kernel(self, linear_data):
482 | """Gradients should work with Gaussian kernel."""
483 | snapshot, params, x = linear_data
484 | config = TrainConfig(kernel="gaussian")
485 | result = train(snapshot, params, config=config)
486 |
487 | def loss(p):
488 | pred = inference_single(result.state, p)
489 | return jnp.sum(pred**2)
490 |
491 | grad = jax.grad(loss)(jnp.array(5.0))
492 | assert not jnp.isnan(grad)
493 | assert grad != 0.0
494 |
495 | def test_grad_with_phs_kernel(self, linear_data):
496 | """Gradients should work with PHS kernel."""
497 | snapshot, params, x = linear_data
498 | config = TrainConfig(kernel="polyharmonic_spline", kernel_order=3)
499 | result = train(snapshot, params, config=config)
500 |
501 | def loss(p):
502 | pred = inference_single(result.state, p)
503 | return jnp.sum(pred**2)
504 |
505 | grad = jax.grad(loss)(jnp.array(5.0))
506 | assert not jnp.isnan(grad)
507 | assert grad != 0.0
508 |
509 | def test_inverse_problem_all_kernels(self, linear_data):
510 | """Inverse problem should work with all kernels."""
511 | snapshot, params, x = linear_data
512 | target_param = 7.5
513 | target = target_param * x
514 |
515 | kernels = [
516 | ("imq", {}),
517 | ("gaussian", {}),
518 | ("polyharmonic_spline", {"kernel_order": 3}),
519 | ]
520 |
521 | for kernel, kwargs in kernels:
522 | config = TrainConfig(kernel=kernel, **kwargs)
523 | result = train(snapshot, params, config=config)
524 |
525 | def loss(p):
526 | pred = inference_single(result.state, p)
527 | return jnp.mean((pred - target) ** 2)
528 |
529 | # Gradient descent
530 | p = jnp.array(5.0)
531 | lr = 0.5
532 |
533 | for _ in range(50):
534 | g = jax.grad(loss)(p)
535 | p = p - lr * g
536 |
537 | assert jnp.abs(p - target_param) < 0.2, f"Kernel {kernel}: recovered {p}, expected {target_param}"
538 |
539 |
540 | class TestBackwardsCompatibility:
541 | """Test backwards compatibility with default IMQ kernel."""
542 |
543 | def test_default_kernel_is_imq(self):
544 | """Default config should use IMQ kernel."""
545 | config = TrainConfig()
546 | assert config.kernel == "imq"
547 | assert config.kernel_order == 3
548 |
549 | def test_train_without_config_uses_imq(self):
550 | """Training without config should use IMQ kernel."""
551 | x = jnp.linspace(0, 1, 50)
552 | params = jnp.linspace(1, 10, 10)
553 | snapshot = x[:, None] * params[None, :]
554 |
555 | result = train(snapshot, params)
556 | assert result.state.kernel == "imq"
557 | assert result.state.shape_factor is not None
558 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------