├── .github ├── ISSUE_TEMPLATE.md ├── TEST_FAIL_TEMPLATE.md ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── examples ├── optimise_1d_grid_model.ipynb └── optimise_2d_grid_model.ipynb ├── pyproject.toml ├── src └── torch_cubic_spline_grids │ ├── __init__.py │ ├── _base_cubic_grid.py │ ├── _constants.py │ ├── b_spline_grids.py │ ├── catmull_rom_grids.py │ ├── interpolate_grids.py │ ├── interpolate_pieces.py │ ├── pad_grids.py │ └── utils.py └── tests ├── __init__.py ├── test_grid_optimisation.py ├── test_grids.py ├── test_interpolate_grid.py ├── test_interpolate_pieces.py ├── test_modules.py ├── test_pad_grid.py └── test_utils.py /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | * torch-cubic-b-spline-grid version: 2 | * Python version: 3 | * Operating System: 4 | 5 | ### Description 6 | 7 | Describe what you were trying to get done. 8 | Tell us what happened, what went wrong, and what you expected to happen. 9 | 10 | ### What I Did 11 | 12 | ``` 13 | Paste the command(s) you ran and the output. 14 | If there was a crash, please include the traceback here. 15 | ``` 16 | -------------------------------------------------------------------------------- /.github/TEST_FAIL_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: "{{ env.TITLE }}" 3 | labels: [bug] 4 | --- 5 | The {{ workflow }} workflow failed on {{ date | date("YYYY-MM-DD HH:mm") }} UTC 6 | 7 | The most recent failing test was on {{ env.PLATFORM }} py{{ env.PYTHON }} 8 | with commit: {{ sha }} 9 | 10 | Full run: https://github.com/{{ repo }}/actions/runs/{{ env.RUN_ID }} 11 | 12 | (This post will be updated if another test fails, as long as this issue remains open.) 13 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 2 | 3 | version: 2 4 | updates: 5 | - package-ecosystem: "github-actions" 6 | directory: "/" 7 | schedule: 8 | interval: "weekly" 9 | commit-message: 10 | prefix: "ci(dependabot):" 11 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | permissions: 4 | contents: write 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | tags: 11 | - "v*" 12 | pull_request: {} 13 | workflow_dispatch: 14 | 15 | jobs: 16 | check-manifest: 17 | name: Check Manifest 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v5 21 | - uses: actions/setup-python@v6 22 | with: 23 | python-version: "3.x" 24 | - run: pip install check-manifest && check-manifest 25 | 26 | test: 27 | name: ${{ matrix.platform }} (${{ matrix.python-version }}) 28 | runs-on: ${{ matrix.platform }} 29 | strategy: 30 | fail-fast: false 31 | matrix: 32 | python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] 33 | platform: [ 34 | ubuntu-latest, 35 | # macos-latest, 36 | # windows-latest, 37 | ] 38 | 39 | steps: 40 | - name: Cancel Previous Runs 41 | uses: styfle/cancel-workflow-action@0.12.1 42 | with: 43 | access_token: ${{ github.token }} 44 | 45 | - uses: actions/checkout@v5 46 | 47 | - name: Set up Python ${{ matrix.python-version }} 48 | uses: actions/setup-python@v6 49 | with: 50 | python-version: ${{ matrix.python-version }} 51 | 52 | - name: Install dependencies 53 | run: | 54 | python -m pip install -U pip 55 | python -m pip install -e . 56 | python -m pip install pytest pytest-cov 57 | 58 | - name: Test with pytest 59 | run: python -m pytest 60 | env: 61 | PLATFORM: ${{ matrix.platform }} 62 | 63 | - name: Coverage 64 | uses: codecov/codecov-action@v5 65 | 66 | deploy: 67 | name: Deploy 68 | needs: test 69 | if: "success() && startsWith(github.ref, 'refs/tags/')" 70 | runs-on: ubuntu-latest 71 | 72 | steps: 73 | - uses: actions/checkout@v5 74 | 75 | - name: Set up Python 76 | uses: actions/setup-python@v6 77 | with: 78 | python-version: "3.x" 79 | 80 | - name: install 81 | run: | 82 | git tag 83 | pip install -U pip 84 | pip install -U build twine 85 | python -m build 86 | twine check dist/* 87 | ls -lh dist 88 | 89 | - name: Build and publish 90 | run: twine upload dist/* 91 | env: 92 | TWINE_USERNAME: __token__ 93 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} 94 | 95 | - uses: softprops/action-gh-release@v2 96 | with: 97 | generate_release_notes: true 98 | 99 | # [WIP] 100 | # https://python-semantic-release.readthedocs.io/en/latest/automatic-releases/github-actions.html 101 | # release: 102 | # runs-on: ubuntu-latest 103 | # concurrency: release 104 | 105 | # steps: 106 | # - uses: actions/checkout@v5 107 | # with: 108 | # fetch-depth: 0 109 | 110 | # - name: Python Semantic Release 111 | # uses: relekang/python-semantic-release@master 112 | # with: 113 | # github_token: ${{ secrets.GITHUB_TOKEN }} 114 | # repository_username: __token__ 115 | # repository_password: ${{ secrets.TWINE_API_KEY }} 116 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # IDE settings 105 | .vscode/ 106 | 107 | src/torch_cubic_spline_grids/_version.py 108 | src/torch_cubic_spline_grids/_version.py 109 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_schedule: monthly 3 | autofix_commit_msg: "style(pre-commit.ci): auto fixes [...]" 4 | autoupdate_commit_msg: "ci(pre-commit.ci): autoupdate" 5 | 6 | default_install_hook_types: [pre-commit, commit-msg] 7 | 8 | repos: 9 | - repo: https://github.com/compilerla/conventional-pre-commit 10 | rev: v1.3.0 11 | hooks: 12 | - id: conventional-pre-commit 13 | stages: [commit-msg] 14 | 15 | - repo: https://github.com/pre-commit/pre-commit-hooks 16 | rev: v4.3.0 17 | hooks: 18 | - id: check-docstring-first 19 | - id: end-of-file-fixer 20 | - id: trailing-whitespace 21 | - id: debug-statements 22 | 23 | - repo: https://github.com/astral-sh/ruff-pre-commit 24 | rev: v0.4.10 25 | hooks: 26 | - id: ruff 27 | args: [--fix] 28 | - id: ruff-format 29 | 30 | - repo: https://github.com/abravalheri/validate-pyproject 31 | rev: v0.10.1 32 | hooks: 33 | - id: validate-pyproject 34 | 35 | - repo: https://github.com/pre-commit/mirrors-mypy 36 | rev: v1.15.0 37 | hooks: 38 | - id: mypy 39 | files: "^src/" 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | Copyright (c) 2023, Alister Burt 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torch-cubic-spline-grids 2 | 3 | [![License](https://img.shields.io/pypi/l/torch-cubic-spline-grids.svg?color=green)](https://github.com/alisterburt/torch-cubic-spline-grids/raw/main/LICENSE) 4 | [![PyPI](https://img.shields.io/pypi/v/torch-cubic-spline-grids.svg?color=green)](https://pypi.org/project/torch-cubic-spline-grids) 5 | [![Python Version](https://img.shields.io/pypi/pyversions/torch-cubic-spline-grids.svg?color=green)](https://python.org) 6 | [![CI](https://github.com/alisterburt/torch-cubic-spline-grids/actions/workflows/ci.yml/badge.svg)](https://github.com/alisterburt/torch-cubic-spline-grids/actions/workflows/ci.yml) 7 | [![codecov](https://codecov.io/gh/alisterburt/torch-cubic-spline-grids/branch/main/graph/badge.svg)](https://codecov.io/gh/alisterburt/torch-cubic-spline-grids) 8 | 9 | *Cubic spline interpolation on multidimensional grids in PyTorch.* 10 | 11 | The primary goal of this package is to provide learnable, continuous 12 | parametrisations of 1-4D spaces. 13 | 14 | --- 15 | 16 | ## Overview 17 | 18 | `torch_cubic_spline_grids` provides a set of PyTorch components called grids. 19 | 20 | Grids are defined by 21 | - their dimensionality (1d, 2d, 3d, 4d...) 22 | - the number of points covering each dimension (`resolution`) 23 | - the number of values stored on each grid point (`n_channels`) 24 | - how we interpolate between values on grid points 25 | 26 | All grids in this package consist of uniformly spaced points covering the full 27 | extent of each dimension. 28 | 29 | ### First steps 30 | Let's make a simple 2D grid with one value on each grid point. 31 | 32 | ```python 33 | import torch 34 | from torch_cubic_spline_grids import CubicBSplineGrid2d 35 | 36 | grid = CubicBSplineGrid2d(resolution=(5, 3), n_channels=1) 37 | ``` 38 | 39 | - `grid.ndim` is `2` 40 | - `grid.resolution` is `(5, 3)` (or `(h, w)`) 41 | - `grid.n_channels` is `1` 42 | - `grid.data.shape` is `(1, 5, 3)` (or `(c, h, w)`) 43 | 44 | In words, the grid extends over two dimensions `(h, w)` with 5 points 45 | in `h` and `3` points in `w`. 46 | There is one value stored at each point on the 2D grid. 47 | The grid data is stored in a tensor of shape `(c, *grid_resolution)`. 48 | 49 | We can obtain the value (interpolant) at any continuous point on the grid. 50 | The grid coordinate system extends from `[0, 1]` along each grid dimension. 51 | The interpolant is obtained by sequential application of 52 | cubic spline interpolation along each dimension of the grid. 53 | 54 | ```python 55 | coords = torch.rand(size=(10, 2)) # values in [0, 1] 56 | interpolants = grid(coords) 57 | ``` 58 | 59 | - `interpolants.shape` is `(10, 1)` 60 | 61 | ### Optimisation 62 | 63 | Values at each grid point can be optimised by minimising a loss function associated with grid interpolants. 64 | In this way the continuous space of the grid can be made to more accurately model a 1-4D space. 65 | 66 |

67 | 68 |

69 | 70 | The image above shows the values of 6 control points on a 1D grid being optimised such 71 | that interpolating between them with cubic B-spline interpolation approximates a single oscillation of a sine wave. 72 | 73 | Notebooks are available for this 74 | [1D example](./examples/optimise_1d_grid_model.ipynb) 75 | and a similar 76 | [2D example](./examples/optimise_2d_grid_model.ipynb). 77 | 78 | ### Types of grids 79 | 80 | `torch_cubic_spline_grids` provides grids which can be interpolated with **cubic 81 | B-spline** interpolation or **cubic Catmull-Rom spline** interpolation. 82 | 83 | | spline | continuity | interpolating? | 84 | |--------------------|------------|----------------| 85 | | cubic B-spline | C2 | No | 86 | | Catmull-Rom spline | C1 | Yes | 87 | 88 | If your need the resulting curve to intersect the data on the grid you should 89 | use the cubic Catmull-Rom spline grids 90 | 91 | - `CubicCatmullRomGrid1d` 92 | - `CubicCatmullRomGrid2d` 93 | - `CubicCatmullRomGrid3d` 94 | - `CubicCatmullRomGrid4d` 95 | 96 | If you require continuous second derivatives then the cubic B-spline grids are more 97 | suitable. 98 | 99 | - `CubicBSplineGrid1d` 100 | - `CubicBSplineGrid2d` 101 | - `CubicBSplineGrid3d` 102 | - `CubicBSplineGrid4d` 103 | 104 | ### Regularisation 105 | 106 | The number of points in each dimension should be chosen such that interpolating on the 107 | grid can approximate the underlying phenomenon being modelled without overfitting. 108 | A low resolution grid provides a regularising effect by smoothing the model. 109 | 110 | 111 | ## Installation 112 | 113 | `torch_cubic_spline_grids` is available on PyPI 114 | 115 | ```shell 116 | pip install torch-cubic-spline-grids 117 | ``` 118 | 119 | 120 | ## Related work 121 | 122 | This is a PyTorch implementation of the way 123 | [Warp](http://warpem.com/warp/#) models continuous deformation 124 | fields and locally variable optical parameters in cryo-EM images. 125 | The approach is described in 126 | [Dimitry Tegunov's paper](https://doi.org/10.1038/s41592-019-0580-y): 127 | 128 | > Many methods in Warp are based on a continuous parametrization of 1- to 129 | > 3-dimensional spaces. 130 | > This parameterization is achieved by spline interpolation between points on a coarse, 131 | > uniform grid, which is computationally efficient. 132 | > A grid extends over the entirety of each dimension that needs to be modeled. 133 | > The grid resolution is defined by the number of control points in each dimension 134 | > and is scaled according to physical constraints 135 | > (for example, the number of frames or pixels) and available signal. 136 | > The latter provides regularization to prevent overfitting of sparse data with too many 137 | > parameters. 138 | > When a parameter described by the grid is retrieved for a point in space (and time), 139 | > for example for a particle (frame), B-spline interpolation is performed at that point 140 | > on the grid. 141 | > To fit a grid’s parameters, in general, a cost function associated with the 142 | > interpolants at specific positions on the grid is optimized. 143 | 144 | --- 145 | 146 | For a fantastic introduction to splines I recommend 147 | [Freya Holmer](https://www.youtube.com/watch?v=jvPPXbo87ds)'s YouTube video. 148 | 149 | [The Continuity of Splines - YouTube](https://youtu.be/jvPPXbo87ds) 150 | -------------------------------------------------------------------------------- /examples/optimise_1d_grid_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "outputs": [], 7 | "source": [ 8 | "import torch\n", 9 | "from matplotlib import pyplot as plt\n", 10 | "\n", 11 | "from torch_cubic_spline_grids import CubicBSplineGrid1d" 12 | ], 13 | "metadata": { 14 | "collapsed": false, 15 | "pycharm": { 16 | "name": "#%%\n" 17 | } 18 | } 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "some variables we can set..." 24 | ], 25 | "metadata": { 26 | "collapsed": false, 27 | "pycharm": { 28 | "name": "#%% md\n" 29 | } 30 | } 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "outputs": [], 36 | "source": [ 37 | "N_CONTROL_POINTS = 6\n", 38 | "N_OBSERVATIONS_PER_ITERATION = 20" 39 | ], 40 | "metadata": { 41 | "collapsed": false, 42 | "pycharm": { 43 | "name": "#%%\n" 44 | } 45 | } 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "source": [ 50 | "initialise our optimisable parameters, a 1D grid of `N_CONTROL_POINTS` uniformly spaced\n", 51 | "points" 52 | ], 53 | "metadata": { 54 | "collapsed": false, 55 | "pycharm": { 56 | "name": "#%% md\n" 57 | } 58 | } 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 3, 63 | "outputs": [], 64 | "source": [ 65 | "grid_1d = CubicBSplineGrid1d(resolution=N_CONTROL_POINTS)" 66 | ], 67 | "metadata": { 68 | "collapsed": false, 69 | "pycharm": { 70 | "name": "#%%\n" 71 | } 72 | } 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "source": [ 77 | "define a function for making observations over the interval `[0, 1]` covering our 1d\n", 78 | "grid" 79 | ], 80 | "metadata": { 81 | "collapsed": false, 82 | "pycharm": { 83 | "name": "#%% md\n" 84 | } 85 | } 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 4, 90 | "outputs": [], 91 | "source": [ 92 | "def make_observations(n, add_noise: bool = False):\n", 93 | " x = torch.rand(n) # in range [0, 1]\n", 94 | " y = torch.sin(2 * torch.pi * x)\n", 95 | " if add_noise is True:\n", 96 | " y += torch.normal(torch.zeros(n), std=0.5)\n", 97 | " return x, y" 98 | ], 99 | "metadata": { 100 | "collapsed": false, 101 | "pycharm": { 102 | "name": "#%%\n" 103 | } 104 | } 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "source": [ 109 | "initialise the optimiser" 110 | ], 111 | "metadata": { 112 | "collapsed": false, 113 | "pycharm": { 114 | "name": "#%% md\n" 115 | } 116 | } 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 5, 121 | "outputs": [], 122 | "source": [ 123 | "optimiser = torch.optim.Adam(grid_1d.parameters(), lr=0.02)" 124 | ], 125 | "metadata": { 126 | "collapsed": false, 127 | "pycharm": { 128 | "name": "#%%\n" 129 | } 130 | } 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "source": [ 135 | "optimise the values at the control points such that interpolating between them with\n", 136 | "cubic B-spline interpolation fits the data" 137 | ], 138 | "metadata": { 139 | "collapsed": false, 140 | "pycharm": { 141 | "name": "#%% md\n" 142 | } 143 | } 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 6, 148 | "outputs": [ 149 | { 150 | "data": { 151 | "text/plain": "[]" 152 | }, 153 | "execution_count": 6, 154 | "metadata": {}, 155 | "output_type": "execute_result" 156 | }, 157 | { 158 | "data": { 159 | "text/plain": "
", 160 | "image/png": "\n" 161 | }, 162 | "metadata": {}, 163 | "output_type": "display_data" 164 | } 165 | ], 166 | "source": [ 167 | "# some plotting stuff\n", 168 | "fig, ax = plt.subplots()\n", 169 | "control_point_x = torch.linspace(0, 1, N_CONTROL_POINTS)\n", 170 | "decay = 1\n", 171 | "ax.scatter(control_point_x, grid_1d.data, color='white')\n", 172 | "\n", 173 | "# actually optimising!\n", 174 | "for i in range(500):\n", 175 | " # make (noisy) observations of the data we want to model\n", 176 | " x, y = make_observations(N_OBSERVATIONS_PER_ITERATION, add_noise=True)\n", 177 | "\n", 178 | " # what does the model predict for our observations?\n", 179 | " prediction = grid_1d(x).squeeze()\n", 180 | "\n", 181 | " # zero gradients and calculate loss between observations and model prediction\n", 182 | " optimiser.zero_grad()\n", 183 | " loss = torch.sum((prediction - y)**2)**0.5\n", 184 | "\n", 185 | " # backpropagate loss and update values at points on grid\n", 186 | " loss.backward()\n", 187 | " optimiser.step()\n", 188 | "\n", 189 | " # plot\n", 190 | " if i % 10 == 0:\n", 191 | " decay *= 0.99\n", 192 | " ax.scatter(control_point_x, grid_1d.data, color='blue', alpha=1 - decay)\n", 193 | "\n", 194 | " x = torch.linspace(0, 1, 1000)\n", 195 | " y = grid_1d(x).squeeze()\n", 196 | " ax.plot(x, y.detach(), alpha=1 - decay, color='blue')\n", 197 | "\n", 198 | "x = torch.linspace(0, 1, 1000)\n", 199 | "y = torch.sin(x * 2 * torch.pi)\n", 200 | "ax.plot(x, y, ls='--', color='orange')" 201 | ], 202 | "metadata": { 203 | "collapsed": false, 204 | "pycharm": { 205 | "name": "#%%\n" 206 | } 207 | } 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "source": [ 212 | "this model has very little capacity to overfit to noisy data because of the small\n", 213 | "number of control points on our grid (parameters)" 214 | ], 215 | "metadata": { 216 | "collapsed": false, 217 | "pycharm": { 218 | "name": "#%% md\n" 219 | } 220 | } 221 | } 222 | ], 223 | "metadata": { 224 | "kernelspec": { 225 | "display_name": "Python 3", 226 | "language": "python", 227 | "name": "python3" 228 | }, 229 | "language_info": { 230 | "codemirror_mode": { 231 | "name": "ipython", 232 | "version": 2 233 | }, 234 | "file_extension": ".py", 235 | "mimetype": "text/x-python", 236 | "name": "python", 237 | "nbconvert_exporter": "python", 238 | "pygments_lexer": "ipython2", 239 | "version": "2.7.6" 240 | } 241 | }, 242 | "nbformat": 4, 243 | "nbformat_minor": 0 244 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # https://peps.python.org/pep-0517/ 2 | [build-system] 3 | requires = ["hatchling", "hatch-vcs"] 4 | build-backend = "hatchling.build" 5 | 6 | # https://peps.python.org/pep-0621/ 7 | [project] 8 | name = "torch-cubic-spline-grids" 9 | description = "Cubic spline interpolation on multidimensional grids in PyTorch" 10 | readme = "README.md" 11 | requires-python = ">=3.9" 12 | license = {text = "BSD 3-Clause License"} 13 | authors = [ 14 | {email = "alisterburt@gmail.com"}, 15 | {name = "Alister Burt"}, 16 | ] 17 | classifiers = [ 18 | "Development Status :: 3 - Alpha", 19 | "License :: OSI Approved :: BSD License", 20 | "Programming Language :: Python :: 3", 21 | "Programming Language :: Python :: 3.9", 22 | "Programming Language :: Python :: 3.10", 23 | "Programming Language :: Python :: 3.11", 24 | "Programming Language :: Python :: 3.12", 25 | "Programming Language :: Python :: 3.13", 26 | ] 27 | dynamic = ["version"] 28 | dependencies = [ 29 | "torch", 30 | "numpy", 31 | "einops", 32 | "typing-extensions", 33 | ] 34 | 35 | # extras 36 | # https://peps.python.org/pep-0621/#dependencies-optional-dependencies 37 | [project.optional-dependencies] 38 | test = ["pytest>=6.0", "pytest-cov"] 39 | dev = [ 40 | "ipython", 41 | "mypy", 42 | "pdbpp", 43 | "pre-commit", 44 | "pytest-cov", 45 | "pytest", 46 | "rich", 47 | "ruff", 48 | ] 49 | 50 | [project.urls] 51 | homepage = "https://github.com/alisterburt/torch-cubic-spline-grids" 52 | repository = "https://github.com/alisterburt/torch-cubic-spline-grids" 53 | 54 | # same as console_scripts entry point 55 | # [project.scripts] 56 | # spam-cli = "spam:main_cli" 57 | 58 | # Entry points 59 | # https://peps.python.org/pep-0621/#entry-points 60 | # [project.entry-points."spam.magical"] 61 | # tomatoes = "spam:main_tomatoes" 62 | 63 | # https://hatch.pypa.io/latest/config/metadata/ 64 | [tool.hatch.version] 65 | source = "vcs" 66 | 67 | # https://hatch.pypa.io/latest/config/build/#file-selection 68 | # [tool.hatch.build.targets.sdist] 69 | # include = ["/src", "/tests"] 70 | 71 | 72 | # https://github.com/charliermarsh/ruff 73 | [tool.ruff] 74 | line-length = 88 75 | target-version = "py38" 76 | 77 | [tool.ruff.lint] 78 | extend-select = [ 79 | "E", # style errors 80 | "F", # flakes 81 | "D", # pydocstyle 82 | "I001", # isort 83 | "U", # pyupgrade 84 | # "N", # pep8-naming 85 | # "S", # bandit 86 | "C", # flake8-comprehensions 87 | "B", # flake8-bugbear 88 | "A001", # flake8-builtins 89 | "RUF", # ruff-specific rules 90 | ] 91 | extend-ignore = [ 92 | "D100", # Missing docstring in public module 93 | "D107", # Missing docstring in __init__ 94 | "D203", # 1 blank line required before class docstring 95 | "D212", # Multi-line docstring summary should start at the first line 96 | "D213", # Multi-line docstring summary should start at the second line 97 | "D413", # Missing blank line after last section 98 | "D416", # Section name should end with a colon 99 | ] 100 | 101 | [tool.ruff.lint.per-file-ignores] 102 | "tests/*.py" = ["D"] 103 | 104 | [tool.ruff.lint.isort] 105 | combine-as-imports = true 106 | 107 | [tool.ruff.format] 108 | # Prefer single quotes over double quotes. 109 | quote-style = "single" 110 | 111 | # https://docs.pytest.org/en/6.2.x/customize.html 112 | [tool.pytest.ini_options] 113 | minversion = "6.0" 114 | pythonpath = "src" 115 | testpaths = ["tests"] 116 | filterwarnings = [ 117 | "error", 118 | ] 119 | 120 | # https://mypy.readthedocs.io/en/stable/config_file.html 121 | [tool.mypy] 122 | files = "src/**/" 123 | strict = true 124 | disallow_any_generics = false 125 | disallow_subclassing_any = false 126 | show_error_codes = true 127 | pretty = true 128 | 129 | 130 | # https://coverage.readthedocs.io/en/6.4/config.html 131 | [tool.coverage.report] 132 | exclude_lines = [ 133 | "pragma: no cover", 134 | "if TYPE_CHECKING:", 135 | "@overload", 136 | "except ImportError", 137 | ] 138 | 139 | # https://github.com/mgedmin/check-manifest#configuration 140 | [tool.check-manifest] 141 | ignore = [ 142 | ".github_changelog_generator", 143 | ".pre-commit-config.yaml", 144 | ".ruff_cache/**/*", 145 | "tests/**/*", 146 | "tox.ini", 147 | ] 148 | 149 | # https://python-semantic-release.readthedocs.io/en/latest/configuration.html 150 | [tool.semantic_release] 151 | version_source = "tag_only" 152 | branch = "main" 153 | changelog_sections="feature,fix,breaking,documentation,performance,chore,:boom:,:sparkles:,:children_crossing:,:lipstick:,:iphone:,:egg:,:chart_with_upwards_trend:,:ambulance:,:lock:,:bug:,:zap:,:goal_net:,:alien:,:wheelchair:,:speech_balloon:,:mag:,:apple:,:penguin:,:checkered_flag:,:robot:,:green_apple:,Other" 154 | # commit_parser=semantic_release.history.angular_parser 155 | build_command = "pip install build && python -m build" 156 | -------------------------------------------------------------------------------- /src/torch_cubic_spline_grids/__init__.py: -------------------------------------------------------------------------------- 1 | """Cubic B-spline interpolation on multidimensional grids in PyTorch.""" 2 | 3 | from importlib.metadata import PackageNotFoundError, version 4 | 5 | try: 6 | __version__ = version('torch-cubic-b-spline-grid') 7 | except PackageNotFoundError: 8 | __version__ = 'uninstalled' 9 | 10 | __author__ = 'Alister Burt' 11 | __email__ = 'alisterburt@gmail.com' 12 | 13 | from .b_spline_grids import ( 14 | CubicBSplineGrid1d, 15 | CubicBSplineGrid2d, 16 | CubicBSplineGrid3d, 17 | CubicBSplineGrid4d, 18 | ) 19 | from .catmull_rom_grids import ( 20 | CubicCatmullRomGrid1d, 21 | CubicCatmullRomGrid2d, 22 | CubicCatmullRomGrid3d, 23 | CubicCatmullRomGrid4d, 24 | ) 25 | 26 | __all__ = [ 27 | 'CubicBSplineGrid1d', 28 | 'CubicBSplineGrid2d', 29 | 'CubicBSplineGrid3d', 30 | 'CubicBSplineGrid4d', 31 | 'CubicCatmullRomGrid1d', 32 | 'CubicCatmullRomGrid2d', 33 | 'CubicCatmullRomGrid3d', 34 | 'CubicCatmullRomGrid4d', 35 | ] 36 | -------------------------------------------------------------------------------- /src/torch_cubic_spline_grids/_base_cubic_grid.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Tuple 2 | 3 | import einops 4 | import torch 5 | from typing_extensions import Self 6 | 7 | from torch_cubic_spline_grids.utils import ( 8 | MonotonicityType, 9 | batch, 10 | coerce_to_multichannel_grid, 11 | ) 12 | 13 | 14 | class CubicSplineGrid(torch.nn.Module): 15 | """Base class for continuous parametrisations of multidimensional spaces.""" 16 | 17 | ndim: int 18 | _data: torch.nn.Parameter 19 | _interpolation_function: Callable 20 | _interpolation_matrix: torch.Tensor 21 | _minibatch_size: int 22 | 23 | def __init__( 24 | self, 25 | resolution: Optional[Tuple[int, ...]] = None, 26 | n_channels: int = 1, 27 | minibatch_size: int = 1_000_000, 28 | monotonicity: Optional[MonotonicityType] = None, 29 | ): 30 | super().__init__() 31 | if resolution is None: 32 | resolution = (2,) * self.ndim 33 | grid_shape = (n_channels, *resolution) 34 | self.data = torch.zeros(size=grid_shape) 35 | self._minibatch_size = minibatch_size 36 | self._monotonicity = monotonicity 37 | self.register_buffer( 38 | name='interpolation_matrix', 39 | tensor=self._interpolation_matrix, 40 | persistent=False, 41 | ) 42 | 43 | def _interpolate(self, u: torch.Tensor) -> torch.Tensor: 44 | return self._interpolation_function( 45 | self._data, 46 | u, 47 | matrix=self.interpolation_matrix, 48 | monotonicity=self._monotonicity, 49 | ) 50 | 51 | def forward(self, u: torch.Tensor) -> torch.Tensor: 52 | u = self._coerce_to_batched_coordinates(u) # (b, d) 53 | 54 | interpolated = [ 55 | self._interpolate(minibatch_u) 56 | for minibatch_u in batch(u, n=self._minibatch_size) 57 | ] # List[Tensor[(b, d)]] 58 | interpolated = torch.cat(interpolated, dim=0) # (b, d) 59 | return self._unpack_interpolated_output(interpolated) 60 | 61 | @classmethod 62 | def from_grid_data(cls, data: torch.Tensor) -> Self: 63 | """Instantiate a grid from existing grid data. 64 | 65 | Parameters 66 | ---------- 67 | data: torch.Tensor 68 | (c, *grid_dimensions) or (*grid_dimensions) array of multichannel values at 69 | each grid point. 70 | """ 71 | grid = cls() 72 | grid.data = data 73 | return grid 74 | 75 | @property 76 | def data(self) -> torch.Tensor: 77 | return self._data.detach() 78 | 79 | @data.setter 80 | def data(self, grid_data: torch.Tensor) -> None: 81 | grid_data = coerce_to_multichannel_grid(grid_data, grid_ndim=self.ndim) 82 | self._data = torch.nn.Parameter(grid_data) 83 | 84 | @property 85 | def n_channels(self) -> int: 86 | return int(self._data.size(0)) 87 | 88 | @property 89 | def resolution(self) -> Tuple[int, ...]: 90 | return tuple(self._data.shape[1:]) 91 | 92 | def _coerce_to_batched_coordinates(self, u: torch.Tensor) -> torch.Tensor: 93 | u = torch.atleast_1d(torch.as_tensor(u, dtype=torch.float32)) 94 | self._input_is_coordinate_like = u.shape[-1] == self.ndim 95 | if self._input_is_coordinate_like is False and self.ndim == 1: 96 | u = einops.rearrange(u, '... -> ... 1') # add singleton coord dimension 97 | else: 98 | u = torch.atleast_2d(u) # add batch dimension if missing 99 | u, self._packed_shapes = einops.pack([u], pattern='* coords') 100 | if u.shape[-1] != self.ndim: 101 | ndim = u.shape[-1] 102 | raise ValueError( 103 | f'Cannot interpolate on a {self.ndim}D grid with {ndim}D coordinates' 104 | ) 105 | return u 106 | 107 | def _unpack_interpolated_output(self, interpolated: torch.Tensor) -> torch.Tensor: 108 | [interpolated] = einops.unpack( 109 | interpolated, packed_shapes=self._packed_shapes, pattern='* coords' 110 | ) 111 | if self._input_is_coordinate_like is False and self.ndim == 1: 112 | interpolated = einops.rearrange(interpolated, '... 1 -> ...') 113 | return interpolated 114 | -------------------------------------------------------------------------------- /src/torch_cubic_spline_grids/_constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Described in Freya Holmer's video "The Continuity of Splines" 4 | # https://youtu.be/jvPPXbo87ds?t=3462 5 | 6 | CUBIC_B_SPLINE_MATRIX = (1 / 6) * torch.tensor([[1, 4, 1, 0], 7 | [-3, 0, 3, 0], 8 | [3, -6, 3, 0], 9 | [-1, 3, -3, 1]]) 10 | 11 | CUBIC_CATMULL_ROM_MATRIX = (1 / 2) * torch.tensor([[0, 2, 0, 0], 12 | [-1, 0, 1, 0], 13 | [2, -5, 4, -1], 14 | [-1, 3, -3, 1]]) 15 | -------------------------------------------------------------------------------- /src/torch_cubic_spline_grids/b_spline_grids.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence, Tuple, Union 2 | 3 | import torch 4 | 5 | from torch_cubic_spline_grids._base_cubic_grid import CubicSplineGrid 6 | from torch_cubic_spline_grids._constants import CUBIC_B_SPLINE_MATRIX 7 | from torch_cubic_spline_grids.interpolate_grids import ( 8 | interpolate_grid_1d as _interpolate_grid_1d, 9 | interpolate_grid_2d as _interpolate_grid_2d, 10 | interpolate_grid_3d as _interpolate_grid_3d, 11 | interpolate_grid_4d as _interpolate_grid_4d, 12 | ) 13 | from torch_cubic_spline_grids.utils import MonotonicityType 14 | 15 | CoordinateLike = Union[float, Sequence[float], torch.Tensor] 16 | 17 | 18 | class _CubicBSplineGrid(CubicSplineGrid): 19 | _interpolation_matrix = CUBIC_B_SPLINE_MATRIX 20 | 21 | 22 | class CubicBSplineGrid1d(_CubicBSplineGrid): 23 | """Continuous parametrisation of a 1D space with a specific resolution.""" 24 | 25 | ndim: int = 1 26 | _interpolation_function: Callable = staticmethod(_interpolate_grid_1d) 27 | 28 | def __init__( 29 | self, 30 | resolution: Optional[Union[int, Tuple[int]]] = None, 31 | n_channels: int = 1, 32 | minibatch_size: int = 1_000_000, 33 | monotonicity: Optional[MonotonicityType] = None, 34 | ): 35 | if isinstance(resolution, int): 36 | resolution = (resolution,) 37 | super().__init__( 38 | resolution=resolution, 39 | n_channels=n_channels, 40 | minibatch_size=minibatch_size, 41 | monotonicity=monotonicity, 42 | ) 43 | 44 | 45 | class CubicBSplineGrid2d(_CubicBSplineGrid): 46 | """Continuous parametrisation of a 2D space with a specific resolution.""" 47 | 48 | ndim: int = 2 49 | _interpolation_function: Callable = staticmethod(_interpolate_grid_2d) 50 | 51 | 52 | class CubicBSplineGrid3d(_CubicBSplineGrid): 53 | """Continuous parametrisation of a 3D space with a specific resolution.""" 54 | 55 | ndim: int = 3 56 | _interpolation_function: Callable = staticmethod(_interpolate_grid_3d) 57 | 58 | 59 | class CubicBSplineGrid4d(_CubicBSplineGrid): 60 | """Continuous parametrisation of a 4D space with a specific resolution.""" 61 | 62 | ndim: int = 4 63 | _interpolation_function: Callable = staticmethod(_interpolate_grid_4d) 64 | -------------------------------------------------------------------------------- /src/torch_cubic_spline_grids/catmull_rom_grids.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence, Tuple, Union 2 | 3 | import torch 4 | 5 | from torch_cubic_spline_grids._base_cubic_grid import CubicSplineGrid 6 | from torch_cubic_spline_grids._constants import CUBIC_CATMULL_ROM_MATRIX 7 | from torch_cubic_spline_grids.interpolate_grids import ( 8 | interpolate_grid_1d as _interpolate_grid_1d, 9 | interpolate_grid_2d as _interpolate_grid_2d, 10 | interpolate_grid_3d as _interpolate_grid_3d, 11 | interpolate_grid_4d as _interpolate_grid_4d, 12 | ) 13 | from torch_cubic_spline_grids.utils import MonotonicityType 14 | 15 | CoordinateLike = Union[float, Sequence[float], torch.Tensor] 16 | 17 | 18 | class _CubicCatmullRomGrid(CubicSplineGrid): 19 | _interpolation_matrix = CUBIC_CATMULL_ROM_MATRIX 20 | 21 | 22 | class CubicCatmullRomGrid1d(_CubicCatmullRomGrid): 23 | """Continuous parametrisation of a 1D space with a specific resolution.""" 24 | 25 | ndim: int = 1 26 | _interpolation_function: Callable = staticmethod(_interpolate_grid_1d) 27 | 28 | def __init__( 29 | self, 30 | resolution: Optional[Union[int, Tuple[int]]] = None, 31 | n_channels: int = 1, 32 | minibatch_size: int = 1_000_000, 33 | monotonicity: Optional[MonotonicityType] = None, 34 | ): 35 | if isinstance(resolution, int): 36 | resolution = (resolution,) 37 | super().__init__( 38 | resolution=resolution, 39 | n_channels=n_channels, 40 | minibatch_size=minibatch_size, 41 | monotonicity=monotonicity, 42 | ) 43 | 44 | 45 | class CubicCatmullRomGrid2d(_CubicCatmullRomGrid): 46 | """Continuous parametrisation of a 2D space with a specific resolution.""" 47 | 48 | ndim: int = 2 49 | _interpolation_function: Callable = staticmethod(_interpolate_grid_2d) 50 | 51 | 52 | class CubicCatmullRomGrid3d(_CubicCatmullRomGrid): 53 | """Continuous parametrisation of a 3D space with a specific resolution.""" 54 | 55 | ndim: int = 3 56 | _interpolation_function: Callable = staticmethod(_interpolate_grid_3d) 57 | 58 | 59 | class CubicCatmullRomGrid4d(_CubicCatmullRomGrid): 60 | """Continuous parametrisation of a 4D space with a specific resolution.""" 61 | 62 | ndim: int = 4 63 | _interpolation_function: Callable = staticmethod(_interpolate_grid_4d) 64 | -------------------------------------------------------------------------------- /src/torch_cubic_spline_grids/interpolate_grids.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import einops 4 | import torch 5 | 6 | from torch_cubic_spline_grids.interpolate_pieces import ( 7 | interpolate_pieces_1d, 8 | interpolate_pieces_2d, 9 | interpolate_pieces_3d, 10 | interpolate_pieces_4d, 11 | ) 12 | from torch_cubic_spline_grids.pad_grids import ( 13 | pad_grid_1d, 14 | pad_grid_2d, 15 | pad_grid_3d, 16 | pad_grid_4d, 17 | ) 18 | from torch_cubic_spline_grids.utils import ( 19 | MonotonicityType, 20 | interpolants_to_interpolation_data_1d, 21 | transform_to_monotonic_nd, 22 | ) 23 | 24 | 25 | def interpolate_grid_1d( 26 | grid: torch.Tensor, 27 | u: torch.Tensor, 28 | matrix: torch.Tensor, 29 | monotonicity: Optional[MonotonicityType] = None, 30 | ) -> torch.Tensor: 31 | """Uniform cubic spline interpolation on a 1D grid. 32 | 33 | The range [0, 1] covers all data points in the 1D grid. 34 | 35 | Parameters 36 | ---------- 37 | grid: torch.Tensor 38 | `(c, w)` array of `w` values in `c` channels to be interpolated. 39 | u: torch.Tensor 40 | `(b, 1)` array of query points in the range `[0, 1]` covering the `w` 41 | dimension of `grid`. 42 | matrix: torch.Tensor 43 | `(4, 4)` characteristic matrix for the spline. 44 | monotonicity: str 45 | when either 'increasing' or 'decreasing' is specified, ensures 46 | that control points of spline are monotonic. 47 | 48 | Returns 49 | ------- 50 | interpolated: torch.Tensor 51 | `(b, c)` array of interpolated values in each channel. 52 | """ 53 | if grid.ndim == 1: 54 | grid = einops.rearrange(grid, 'w -> 1 w') 55 | _, w = grid.shape 56 | 57 | # handle interpolation at edges by extending grid of control points according to 58 | # local gradients 59 | grid = pad_grid_1d(grid) 60 | 61 | # find control point indices and interpolation coordinate 62 | idx, t = interpolants_to_interpolation_data_1d(u[:, 0], n_samples=w) 63 | if monotonicity: 64 | grid = transform_to_monotonic_nd(grid, ndims=1, monotonicity=monotonicity) 65 | control_points = grid[..., idx] # (c, b, 4) 66 | control_points = einops.rearrange(control_points, 'c b p -> b c p') 67 | 68 | # interpolate 69 | return interpolate_pieces_1d(control_points, t, matrix=matrix) 70 | 71 | 72 | def interpolate_grid_2d( 73 | grid: torch.Tensor, 74 | u: torch.Tensor, 75 | matrix: torch.Tensor, 76 | monotonicity: Optional[MonotonicityType] = None, 77 | ) -> torch.Tensor: 78 | """Uniform cubic B-spline interpolation on a 2D grid. 79 | 80 | Parameters 81 | ---------- 82 | grid: torch.Tensor 83 | `(c, h, w)` multichannel 2D grid. 84 | u: torch.Tensor 85 | `(b, 2)` array of values in the range `[0, 1]`. 86 | `[0, 1]` in `u[:, 0]` covers dim -2 (h) of `grid` 87 | `[0, 1]` in `u[:, 1]` covers dim -1 (w) of `grid` 88 | matrix: torch.Tensor 89 | `(4, 4)` characteristic matrix for the spline. 90 | monotonicity: str 91 | when either 'increasing' or 'decreasing' is specified, ensures 92 | that control points of spline are monotonic. 93 | 94 | Returns 95 | ------- 96 | `(b, c)` array of interpolated values in each channel. 97 | """ 98 | if grid.ndim == 2: 99 | grid = einops.rearrange(grid, 'h w -> 1 h w') 100 | _, h, w = grid.shape 101 | 102 | # pad grid to handle interpolation at edges. 103 | grid = pad_grid_2d(grid) 104 | 105 | # find control point indices and interpolation coordinate in each dim 106 | idx_h, t_h = interpolants_to_interpolation_data_1d(u[:, 0], n_samples=h) 107 | idx_w, t_w = interpolants_to_interpolation_data_1d(u[:, 1], n_samples=w) 108 | 109 | # construct (4, 4) grids of control points and 2D interpolant then interpolate 110 | idx_h = einops.repeat(idx_h, 'b h -> b h w', w=4) 111 | idx_w = einops.repeat(idx_w, 'b w -> b h w', h=4) 112 | if monotonicity: 113 | grid = transform_to_monotonic_nd(grid, ndims=2, monotonicity=monotonicity) 114 | control_points = grid[..., idx_h, idx_w] # (c, b, 4, 4) 115 | control_points = einops.rearrange(control_points, 'c b h w -> b c h w') 116 | 117 | t = einops.rearrange([t_h, t_w], 'hw b -> b hw') 118 | return interpolate_pieces_2d(control_points, t, matrix=matrix) 119 | 120 | 121 | def interpolate_grid_3d( 122 | grid: torch.Tensor, 123 | u: torch.Tensor, 124 | matrix: torch.Tensor, 125 | monotonicity: Optional[MonotonicityType] = None, 126 | ) -> torch.Tensor: 127 | """Uniform cubic B-spline interpolation on a 3D grid. 128 | 129 | Parameters 130 | ---------- 131 | grid: torch.Tensor 132 | `(c, d, h, w)` multichannel 3D grid. 133 | u: torch.Tensor 134 | `(b, 3)` array of values in the range [0, 1]. 135 | [0, 1] in b[:, 0] covers depth dim `d` of `grid` 136 | [0, 1] in b[:, 1] covers height dim `h` of `grid` 137 | [0, 1] in b[:, 2] covers width dim `w` of `grid` 138 | matrix: torch.Tensor 139 | `(4, 4)` characteristic matrix for the spline. 140 | monotonicity: str 141 | when either 'increasing' or 'decreasing' is specified, ensures 142 | that control points of spline are monotonic. 143 | 144 | Returns 145 | ------- 146 | `(b, c)` array of c-dimensional interpolated values 147 | """ 148 | if grid.ndim == 3: 149 | grid = einops.rearrange(grid, 'd h w -> 1 d h w') 150 | _, n_samples_d, n_samples_h, n_samples_w = grid.shape 151 | 152 | # expand grid to handle interpolation at edges 153 | grid = pad_grid_3d(grid) 154 | 155 | # find control point indices and interpolation coordinate in each dim 156 | idx_d, t_d = interpolants_to_interpolation_data_1d(u[:, 0], n_samples_d) 157 | idx_h, t_h = interpolants_to_interpolation_data_1d(u[:, 1], n_samples_h) 158 | idx_w, t_w = interpolants_to_interpolation_data_1d(u[:, 2], n_samples_w) 159 | 160 | # construct (4, 4, 4) grids of control points and 3D interpolant then interpolate 161 | idx_d = einops.repeat(idx_d, 'b d -> b d h w', h=4, w=4) 162 | idx_h = einops.repeat(idx_h, 'b h -> b d h w', d=4, w=4) 163 | idx_w = einops.repeat(idx_w, 'b w -> b d h w', d=4, h=4) 164 | if monotonicity: 165 | grid = transform_to_monotonic_nd(grid, ndims=3, monotonicity=monotonicity) 166 | control_points = grid[:, idx_d, idx_h, idx_w] # (c, b, 4, 4, 4) 167 | control_points = einops.rearrange(control_points, 'c b d h w -> b c d h w') 168 | 169 | t = einops.rearrange([t_d, t_h, t_w], 'dhw b -> b dhw') 170 | return interpolate_pieces_3d(control_points, t, matrix=matrix) 171 | 172 | 173 | def interpolate_grid_4d( 174 | grid: torch.Tensor, 175 | u: torch.Tensor, 176 | matrix: torch.Tensor, 177 | monotonicity: Optional[MonotonicityType] = None, 178 | ) -> torch.Tensor: 179 | """Uniform cubic B-spline interpolation on a 4D grid. 180 | 181 | Parameters 182 | ---------- 183 | grid: torch.Tensor 184 | `(c, u, d, h, w)` multichannel 4D grid. 185 | u: torch.Tensor 186 | `(b, 4)` array of values in the range [0, 1]. 187 | [0, 1] in b[:, 0] covers time dim `u` of `grid` 188 | [0, 1] in b[:, 1] covers depth dim `d` of `grid` 189 | [0, 1] in b[:, 2] covers height dim `h` of `grid` 190 | [0, 1] in b[:, 3] covers width dim `w` of `grid` 191 | matrix: torch.Tensor 192 | `(4, 4)` characteristic matrix for the spline. 193 | monotonicity: str 194 | when either 'increasing' or 'decreasing' is specified, ensures 195 | that control points of spline are monotonic. 196 | 197 | Returns 198 | ------- 199 | `(b, c)` array of c-dimensional interpolated values 200 | """ 201 | if grid.ndim == 4: 202 | grid = einops.rearrange(grid, 't d h w -> 1 t d h w') 203 | _, t, d, h, w = grid.shape 204 | 205 | # expand grid to handle interpolation at edges 206 | grid = pad_grid_4d(grid) 207 | 208 | # find control point indices and interpolation coordinate in each dim 209 | idx_t, t_t = interpolants_to_interpolation_data_1d(u[:, 0], n_samples=t) 210 | idx_d, t_d = interpolants_to_interpolation_data_1d(u[:, 1], n_samples=d) 211 | idx_h, t_h = interpolants_to_interpolation_data_1d(u[:, 2], n_samples=h) 212 | idx_w, t_w = interpolants_to_interpolation_data_1d(u[:, 3], n_samples=w) 213 | 214 | # construct (4, 4, 4, 4) grids of control points and 4D interpolant then interpolate 215 | idx_t = einops.repeat(idx_t, 'b t -> b t d h w', d=4, h=4, w=4) 216 | idx_d = einops.repeat(idx_d, 'b d -> b t d h w', t=4, h=4, w=4) 217 | idx_h = einops.repeat(idx_h, 'b h -> b t d h w', t=4, d=4, w=4) 218 | idx_w = einops.repeat(idx_w, 'b w -> b t d h w', t=4, d=4, h=4) 219 | if monotonicity: 220 | grid = transform_to_monotonic_nd(grid, ndims=3, monotonicity=monotonicity) 221 | control_points = grid[:, idx_t, idx_d, idx_h, idx_w] # (c, b, 4, 4, 4, 4) 222 | control_points = einops.rearrange(control_points, 'c b t d h w -> b c t d h w') 223 | 224 | t = einops.rearrange([t_t, t_d, t_h, t_w], 'tdhw b -> b tdhw') 225 | return interpolate_pieces_4d(control_points, t, matrix=matrix) 226 | -------------------------------------------------------------------------------- /src/torch_cubic_spline_grids/interpolate_pieces.py: -------------------------------------------------------------------------------- 1 | """Interpolate 'pieces' for piecewise uniform cubic B-spline interpolation.""" 2 | 3 | import einops 4 | import torch 5 | 6 | 7 | def interpolate_pieces_1d( 8 | control_points: torch.Tensor, t: torch.Tensor, matrix: torch.Tensor 9 | ) -> torch.Tensor: 10 | """Batched uniform 1D cubic spline interpolation. 11 | 12 | ``` 13 | [0, u, u^2, u^3] * [a00, a01, a02, a03] * [p0] 14 | [a10, a11, a12, a13] [p1] 15 | [a20, a21, a22, a23] [p2] 16 | [a30, a31, a32, a33] [p3] 17 | ``` 18 | c.f. Freya Holmer - "The Continuity of Splines": https://youtu.be/jvPPXbo87ds?t=3462 19 | 20 | Parameters 21 | ---------- 22 | control_points: torch.Tensor 23 | `(b, c, 4)` batch of 4 uniformly spaced control points `[p0, p1, p2, p3]` 24 | in `c` channels. 25 | t: torch.Tensor 26 | `(b, )` batch of interpolants in the range [0, 1] covering the interpolation 27 | interval between `p1` and `p2` 28 | matrix: torch.Tensor 29 | `(4, 4)` characteristic matrix for the spline. 30 | 31 | Returns 32 | ------- 33 | interpolated: torch.Tensor 34 | `(b, c)` array of per-channel interpolants of `control_points` at `u`. 35 | """ 36 | t = einops.rearrange([t**0, t, t**2, t**3], 'u b -> b 1 1 u') 37 | control_points = einops.rearrange(control_points, 'b c p -> b c p 1') 38 | interpolated = t @ matrix @ control_points 39 | return einops.rearrange(interpolated, 'b c 1 1 -> b c') 40 | 41 | 42 | def interpolate_pieces_2d( 43 | control_points: torch.Tensor, t: torch.Tensor, matrix: torch.Tensor 44 | ) -> torch.Tensor: 45 | """Batched uniform 2D cubic B-spline interpolation. 46 | 47 | Parameters 48 | ---------- 49 | control_points: torch.Tensor 50 | `(b, c, 4, 4)` batch of 2D multichannel grids of uniformly spaced control 51 | points `[p0, p1, p2, p3]` for cubic B-spline interpolation. 52 | t: torch.Tensor 53 | `(b, 2)` batch of values in the range `[0, 1]` defining the position of 2D 54 | points to be interpolated within the interval `[p1, p2]` along dim 1 and 2 of 55 | the 2D grid of control points. 56 | matrix: torch.Tensor 57 | `(4, 4)` characteristic matrix for the spline. 58 | 59 | Returns 60 | ------- 61 | interpolated: 62 | `(b, n)` batch of n-dimensional interpolated values. 63 | """ 64 | # extract (b, c, 4) control points at each height along width dim of (h, w) grid 65 | h0, h1, h2, h3 = einops.rearrange(control_points, 'b c h w -> h b c w') 66 | 67 | # separate u into components along height and width dimensions 68 | t_h, t_w = einops.rearrange(t, 'b hw -> hw b') 69 | 70 | # 1d interpolation along width dim at each height 71 | p0 = interpolate_pieces_1d(control_points=h0, t=t_w, matrix=matrix) 72 | p1 = interpolate_pieces_1d(control_points=h1, t=t_w, matrix=matrix) 73 | p2 = interpolate_pieces_1d(control_points=h2, t=t_w, matrix=matrix) 74 | p3 = interpolate_pieces_1d(control_points=h3, t=t_w, matrix=matrix) 75 | 76 | # 1d interpolation of result along height dim 77 | control_points = einops.rearrange([p0, p1, p2, p3], 'p b c -> b c p') 78 | return interpolate_pieces_1d(control_points=control_points, t=t_h, matrix=matrix) 79 | 80 | 81 | def interpolate_pieces_3d( 82 | control_points: torch.Tensor, t: torch.Tensor, matrix: torch.Tensor 83 | ) -> torch.Tensor: 84 | """Batched uniform 3D cubic B-spline interpolation. 85 | 86 | Parameters 87 | ---------- 88 | control_points: torch.Tensor 89 | `(b, c, 4, 4, 4)` batch of `(4, 4, 4)` multichannel grids of uniformly 90 | spaced control points for cubic B-spline interpolation. 91 | t: torch.Tensor 92 | `(b, 3)` batch of values in the range `[0, 1]` defining the position of 3D 93 | points to be interpolated within the interval `[p1, p2]` along dim -3, 94 | -2 and -1 of `control_points` 95 | matrix: torch.Tensor 96 | `(4, 4)` characteristic matrix for the spline. 97 | 98 | Returns 99 | ------- 100 | interpolated: 101 | `(b, c)` batch interpolated values in each channel. 102 | """ 103 | # extract (b, c, 4, 4) 2D control point planes at each point along the depth dim 104 | d0, d1, d2, d3 = einops.rearrange(control_points, 'b c d h w -> d b c h w') 105 | 106 | # separate u into components along depth and (height, width) dimensions 107 | t_d = t[:, 0] 108 | t_hw = t[:, [1, 2]] 109 | 110 | # 2d interpolation on each (height, width) plane at each depth 111 | p0 = interpolate_pieces_2d(control_points=d0, t=t_hw, matrix=matrix) 112 | p1 = interpolate_pieces_2d(control_points=d1, t=t_hw, matrix=matrix) 113 | p2 = interpolate_pieces_2d(control_points=d2, t=t_hw, matrix=matrix) 114 | p3 = interpolate_pieces_2d(control_points=d3, t=t_hw, matrix=matrix) 115 | 116 | # 1d interpolation of result along depth dim 117 | control_points = einops.rearrange([p0, p1, p2, p3], 'p b c -> b c p') 118 | return interpolate_pieces_1d(control_points=control_points, t=t_d, matrix=matrix) 119 | 120 | 121 | def interpolate_pieces_4d( 122 | control_points: torch.Tensor, t: torch.Tensor, matrix: torch.Tensor 123 | ) -> torch.Tensor: 124 | """Batched 4D cubic B-spline interpolation. 125 | 126 | Parameters 127 | ---------- 128 | control_points: torch.Tensor 129 | `(b, c, 4, 4, 4, 4)` batch of multichannel `(4, 4, 4, 4)` grids of uniformly 130 | spaced control points for cubic B-spline interpolation. 131 | t: torch.Tensor 132 | `(b, 4)` batch of values in the range `[0, 1]` defining the position of 4D 133 | points to be interpolated within the interval `[p1, p2]` along dims -4, -3, 134 | -2 and -1 of `control_points`. 135 | matrix: torch.Tensor 136 | `(4, 4)` characteristic matrix for the spline. 137 | 138 | Returns 139 | ------- 140 | interpolated: 141 | `(b, n)` batch of n-dimensional interpolated values. 142 | """ 143 | # extract (b, c, 4, 4, 4) 3D control point grids at each point along the time dim 144 | t0, t1, t2, t3 = einops.rearrange(control_points, 'b c u d h w -> u b c d h w') 145 | 146 | # separate u into components along time and (depth, height, width) dimensions 147 | t_t = t[:, 0] 148 | t_dhw = t[:, [1, 2, 3]] 149 | 150 | # 3D interpolation on each 3D grid along time dimension 151 | p0 = interpolate_pieces_3d(control_points=t0, t=t_dhw, matrix=matrix) 152 | p1 = interpolate_pieces_3d(control_points=t1, t=t_dhw, matrix=matrix) 153 | p2 = interpolate_pieces_3d(control_points=t2, t=t_dhw, matrix=matrix) 154 | p3 = interpolate_pieces_3d(control_points=t3, t=t_dhw, matrix=matrix) 155 | 156 | # 1d interpolation of result along time dim 157 | control_points = einops.rearrange([p0, p1, p2, p3], 'p b c -> b c p') 158 | return interpolate_pieces_1d(control_points=control_points, t=t_t, matrix=matrix) 159 | -------------------------------------------------------------------------------- /src/torch_cubic_spline_grids/pad_grids.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | 5 | def pad_grid_1d(grid: torch.Tensor) -> torch.Tensor: 6 | """Pad in the last dimension according to local gradients. 7 | 8 | e.g. [0, 1, 2] -> [-1, 0, 1, 2, 3] 9 | 10 | grid: torch.Tensor 11 | `(..., w)` array of values to be padded in last dimension. 12 | 13 | Returns 14 | ------- 15 | padded_grid: torch.Tensor 16 | `(..., w+2)` padded array. 17 | """ 18 | # remove singleton dimension if necessary 19 | w = grid.shape[-1] 20 | if w == 1: 21 | grid = einops.repeat(grid, '... w -> ... (repeat w)', repeat=2) 22 | 23 | # find values for padding at each end of width dim 24 | start = grid[..., 0] - (grid[..., 1] - grid[..., 0]) 25 | end = grid[..., -1] + (grid[..., -1] - grid[..., -2]) 26 | 27 | # reintroduce width dim lost during indexing 28 | start = einops.rearrange(start, '... -> ... 1') 29 | end = einops.repeat(end, '... -> ... 1') 30 | return torch.cat([start, grid, end], dim=-1) 31 | 32 | 33 | def pad_grid_2d(grid: torch.Tensor) -> torch.Tensor: 34 | """Pad a 2D grid of values according to local gradients. 35 | 36 | ``` 37 | e.g. [[-3, -2, -1, 0] 38 | [[0, 1] [-1, 0, 1, 2] 39 | [2, 3]] -> [ 1, 2, 3, 4] 40 | [ 3, 4, 5, 6]] 41 | ``` 42 | 43 | Parameters 44 | ---------- 45 | grid: torch.Tensor 46 | `(..., h, w)` array of values to be padded in height and width dimensions. 47 | 48 | Returns 49 | ------- 50 | padded_grid: torch.Tensor 51 | `(..., h+2, w+2)` padded array. 52 | """ 53 | # remove singleton dimension if necessary 54 | h = grid.shape[-2] 55 | if h == 1: 56 | grid = einops.repeat(grid, '... h w -> ... (repeat h) w', repeat=2) 57 | grid = pad_grid_1d(grid) # pad width dim (..., h, w+2) 58 | 59 | # find values for padding at each end of height dim 60 | h_start = grid[..., 0, :] - (grid[..., 1, :] - grid[..., 0, :]) 61 | h_end = grid[..., -1, :] + (grid[..., -1, :] - grid[..., -2, :]) 62 | 63 | # reintroduce height dim lost through indexing 64 | h_start = einops.rearrange(h_start, '... w -> ... 1 w') 65 | h_end = einops.rearrange(h_end, '... w -> ... 1 w') 66 | 67 | # pad height dim 68 | return torch.cat([h_start, grid, h_end], dim=-2) 69 | 70 | 71 | def pad_grid_3d(grid: torch.Tensor) -> torch.Tensor: 72 | """Pad a 3D grid of values according to local gradients. 73 | 74 | Parameters 75 | ---------- 76 | grid: torch.Tensor 77 | `(..., d, h, w)` array of values to be padded in depth, height and width 78 | dimensions. 79 | 80 | Returns 81 | ------- 82 | padded_grid: torch.Tensor 83 | `(..., d+2, h+2, w+2)` padded array. 84 | """ 85 | # remove singleton dimension if necessary 86 | d = grid.shape[-3] 87 | if d == 1: 88 | grid = einops.repeat(grid, '... d h w -> ... (repeat d) h w', repeat=2) 89 | 90 | # pad in height and width dims 91 | grid = pad_grid_2d(grid) 92 | 93 | # find values for padding at each end of depth dim 94 | d_start = grid[..., 0, :, :] - (grid[..., 1, :, :] - grid[..., 0, :, :]) 95 | d_end = grid[..., -1, :, :] + (grid[..., -1, :, :] - grid[..., -2, :, :]) 96 | 97 | # reintroduce depth dim dropped by indexing 98 | d_start = einops.rearrange(d_start, '... h w -> ... 1 h w') 99 | d_end = einops.rearrange(d_end, '... h w -> ... 1 h w') 100 | return torch.cat([d_start, grid, d_end], dim=-3) 101 | 102 | 103 | def pad_grid_4d(grid: torch.Tensor) -> torch.Tensor: 104 | """Pad a 4D grid of values according to local gradients. 105 | 106 | Parameters 107 | ---------- 108 | grid: torch.Tensor 109 | `(..., u, d, h, w)` array of values to be padded in time, depth, height and 110 | width dimensions. 111 | 112 | Returns 113 | ------- 114 | padded_grid: torch.Tensor 115 | `(..., u+2, d+2, h+2, w+2)` grid 116 | """ 117 | # remove singleton dimension if necessary 118 | t = grid.shape[-4] 119 | if t == 1: 120 | grid = einops.repeat(grid, '... u d h w -> ... (repeat u) d h w', repeat=2) 121 | 122 | # pad in height and width dims 123 | grid = pad_grid_3d(grid) # (..., u, d+2, h+2, w+2) 124 | 125 | # find values for padding at each end of time dim 126 | dt_start = grid[..., 1, :, :, :] - grid[..., 0, :, :, :] 127 | t_start = grid[..., 0, :, :, :] - dt_start 128 | dt_end = grid[..., -1, :, :, :] - grid[..., -2, :, :, :] 129 | t_end = grid[..., -1, :, :, :] + dt_end 130 | 131 | # reintroduce time dim dropped by indexing 132 | t_start = einops.rearrange(t_start, '... d h w -> ... 1 d h w') 133 | t_end = einops.rearrange(t_end, '... d h w -> ... 1 d h w') 134 | return torch.cat([t_start, grid, t_end], dim=-4) 135 | -------------------------------------------------------------------------------- /src/torch_cubic_spline_grids/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from typing import Callable, Iterable, Literal, Tuple 3 | 4 | import einops 5 | import torch 6 | 7 | 8 | def generate_sample_positions_for_padded_grid_1d( 9 | n_samples: int, device: torch.device 10 | ) -> torch.Tensor: 11 | """Generate a 1D vector of sample coordinates for a padded grid. 12 | 13 | Coordinate system is [0, 1] covering each dimension, pre-padding. 14 | e.g. for 6 samples on a padded grid 15 | `[-0.333, 0, 0.333, 0.666, 1, 1.333]` 16 | 17 | 18 | Parameters 19 | ---------- 20 | n_samples: int 21 | The number of samples on the grid prior to padding. 22 | device: torch.device 23 | The torch device on which to store the tensor. 24 | 25 | Returns 26 | ------- 27 | sample_coordinates: torch.Tensor 28 | The coordinates in the [0, 1] coordinate system of each sample on the 29 | padded grid. 30 | """ 31 | du = 1 / (n_samples - 1) 32 | sample_coordinates = torch.linspace(-du, 1 + du, steps=n_samples + 2, device=device) 33 | 34 | # fix for numerical stability issues around 0 and 1 35 | # ensures valid control point indices are selected 36 | epsilon = 1e-6 37 | sample_coordinates[1] = 0 - epsilon 38 | sample_coordinates[-2] = 1 + epsilon 39 | return sample_coordinates 40 | 41 | 42 | def find_control_point_idx_1d( 43 | sample_positions: torch.Tensor, query_points: torch.Tensor 44 | ) -> torch.Tensor: 45 | """Find indices of four control points required for cubic interpolation. 46 | 47 | E.g. for sample positions `[0, 1, 2, 3, 4, 5]` and query point `2.5` the control 48 | point indices would be `[1, 2, 3, 4]` as `2.5` lies between `2` and `3` 49 | 50 | Parameters 51 | ---------- 52 | sample_positions: torch.Tensor 53 | Monotonically increasing 1D array of sample positions. 54 | query_points: torch.Tensor 55 | `(b, )` array of query points for which control point indices. 56 | 57 | Returns 58 | ------- 59 | control_point_idx: torch.Tensor 60 | `(b, 4)` array of indices for control points. 61 | """ 62 | # find index of upper bound of interval for each query point 63 | sample_positions = sample_positions.contiguous() 64 | query_points = query_points.contiguous() 65 | iub_idx = torch.searchsorted(sample_positions, query_points, side='right') 66 | 67 | # generate (b, 4) array of indices of control points [s0, s1, s2, s3] 68 | # required for cubic interpolation 69 | s0_idx = iub_idx - 2 70 | s1_idx = iub_idx - 1 71 | s2_idx = iub_idx 72 | s3_idx = iub_idx + 1 73 | return einops.rearrange([s0_idx, s1_idx, s2_idx, s3_idx], 's b -> b s') 74 | 75 | 76 | def interpolants_to_interpolation_data_1d( 77 | interpolants: torch.Tensor, n_samples: int 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """Find the necessary data for piecewise cubic interpolation on a padded grid. 80 | 81 | Two pieces of data are required for piecewise cubic interpolation 82 | - four control points `[p0, p1, p2, p3]` 83 | - the interpolation coordinate 84 | 85 | The interpolation coordinate is a value in the range [0, 1] telling us how far into 86 | the interval `[p1, p2]` a query point is. 87 | 88 | This function returns the indices of the control points and the interpolation 89 | coordinate for a 1D grid. Returning the indices rather than the control points 90 | makes this more flexible for use in multidimensional grid interpolation which 91 | requires reusing and combining control point indices across dimensions. 92 | 93 | Parameters 94 | ---------- 95 | interpolants: torch.Tensor 96 | `(b, )` batch of values in range [0, 1] covering the dimension being 97 | interpolated. 98 | n_samples: int 99 | The number of samples on the grid being interpolated (prior to padding). 100 | 101 | Returns 102 | ------- 103 | control_point_idx, interpolation_coordinate: Tuple[torch.Tensor, torch.Tensor] 104 | The indices of control points `[p0, p1, p2, p3]` on a padded 1D grid and the 105 | interpolation coordinate associated with the interval `[p1, p2]`. 106 | """ 107 | interpolants = torch.clamp(interpolants, min=0, max=1) 108 | device = interpolants.device 109 | if n_samples > 1: 110 | grid_u = generate_sample_positions_for_padded_grid_1d(n_samples, device=device) 111 | control_point_idx = find_control_point_idx_1d( 112 | sample_positions=grid_u, query_points=interpolants 113 | ) 114 | u_p1 = grid_u[control_point_idx[:, 1]] 115 | du = 1 / (n_samples - 1) 116 | interpolation_coordinate = (interpolants - u_p1) / du 117 | else: 118 | control_point_idx = einops.repeat( 119 | torch.tensor([0, 1, 2, 3]), 'p -> b p', b=len(interpolants) 120 | ) 121 | interpolation_coordinate = einops.repeat( 122 | torch.tensor([0.5], device=device), '1 -> b', b=len(interpolants) 123 | ) 124 | return control_point_idx, interpolation_coordinate 125 | 126 | 127 | def coerce_to_multichannel_grid(grid: torch.Tensor, grid_ndim: int) -> torch.Tensor: 128 | """If missing, add a channel dimension to a multidimensional grid. 129 | 130 | e.g. for a 2D (h, w) grid 131 | `h w -> 1 h w` 132 | `c h w -> c h w` 133 | """ 134 | grid_is_multichannel = grid.ndim == grid_ndim + 1 135 | grid_is_single_channel = grid.ndim == grid_ndim 136 | if grid_is_single_channel is False and grid_is_multichannel is False: 137 | raise ValueError(f'expected a {grid_ndim}D grid, got {grid.ndim}') 138 | if grid_is_single_channel: 139 | grid = einops.rearrange(grid, '... -> 1 ...') 140 | return grid 141 | 142 | 143 | MonotonicityType = Literal['increasing', 'decreasing'] 144 | 145 | 146 | def transform_to_monotonic_nd( 147 | tensor: torch.Tensor, ndims: int, monotonicity: MonotonicityType 148 | ) -> torch.Tensor: 149 | """Transform tensor values, so they are monotonic across dimensions. 150 | 151 | Parameters 152 | ---------- 153 | tensor: torch.Tensor 154 | a tensor of the arbitrary shape. 155 | ndims: int 156 | the number of the dimensions counting from the last to the first, for which 157 | elements should be monotonic. 158 | monotonicity: str 159 | Either 'decreasing' or 'increasing'. 160 | 161 | Returns 162 | ------- 163 | tensor: torch.Tensor 164 | a tensor, with elements monotonic for the last `ndims` dimensions. 165 | """ 166 | monotonicity_function: Callable 167 | 168 | if monotonicity == 'increasing': 169 | monotonicity_function = torch.cummax 170 | elif monotonicity == 'decreasing': 171 | monotonicity_function = torch.cummin 172 | elif monotonicity != '': 173 | raise ValueError(f'Unsupported monotonicity type "{monotonicity}" specified.') 174 | 175 | for dim in range(1, ndims + 1): 176 | tensor, _ = monotonicity_function(tensor, dim=-dim) 177 | 178 | return tensor 179 | 180 | 181 | def batch(iterable: Sequence, n: int = 1) -> Iterable[Iterable]: 182 | """Split an iterable into batches of constant length.""" 183 | max_len = len(iterable) 184 | for idx in range(0, max_len, n): 185 | yield iterable[idx : min(idx + n, max_len)] 186 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/teamtomo/torch-cubic-spline-grids/9894cef28da6ae8055a956d83bcdee522907c672/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_grid_optimisation.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from torch_cubic_spline_grids import ( 4 | CubicBSplineGrid1d, 5 | CubicBSplineGrid2d, 6 | CubicBSplineGrid3d, 7 | CubicBSplineGrid4d, 8 | ) 9 | 10 | 11 | def test_1d_grid_optimisation(): 12 | grid_resolution = 6 13 | n_observations_per_iteration = 100 14 | grid = CubicBSplineGrid1d(resolution=grid_resolution, n_channels=1) 15 | 16 | def f(x: torch.Tensor, add_noise: bool = False): 17 | y = torch.sin(x * 2 * torch.pi) 18 | if add_noise is True: 19 | y += torch.normal(mean=torch.zeros(len(y)), std=0.3) 20 | return y 21 | 22 | optimiser = torch.optim.SGD(lr=0.1, params=grid.parameters()) 23 | for _i in range(5000): 24 | x = torch.rand(size=(n_observations_per_iteration,)) 25 | observations = f(x, add_noise=True) 26 | prediction = grid(x).squeeze() 27 | loss = torch.mean(torch.abs(prediction - observations)) 28 | loss.backward() 29 | optimiser.step() 30 | optimiser.zero_grad() 31 | 32 | x = torch.linspace(0, 1, steps=100) 33 | ground_truth = f(x) 34 | prediction = grid(x).squeeze() 35 | mean_absolute_error = torch.mean(torch.abs(prediction - ground_truth)) 36 | assert mean_absolute_error.item() < 0.02 37 | 38 | 39 | def test_1d_grid_optimization_decreasing(): 40 | grid_resolution = 8 41 | n_observations_per_iteration = 100 42 | grid = CubicBSplineGrid1d( 43 | resolution=grid_resolution, n_channels=1, monotonicity='decreasing' 44 | ) 45 | 46 | def f(x: torch.Tensor, add_noise: bool = False): 47 | y = torch.exp(-5 * x) 48 | if add_noise is True: 49 | y += torch.normal(mean=torch.zeros(len(y)), std=0.4) 50 | return y 51 | 52 | optimiser = torch.optim.SGD(lr=0.1, params=grid.parameters()) 53 | for _i in range(5000): 54 | x = torch.rand(size=(n_observations_per_iteration,)) 55 | observations = f(x, add_noise=True) 56 | prediction = grid(x).squeeze() 57 | loss = torch.mean(torch.abs(prediction - observations)) 58 | loss.backward() 59 | optimiser.step() 60 | optimiser.zero_grad() 61 | 62 | x = torch.linspace(0, 1, steps=100) 63 | prediction = grid(x).squeeze() 64 | 65 | eps = torch.tensor(1e-5, dtype=prediction.dtype) 66 | non_increasing = torch.diff(prediction, dim=-1) <= eps 67 | assert non_increasing.all().item() 68 | 69 | 70 | def test_2d_grid_optimisation(): 71 | grid_resolution = (3, 3) 72 | n_observations_per_iteration = 100 73 | grid = CubicBSplineGrid2d(resolution=grid_resolution, n_channels=1) 74 | 75 | def f(x: torch.Tensor, add_noise: bool = False): 76 | centered = x - 0.5 77 | y = torch.sqrt(torch.sum(centered**2, dim=-1)) # (x**2 + y**2) ** 0.5 78 | if add_noise is True: 79 | y += torch.normal(mean=torch.zeros(len(y)), std=0.3) 80 | return y 81 | 82 | optimiser = torch.optim.SGD(lr=0.3, params=grid.parameters()) 83 | for _i in range(1000): 84 | x = torch.rand(size=(n_observations_per_iteration, 2)) 85 | observations = f(x, add_noise=True) 86 | prediction = grid(x).squeeze() 87 | loss = torch.mean((prediction - observations) ** 2) 88 | loss.backward() 89 | optimiser.step() 90 | optimiser.zero_grad() 91 | 92 | _x = torch.linspace(0, 1, steps=100) 93 | x = torch.meshgrid(_x, _x, indexing='xy') 94 | x = einops.rearrange([*x], 'xy h w -> (h w) xy') 95 | ground_truth = f(x) 96 | prediction = grid(x).squeeze() 97 | mean_absolute_error = torch.mean(torch.abs(prediction - ground_truth)) 98 | assert mean_absolute_error.item() < 0.02 99 | 100 | 101 | def test_3d_grid_optimisation(): 102 | grid_resolution = (3, 3, 3) 103 | n_observations_per_iteration = 1000 104 | grid = CubicBSplineGrid3d(resolution=grid_resolution, n_channels=1) 105 | 106 | def f(x: torch.Tensor, add_noise: bool = False): 107 | centered = x - 0.5 108 | y = torch.sqrt(torch.sum(centered**2, dim=-1)) # (x**2 + y**2 + z**2) ** 0.5 109 | if add_noise is True: 110 | y += torch.normal(mean=torch.zeros(len(y)), std=0.3) 111 | return y 112 | 113 | optimiser = torch.optim.SGD(lr=0.3, params=grid.parameters()) 114 | for _i in range(1000): 115 | x = torch.rand(size=(n_observations_per_iteration, 3)) 116 | observations = f(x, add_noise=True) 117 | prediction = grid(x).squeeze() 118 | loss = torch.mean((prediction - observations) ** 2) 119 | loss.backward() 120 | optimiser.step() 121 | optimiser.zero_grad() 122 | 123 | _x = torch.linspace(0, 1, steps=100) 124 | x = torch.meshgrid(_x, _x, _x, indexing='xy') 125 | x = einops.rearrange([*x], 'xyz d h w -> (d h w) xyz') 126 | ground_truth = f(x) 127 | prediction = grid(x).squeeze() 128 | mean_absolute_error = torch.mean(torch.abs(prediction - ground_truth)) 129 | assert mean_absolute_error.item() < 0.02 130 | 131 | 132 | def test_4d_grid_optimisation(): 133 | grid_resolution = (3, 3, 3, 3) 134 | n_observations_per_iteration = 1000 135 | grid = CubicBSplineGrid4d(resolution=grid_resolution, n_channels=1) 136 | 137 | def f(x: torch.Tensor, add_noise: bool = False): 138 | centered = x - 0.5 139 | y = torch.sqrt(torch.sum(centered**2, dim=-1)) 140 | if add_noise is True: 141 | y += torch.normal(mean=torch.zeros(len(y)), std=0.3) 142 | return y 143 | 144 | optimiser = torch.optim.SGD(lr=0.9, params=grid.parameters()) 145 | for _i in range(1000): 146 | x = torch.rand(size=(n_observations_per_iteration, 4)) 147 | observations = f(x, add_noise=True) 148 | prediction = grid(x).squeeze() 149 | loss = torch.mean((prediction - observations) ** 2) 150 | loss.backward() 151 | optimiser.step() 152 | optimiser.zero_grad() 153 | 154 | _x = torch.linspace(0, 1, steps=10) 155 | x = torch.meshgrid(_x, _x, _x, _x, indexing='xy') 156 | x = einops.rearrange([*x], 'xyz u d h w -> (u d h w) xyz') 157 | ground_truth = f(x) 158 | prediction = grid(x).squeeze() 159 | mean_absolute_error = torch.mean(torch.abs(prediction - ground_truth)) 160 | assert mean_absolute_error.item() < 0.02 161 | -------------------------------------------------------------------------------- /tests/test_grids.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_cubic_spline_grids import ( 5 | CubicBSplineGrid1d, 6 | CubicBSplineGrid2d, 7 | CubicBSplineGrid3d, 8 | CubicBSplineGrid4d, 9 | CubicCatmullRomGrid1d, 10 | CubicCatmullRomGrid2d, 11 | CubicCatmullRomGrid3d, 12 | CubicCatmullRomGrid4d, 13 | ) 14 | 15 | 16 | @pytest.mark.parametrize( 17 | 'grid_cls', [CubicBSplineGrid1d, CubicCatmullRomGrid1d] 18 | ) 19 | def test_1d_grid_direct_instantiation(grid_cls): 20 | """Test grid instantiation with different types for resolution argument.""" 21 | grid = grid_cls() 22 | assert isinstance(grid, grid_cls) 23 | assert grid.data.shape == (1, 2) 24 | 25 | grid = grid_cls(resolution=5, n_channels=3) 26 | assert isinstance(grid, grid_cls) 27 | assert grid.data.shape == (3, 5) 28 | 29 | grid = grid_cls(resolution=(5,), n_channels=3) 30 | assert isinstance(grid, grid_cls) 31 | assert grid.data.shape == (3, 5) 32 | 33 | 34 | @pytest.mark.parametrize( 35 | 'grid_cls', [CubicBSplineGrid1d, CubicCatmullRomGrid1d] 36 | ) 37 | def test_1d_grid_instantiation_from_existing_data(grid_cls): 38 | """Test grid instantiation from existing data.""" 39 | grid = grid_cls.from_grid_data(data=torch.zeros(3, 5)) 40 | assert grid.ndim == 1 41 | assert grid.resolution == (5,) 42 | assert grid.n_channels == 3 43 | assert isinstance(grid._data, torch.nn.Parameter) 44 | 45 | 46 | @pytest.mark.parametrize( 47 | 'grid_cls', [CubicBSplineGrid1d, CubicCatmullRomGrid1d] 48 | ) 49 | def test_calling_1d_grid(grid_cls): 50 | """Test calling 1d grid.""" 51 | grid = grid_cls() 52 | expected = torch.tensor([0.]) 53 | for arg in (0.5, [0.5], torch.tensor([0.5])): 54 | result = grid(arg) 55 | assert torch.allclose(result, expected) 56 | 57 | 58 | @pytest.mark.parametrize( 59 | 'grid_cls', [CubicBSplineGrid1d, CubicCatmullRomGrid1d] 60 | ) 61 | def test_1d_grid_with_singleton_dimension(grid_cls): 62 | """Test that a 2D grid with a singleton dimension can be used.""" 63 | # singleton in width dim 64 | grid = grid_cls(resolution=1) 65 | result = grid(0.5) 66 | assert torch.allclose(result, torch.tensor([0.0])) 67 | 68 | 69 | @pytest.mark.parametrize( 70 | 'grid_cls', [CubicBSplineGrid1d, CubicCatmullRomGrid1d] 71 | ) 72 | def test_calling_1d_grid_with_stacked_coords(grid_cls): 73 | """Test calling a 1d grid with a multidimensional array of coordinates.""" 74 | grid = grid_cls(resolution=1) 75 | h, w = 4, 4 76 | 77 | # no explicit coordinate dimension 78 | result = grid(torch.rand(size=(h, w))) 79 | assert result.shape == (h, w) 80 | assert torch.allclose(result, torch.tensor([0]).float()) 81 | 82 | # with explicit coordinate dimension 83 | result = grid(torch.rand(size=(h, w, 1))) 84 | assert result.shape == (h, w, 1) 85 | assert torch.allclose(result, torch.tensor([0]).float()) 86 | 87 | 88 | def test_interpolation_matrix_device(): 89 | """Interpolation matrix should move when Module moves to a different device.""" 90 | grid = CubicBSplineGrid1d(resolution=3) 91 | assert grid.interpolation_matrix.device == torch.device('cpu') 92 | grid.to(torch.device('meta')) 93 | assert grid.interpolation_matrix.device == torch.device('meta') 94 | 95 | 96 | def test_grid_device(): 97 | """Grid data should move when Module moves to a different device.""" 98 | grid = CubicBSplineGrid1d(resolution=3) 99 | assert grid.data.device == torch.device('cpu') 100 | grid.to(torch.device('meta')) 101 | assert grid.data.device == torch.device('meta') 102 | 103 | 104 | @pytest.mark.parametrize( 105 | 'grid_cls', [CubicBSplineGrid2d, CubicCatmullRomGrid2d] 106 | ) 107 | def test_2d_grid_direct_instantiation(grid_cls): 108 | grid = grid_cls() 109 | assert isinstance(grid, grid_cls) 110 | assert grid.data.shape == (1, 2, 2) 111 | 112 | grid = grid_cls(resolution=(5, 4), n_channels=3) 113 | assert isinstance(grid, grid_cls) 114 | assert grid.data.shape == (3, 5, 4) 115 | 116 | 117 | @pytest.mark.parametrize( 118 | 'grid_cls', [CubicBSplineGrid2d, CubicCatmullRomGrid2d] 119 | ) 120 | def test_2d_grid_instantiation_from_existing_data(grid_cls): 121 | """Test grid instantiation from existing data.""" 122 | grid = grid_cls.from_grid_data(data=torch.zeros(3, 5, 4)) 123 | assert grid.ndim == 2 124 | assert grid.resolution == (5, 4) 125 | assert grid.n_channels == 3 126 | assert isinstance(grid._data, torch.nn.Parameter) 127 | 128 | 129 | @pytest.mark.parametrize( 130 | 'grid_cls', [CubicBSplineGrid2d, CubicCatmullRomGrid2d] 131 | ) 132 | def test_calling_2d_grid(grid_cls): 133 | """Test calling 2d grid.""" 134 | grid = grid_cls() 135 | expected = torch.tensor([0., 0.]) 136 | for arg in ([0.5, 0.5], torch.tensor([0.5, 0.5])): 137 | result = grid(arg) 138 | assert torch.allclose(result, expected) 139 | 140 | 141 | @pytest.mark.parametrize( 142 | 'grid_cls', [CubicBSplineGrid2d, CubicCatmullRomGrid2d] 143 | ) 144 | def test_2d_grid_with_singleton_dimension(grid_cls): 145 | """Test that a 2D grid with a singleton dimension can be used.""" 146 | # singleton in width dim 147 | grid = grid_cls(resolution=(2, 1)) 148 | result = grid([0.5, 0.5]) 149 | assert torch.allclose(result, torch.tensor([0.0, 0.0])) 150 | 151 | # singleton in height dim 152 | grid = grid_cls(resolution=(1, 2)) 153 | result = grid([0.5, 0.5]) 154 | assert torch.allclose(result, torch.tensor([0.0, 0.0])) 155 | 156 | 157 | @pytest.mark.parametrize( 158 | 'grid_cls', [CubicBSplineGrid2d, CubicCatmullRomGrid2d] 159 | ) 160 | def test_calling_2d_grid_with_stacked_coordinates(grid_cls): 161 | """Test calling a 2D grid with stacked coordinates.""" 162 | grid = grid_cls(resolution=(2, 2), n_channels=1) 163 | result = grid(torch.rand(size=(5, 5, 2))) 164 | assert result.shape == (5, 5, 1) 165 | 166 | grid = grid_cls(resolution=(2, 2), n_channels=2) 167 | result = grid(torch.rand(size=(5, 5, 2))) 168 | assert result.shape == (5, 5, 2) 169 | 170 | 171 | @pytest.mark.parametrize( 172 | 'grid_cls', [CubicBSplineGrid3d, CubicCatmullRomGrid3d] 173 | ) 174 | def test_3d_grid_direct_instantiation(grid_cls): 175 | grid = grid_cls() 176 | assert isinstance(grid, grid_cls) 177 | assert grid.data.shape == (1, 2, 2, 2) 178 | 179 | grid = grid_cls(resolution=(5, 4, 3), n_channels=2) 180 | assert isinstance(grid, grid_cls) 181 | assert grid.data.shape == (2, 5, 4, 3) 182 | 183 | 184 | @pytest.mark.parametrize( 185 | 'grid_cls', [CubicBSplineGrid3d, CubicCatmullRomGrid3d] 186 | ) 187 | def test_3d_grid_instantiation_from_existing_data(grid_cls): 188 | """Test grid instantiation from existing data.""" 189 | grid = grid_cls.from_grid_data(data=torch.zeros(2, 5, 4, 3)) 190 | assert grid.ndim == 3 191 | assert grid.resolution == (5, 4, 3) 192 | assert grid.n_channels == 2 193 | assert isinstance(grid._data, torch.nn.Parameter) 194 | 195 | 196 | @pytest.mark.parametrize( 197 | 'grid_cls', [CubicBSplineGrid3d, CubicCatmullRomGrid3d] 198 | ) 199 | def test_calling_3d_grid(grid_cls): 200 | """Test calling 3d grid.""" 201 | grid = grid_cls() 202 | expected = torch.tensor([0., 0., 0.]) 203 | for arg in ([0.5, 0.5, 0.5], torch.tensor([0.5, 0.5, 0.5])): 204 | result = grid(arg) 205 | assert torch.allclose(result, expected) 206 | 207 | 208 | @pytest.mark.parametrize( 209 | 'grid_cls', [CubicBSplineGrid3d, CubicCatmullRomGrid3d] 210 | ) 211 | def test_calling_3d_grid_with_stacked_coordinates(grid_cls): 212 | """Test calling 3d grid with stacked coordinates.""" 213 | grid = grid_cls() 214 | d, h, w = 4, 4, 4 215 | result = grid(torch.rand(size=(d, h, w, 3))) 216 | assert result.shape == (d, h, w, 1) 217 | 218 | 219 | @pytest.mark.parametrize( 220 | 'grid_cls', [CubicBSplineGrid3d, CubicCatmullRomGrid3d] 221 | ) 222 | def test_3d_grid_with_singleton_dimension(grid_cls): 223 | """Test that a 3D grid with a singleton dimension can be used.""" 224 | # singleton in width dim 225 | grid = grid_cls(resolution=(2, 2, 1)) 226 | result = grid([0.5, 0.5, 0.5]) 227 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0])) 228 | 229 | # singleton in height dim 230 | grid = grid_cls(resolution=(2, 1, 2)) 231 | result = grid([0.5, 0.5, 0.5]) 232 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0])) 233 | 234 | # singleton in depth dim 235 | grid = grid_cls(resolution=(1, 2, 2)) 236 | result = grid([0.5, 0.5, 0.5]) 237 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0])) 238 | 239 | 240 | @pytest.mark.parametrize( 241 | 'grid_cls', [CubicBSplineGrid4d, CubicCatmullRomGrid4d] 242 | ) 243 | def test_4d_grid_direct_instantiation(grid_cls): 244 | grid = grid_cls() 245 | assert isinstance(grid, grid_cls) 246 | assert grid.data.shape == (1, 2, 2, 2, 2) 247 | 248 | grid = grid_cls(resolution=(6, 5, 4, 3), n_channels=2) 249 | assert isinstance(grid, grid_cls) 250 | assert grid.data.shape == (2, 6, 5, 4, 3) 251 | 252 | 253 | @pytest.mark.parametrize( 254 | 'grid_cls', [CubicBSplineGrid4d, CubicCatmullRomGrid4d] 255 | ) 256 | def test_4d_grid_instantiation_from_existing_data(grid_cls): 257 | """Test grid instantiation from existing data.""" 258 | grid = grid_cls.from_grid_data(data=torch.zeros(2, 6, 5, 4, 3)) 259 | assert grid.ndim == 4 260 | assert grid.resolution == (6, 5, 4, 3) 261 | assert grid.n_channels == 2 262 | assert isinstance(grid._data, torch.nn.Parameter) 263 | 264 | 265 | @pytest.mark.parametrize( 266 | 'grid_cls', [CubicBSplineGrid4d, CubicCatmullRomGrid4d] 267 | ) 268 | def test_calling_4d_grid(grid_cls): 269 | """Test calling 4d grid.""" 270 | grid = grid_cls() 271 | expected = torch.tensor([0., 0., 0., 0.]) 272 | for arg in ([0.5, 0.5, 0.5, 0.5], torch.tensor([0.5, 0.5, 0.5, 0.5])): 273 | result = grid(arg) 274 | assert torch.allclose(result, expected) 275 | 276 | 277 | @pytest.mark.parametrize( 278 | 'grid_cls', [CubicBSplineGrid4d, CubicCatmullRomGrid4d] 279 | ) 280 | def test_calling_4d_grid_with_stacked_coordinates(grid_cls): 281 | """Test calling 3d grid with stacked coordinates.""" 282 | grid = grid_cls() 283 | t, d, h, w = 2, 4, 4, 4 284 | result = grid(torch.rand(size=(t, d, h, w, 4))) 285 | assert result.shape == (t, d, h, w, 1) 286 | 287 | 288 | @pytest.mark.parametrize( 289 | 'grid_cls', [CubicBSplineGrid4d, CubicCatmullRomGrid4d] 290 | ) 291 | def test_4d_grid_with_singleton_dimension(grid_cls): 292 | """Test that a 4D grid with a singleton dimension can be used.""" 293 | # singleton in width dim 294 | grid = grid_cls(resolution=(2, 2, 2, 1)) 295 | result = grid([0.5, 0.5, 0.5, 0.5]) 296 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0, 0.0])) 297 | 298 | # singleton in height dim 299 | grid = grid_cls(resolution=(2, 2, 1, 2)) 300 | result = grid([0.5, 0.5, 0.5, 0.5]) 301 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0, 0.0])) 302 | 303 | # singleton in depth dim 304 | grid = grid_cls(resolution=(2, 1, 2, 2)) 305 | result = grid([0.5, 0.5, 0.5, 0.5]) 306 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0, 0.0])) 307 | 308 | # singleton in time dim 309 | grid = grid_cls(resolution=(1, 2, 2, 2)) 310 | result = grid([0.5, 0.5, 0.5, 0.5]) 311 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0, 0.0])) 312 | 313 | # multiple singletons 314 | grid = grid_cls(resolution=(1, 1, 1, 1)) 315 | result = grid([0.5, 0.5, 0.5, 0.5]) 316 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0, 0.0])) 317 | -------------------------------------------------------------------------------- /tests/test_interpolate_grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_cubic_spline_grids import interpolate_grids 4 | from torch_cubic_spline_grids._constants import CUBIC_B_SPLINE_MATRIX 5 | 6 | 7 | def test_interpolate_grid_1d(): 8 | """Check that 1d interpolation works as expected.""" 9 | grid = torch.tensor([0, 1, 2, 3, 4, 5]).float() 10 | u = torch.tensor([0.5]).view((1, 1)) 11 | result = interpolate_grids.interpolate_grid_1d( 12 | grid, u, matrix=CUBIC_B_SPLINE_MATRIX 13 | ) 14 | expected = torch.tensor([[2.5]]) 15 | assert torch.allclose(result, expected) 16 | 17 | 18 | def test_interpolate_grid_1d_approx(): 19 | """Check that 1D interpolation approximates a function.""" 20 | grid_x = torch.linspace(0, 2 * torch.pi, steps=50) 21 | grid_y = torch.sin(grid_x) 22 | sample_x = torch.linspace(0, 1, steps=1000).view((-1, 1)) 23 | sample_y = interpolate_grids.interpolate_grid_1d( 24 | grid_y, sample_x, matrix=CUBIC_B_SPLINE_MATRIX 25 | ) 26 | ground_truth_y = torch.sin(sample_x * 2 * torch.pi) 27 | mean_absolute_error = torch.mean(torch.abs(sample_y - ground_truth_y)) 28 | assert mean_absolute_error <= 0.01 29 | 30 | 31 | def test_interpolate_grid_2d(): 32 | """Check that 2D interpolation works.""" 33 | grid = torch.tensor( 34 | [[0, 1, 2, 3], 35 | [4, 5, 6, 7], 36 | [8, 9, 10, 11], 37 | [12, 13, 14, 15]] 38 | ).float() 39 | u = torch.tensor([0.5, 0.5]).view(1, 2) 40 | result = interpolate_grids.interpolate_grid_2d( 41 | grid, u, matrix=CUBIC_B_SPLINE_MATRIX 42 | ) 43 | expected = torch.tensor([7.5]) 44 | assert torch.allclose(result, expected) 45 | 46 | 47 | def test_interpolate_grid_3d(): 48 | """Check that 3D interpolation works.""" 49 | grid = torch.tensor( 50 | [[[0, 1, 2, 3], 51 | [4, 5, 6, 7], 52 | [8, 9, 10, 11], 53 | [12, 13, 14, 15]], 54 | [[16, 17, 18, 19], 55 | [20, 21, 22, 23], 56 | [24, 25, 26, 27], 57 | [28, 29, 30, 31]], 58 | [[32, 33, 34, 35], 59 | [36, 37, 38, 39], 60 | [40, 41, 42, 43], 61 | [44, 45, 46, 47]], 62 | [[48, 49, 50, 51], 63 | [52, 53, 54, 55], 64 | [56, 57, 58, 59], 65 | [60, 61, 62, 63]]], 66 | ).float() 67 | u = torch.tensor([[0.5, 0.5, 0.5]]).view(1, 3) 68 | result = interpolate_grids.interpolate_grid_3d(grid, u, 69 | matrix=CUBIC_B_SPLINE_MATRIX) 70 | expected = torch.tensor([31.5]) 71 | assert torch.allclose(result, expected) 72 | 73 | 74 | def test_interpolate_grid_4d(): 75 | """Check that 4D interpolation works as expected.""" 76 | grid = torch.tensor( 77 | [[[[0, 1, 2, 3], 78 | [4, 5, 6, 7], 79 | [8, 9, 10, 11], 80 | [12, 13, 14, 15]], 81 | [[16, 17, 18, 19], 82 | [20, 21, 22, 23], 83 | [24, 25, 26, 27], 84 | [28, 29, 30, 31]], 85 | [[32, 33, 34, 35], 86 | [36, 37, 38, 39], 87 | [40, 41, 42, 43], 88 | [44, 45, 46, 47]], 89 | [[48, 49, 50, 51], 90 | [52, 53, 54, 55], 91 | [56, 57, 58, 59], 92 | [60, 61, 62, 63]]], 93 | [[[64, 65, 66, 67], 94 | [68, 69, 70, 71], 95 | [72, 73, 74, 75], 96 | [76, 77, 78, 79]], 97 | [[80, 81, 82, 83], 98 | [84, 85, 86, 87], 99 | [88, 89, 90, 91], 100 | [92, 93, 94, 95]], 101 | [[96, 97, 98, 99], 102 | [100, 101, 102, 103], 103 | [104, 105, 106, 107], 104 | [108, 109, 110, 111]], 105 | [[112, 113, 114, 115], 106 | [116, 117, 118, 119], 107 | [120, 121, 122, 123], 108 | [124, 125, 126, 127]]], 109 | [[[128, 129, 130, 131], 110 | [132, 133, 134, 135], 111 | [136, 137, 138, 139], 112 | [140, 141, 142, 143]], 113 | [[144, 145, 146, 147], 114 | [148, 149, 150, 151], 115 | [152, 153, 154, 155], 116 | [156, 157, 158, 159]], 117 | [[160, 161, 162, 163], 118 | [164, 165, 166, 167], 119 | [168, 169, 170, 171], 120 | [172, 173, 174, 175]], 121 | [[176, 177, 178, 179], 122 | [180, 181, 182, 183], 123 | [184, 185, 186, 187], 124 | [188, 189, 190, 191]]], 125 | [[[192, 193, 194, 195], 126 | [196, 197, 198, 199], 127 | [200, 201, 202, 203], 128 | [204, 205, 206, 207]], 129 | [[208, 209, 210, 211], 130 | [212, 213, 214, 215], 131 | [216, 217, 218, 219], 132 | [220, 221, 222, 223]], 133 | [[224, 225, 226, 227], 134 | [228, 229, 230, 231], 135 | [232, 233, 234, 235], 136 | [236, 237, 238, 239]], 137 | [[240, 241, 242, 243], 138 | [244, 245, 246, 247], 139 | [248, 249, 250, 251], 140 | [252, 253, 254, 255]]]] 141 | ).float() 142 | u = torch.tensor([0.5, 0.5, 0.5, 0.5]).view(1, 4) 143 | result = interpolate_grids.interpolate_grid_4d( 144 | grid, u, matrix=CUBIC_B_SPLINE_MATRIX 145 | ) 146 | expected = torch.tensor([127.5]) 147 | assert torch.allclose(result, expected) 148 | -------------------------------------------------------------------------------- /tests/test_interpolate_pieces.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_cubic_spline_grids import interpolate_pieces 4 | from torch_cubic_spline_grids._constants import CUBIC_B_SPLINE_MATRIX 5 | 6 | 7 | def test_interpolate_pieces_1d(): 8 | """test that cubic B-spline interpolation results in expected values.""" 9 | t = 0.5 10 | points = torch.tensor([-1.5, 5.1, 2.2, 6.8]) 11 | result = interpolate_pieces.interpolate_pieces_1d( 12 | control_points=points.view(1, 1, 4), # (b, c, 4) 13 | t=torch.tensor([t]), 14 | matrix=CUBIC_B_SPLINE_MATRIX, 15 | ) 16 | expected = torch.tensor([1, t, t ** 2, t ** 3]) @ CUBIC_B_SPLINE_MATRIX @ points 17 | assert torch.allclose(result, expected) 18 | 19 | 20 | def test_interpolate_pieces_1d_with_batched_queries(): 21 | """test batched evaluation over one 'piece' (set of 4 control points).""" 22 | pieces = torch.tensor([[0, 1, 2, 3]]).float().view(1, 1, 4) # (b, c, 4) 23 | t = torch.tensor([0, 0.5, 1]) # (b, ) 24 | result = interpolate_pieces.interpolate_pieces_1d( 25 | pieces, t, matrix=CUBIC_B_SPLINE_MATRIX 26 | ) 27 | 28 | # cubic b spline intepolation should be equivalent to linear interpolation 29 | # for four control points on the same line 30 | expected = torch.tensor([1, 1.5, 2]).view((3, 1)) 31 | assert torch.allclose(result, expected) 32 | 33 | 34 | def test_interpolate_pieces_1d_with_batched_pieces_and_queries(): 35 | """test batched evaluation over batched 'pieces' (sets of 4 control points).""" 36 | pieces = torch.tensor( 37 | [[0, 1, 2, 3], 38 | [2, 3, 4, 5]] 39 | ).float().view(2, 1, 4) # (b, c, 4) 40 | t = torch.tensor([0.5, 0.5]) # (b, ) 41 | result = interpolate_pieces.interpolate_pieces_1d( 42 | pieces, t, matrix=CUBIC_B_SPLINE_MATRIX 43 | ) # (b, c) 44 | 45 | # cubic b spline intepolation should be equivalent to linear interpolation 46 | # for four control points on the same line 47 | expected = torch.tensor([[1.5, 3.5]]).view(2, 1) 48 | assert torch.allclose(result, expected) 49 | 50 | 51 | def test_interpolate_pieces_2d(): 52 | """test evaluation of 2D cubic B-spline interpolation.""" 53 | control_points = torch.tensor( 54 | [[0, 1, 2, 3], 55 | [4, 5, 6, 7], 56 | [8, 9, 10, 11], 57 | [12, 13, 14, 15]] 58 | ).float().view(1, 1, 4, 4) 59 | t = torch.tensor([0.5, 0.5]).view(1, 2) 60 | result = interpolate_pieces.interpolate_pieces_2d( 61 | control_points, t, matrix=CUBIC_B_SPLINE_MATRIX 62 | ) 63 | expected = torch.tensor([7.5]) 64 | assert torch.allclose(result, expected) 65 | 66 | 67 | def test_interpolate_pieces_3d(): 68 | """test evaluation of 3D cubic B-spline interpolation.""" 69 | control_points = torch.tensor( 70 | [[[0, 1, 2, 3], 71 | [4, 5, 6, 7], 72 | [8, 9, 10, 11], 73 | [12, 13, 14, 15]], 74 | [[16, 17, 18, 19], 75 | [20, 21, 22, 23], 76 | [24, 25, 26, 27], 77 | [28, 29, 30, 31]], 78 | [[32, 33, 34, 35], 79 | [36, 37, 38, 39], 80 | [40, 41, 42, 43], 81 | [44, 45, 46, 47]], 82 | [[48, 49, 50, 51], 83 | [52, 53, 54, 55], 84 | [56, 57, 58, 59], 85 | [60, 61, 62, 63]]], 86 | ).float().view(1, 1, 4, 4, 4) 87 | t = torch.tensor([[0.5, 0.5, 0.5]]).view(1, 3) 88 | result = interpolate_pieces.interpolate_pieces_3d( 89 | control_points, t, matrix=CUBIC_B_SPLINE_MATRIX 90 | ) 91 | expected = torch.tensor([31.5]) 92 | assert torch.allclose(result, expected) 93 | 94 | 95 | def test_interpolate_pieces_4d(): 96 | """test evaluation of 4D cubic B-spline interpolation.""" 97 | control_points = torch.tensor( 98 | [[[[0, 1, 2, 3], 99 | [4, 5, 6, 7], 100 | [8, 9, 10, 11], 101 | [12, 13, 14, 15]], 102 | [[16, 17, 18, 19], 103 | [20, 21, 22, 23], 104 | [24, 25, 26, 27], 105 | [28, 29, 30, 31]], 106 | [[32, 33, 34, 35], 107 | [36, 37, 38, 39], 108 | [40, 41, 42, 43], 109 | [44, 45, 46, 47]], 110 | [[48, 49, 50, 51], 111 | [52, 53, 54, 55], 112 | [56, 57, 58, 59], 113 | [60, 61, 62, 63]]], 114 | [[[64, 65, 66, 67], 115 | [68, 69, 70, 71], 116 | [72, 73, 74, 75], 117 | [76, 77, 78, 79]], 118 | [[80, 81, 82, 83], 119 | [84, 85, 86, 87], 120 | [88, 89, 90, 91], 121 | [92, 93, 94, 95]], 122 | [[96, 97, 98, 99], 123 | [100, 101, 102, 103], 124 | [104, 105, 106, 107], 125 | [108, 109, 110, 111]], 126 | [[112, 113, 114, 115], 127 | [116, 117, 118, 119], 128 | [120, 121, 122, 123], 129 | [124, 125, 126, 127]]], 130 | [[[128, 129, 130, 131], 131 | [132, 133, 134, 135], 132 | [136, 137, 138, 139], 133 | [140, 141, 142, 143]], 134 | [[144, 145, 146, 147], 135 | [148, 149, 150, 151], 136 | [152, 153, 154, 155], 137 | [156, 157, 158, 159]], 138 | [[160, 161, 162, 163], 139 | [164, 165, 166, 167], 140 | [168, 169, 170, 171], 141 | [172, 173, 174, 175]], 142 | [[176, 177, 178, 179], 143 | [180, 181, 182, 183], 144 | [184, 185, 186, 187], 145 | [188, 189, 190, 191]]], 146 | [[[192, 193, 194, 195], 147 | [196, 197, 198, 199], 148 | [200, 201, 202, 203], 149 | [204, 205, 206, 207]], 150 | [[208, 209, 210, 211], 151 | [212, 213, 214, 215], 152 | [216, 217, 218, 219], 153 | [220, 221, 222, 223]], 154 | [[224, 225, 226, 227], 155 | [228, 229, 230, 231], 156 | [232, 233, 234, 235], 157 | [236, 237, 238, 239]], 158 | [[240, 241, 242, 243], 159 | [244, 245, 246, 247], 160 | [248, 249, 250, 251], 161 | [252, 253, 254, 255]]]] 162 | ).float().view(1, 1, 4, 4, 4, 4) 163 | t = torch.tensor([0.5, 0.5, 0.5, 0.5]).view(1, 4) 164 | result = interpolate_pieces.interpolate_pieces_4d( 165 | control_points, t, matrix=CUBIC_B_SPLINE_MATRIX 166 | ) 167 | expected = torch.tensor([127.5]) 168 | assert torch.allclose(result, expected) 169 | -------------------------------------------------------------------------------- /tests/test_modules.py: -------------------------------------------------------------------------------- 1 | from torch_cubic_spline_grids import ( 2 | CubicBSplineGrid1d, 3 | CubicBSplineGrid2d, 4 | CubicBSplineGrid3d, 5 | CubicBSplineGrid4d, 6 | ) 7 | 8 | 9 | def test_grid_class_instantiation(): 10 | grid_classes = [ 11 | CubicBSplineGrid1d, 12 | CubicBSplineGrid2d, 13 | CubicBSplineGrid3d, 14 | CubicBSplineGrid4d 15 | ] 16 | for grid_class in grid_classes: 17 | instance = grid_class() 18 | assert isinstance(instance, grid_class) 19 | assert len(list(instance.parameters())) > 0 -------------------------------------------------------------------------------- /tests/test_pad_grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_cubic_spline_grids import pad_grids 4 | 5 | 6 | def test_pad_1d(): 7 | grid = torch.arange(3) 8 | padded_grid = pad_grids.pad_grid_1d(grid) 9 | expected = torch.tensor([-1, 0, 1, 2, 3]) 10 | assert torch.allclose(padded_grid, expected) 11 | 12 | 13 | def test_pad_2d(): 14 | grid = torch.tensor( 15 | [[0, 1], 16 | [2, 3]] 17 | ) 18 | padded_grid = pad_grids.pad_grid_2d(grid) 19 | expected = torch.tensor( 20 | [[-3, -2, -1, 0], 21 | [-1, 0, 1, 2], 22 | [1, 2, 3, 4], 23 | [3, 4, 5, 6]] 24 | ) 25 | assert torch.allclose(padded_grid, expected) 26 | 27 | 28 | def test_pad_3d(): 29 | grid = torch.tensor( 30 | [[[0, 1], 31 | [2, 3]], 32 | [[4, 5], 33 | [6, 7]]] 34 | ) 35 | padded_grid = pad_grids.pad_grid_3d(grid) 36 | expected = torch.tensor( 37 | [[[-7, -6, -5, -4], 38 | [-5, -4, -3, -2], 39 | [-3, -2, -1, 0], 40 | [-1, 0, 1, 2]], 41 | 42 | [[-3, -2, -1, 0], 43 | [-1, 0, 1, 2], 44 | [1, 2, 3, 4], 45 | [3, 4, 5, 6]], 46 | 47 | [[1, 2, 3, 4], 48 | [3, 4, 5, 6], 49 | [5, 6, 7, 8], 50 | [7, 8, 9, 10]], 51 | 52 | [[5, 6, 7, 8], 53 | [7, 8, 9, 10], 54 | [9, 10, 11, 12], 55 | [11, 12, 13, 14]]] 56 | ) 57 | assert torch.allclose(padded_grid, expected) 58 | 59 | 60 | def test_pad_4d(): 61 | grid = torch.tensor( 62 | [[[[0, 1], 63 | [2, 3]], 64 | [[4, 5], 65 | [6, 7]]], 66 | [[[8, 9], 67 | [10, 11]], 68 | [[12, 13], 69 | [14, 15]]]] 70 | ) 71 | padded_grid = pad_grids.pad_grid_4d(grid) 72 | expected = torch.tensor( 73 | [[[[-15, -14, -13, -12], 74 | [-13, -12, -11, -10], 75 | [-11, -10, -9, -8], 76 | [-9, -8, -7, -6]], 77 | [[-11, -10, -9, -8], 78 | [-9, -8, -7, -6], 79 | [-7, -6, -5, -4], 80 | [-5, -4, -3, -2]], 81 | [[-7, -6, -5, -4], 82 | [-5, -4, -3, -2], 83 | [-3, -2, -1, 0], 84 | [-1, 0, 1, 2]], 85 | [[-3, -2, -1, 0], 86 | [-1, 0, 1, 2], 87 | [1, 2, 3, 4], 88 | [3, 4, 5, 6]]], 89 | [[[-7, -6, -5, -4], 90 | [-5, -4, -3, -2], 91 | [-3, -2, -1, 0], 92 | [-1, 0, 1, 2]], 93 | [[-3, -2, -1, 0], 94 | [-1, 0, 1, 2], 95 | [1, 2, 3, 4], 96 | [3, 4, 5, 6]], 97 | [[1, 2, 3, 4], 98 | [3, 4, 5, 6], 99 | [5, 6, 7, 8], 100 | [7, 8, 9, 10]], 101 | [[5, 6, 7, 8], 102 | [7, 8, 9, 10], 103 | [9, 10, 11, 12], 104 | [11, 12, 13, 14]]], 105 | [[[1, 2, 3, 4], 106 | [3, 4, 5, 6], 107 | [5, 6, 7, 8], 108 | [7, 8, 9, 10]], 109 | [[5, 6, 7, 8], 110 | [7, 8, 9, 10], 111 | [9, 10, 11, 12], 112 | [11, 12, 13, 14]], 113 | [[9, 10, 11, 12], 114 | [11, 12, 13, 14], 115 | [13, 14, 15, 16], 116 | [15, 16, 17, 18]], 117 | [[13, 14, 15, 16], 118 | [15, 16, 17, 18], 119 | [17, 18, 19, 20], 120 | [19, 20, 21, 22]]], 121 | [[[9, 10, 11, 12], 122 | [11, 12, 13, 14], 123 | [13, 14, 15, 16], 124 | [15, 16, 17, 18]], 125 | [[13, 14, 15, 16], 126 | [15, 16, 17, 18], 127 | [17, 18, 19, 20], 128 | [19, 20, 21, 22]], 129 | [[17, 18, 19, 20], 130 | [19, 20, 21, 22], 131 | [21, 22, 23, 24], 132 | [23, 24, 25, 26]], 133 | [[21, 22, 23, 24], 134 | [23, 24, 25, 26], 135 | [25, 26, 27, 28], 136 | [27, 28, 29, 30]]]] 137 | ) 138 | assert torch.allclose(padded_grid, expected) 139 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from torch_cubic_spline_grids.utils import find_control_point_idx_1d, batch 5 | 6 | 7 | def test_find_control_points(): 8 | sample_positions = torch.tensor([0, 1, 2, 3, 4, 5, 6]) 9 | 10 | # sample between points should yield four closest points 11 | result = find_control_point_idx_1d(sample_positions, torch.tensor([2.5])) 12 | expected = torch.tensor([[1, 2, 3, 4]]) 13 | assert torch.allclose(result, expected) 14 | 15 | # sample on point should be included as lower bound in interval 16 | result = find_control_point_idx_1d(sample_positions, torch.tensor([2])) 17 | expected = torch.tensor([[1, 2, 3, 4]]) 18 | assert torch.allclose(result, expected) 19 | 20 | # check the same is true for 3, the upper bound of the same interval 21 | result = find_control_point_idx_1d(sample_positions, torch.tensor([3])) 22 | expected = torch.tensor([[2, 3, 4, 5]]) 23 | assert torch.allclose(result, expected) 24 | 25 | 26 | def test_batch(): 27 | """All items should be present in minibatches.""" 28 | l = [0, 1, 2, 3, 4, 5, 6] 29 | minibatches = [minibatch for minibatch in batch(l, n=3)] 30 | expected = [[0, 1, 2], [3, 4, 5], [6]] 31 | assert minibatches == expected 32 | 33 | 34 | def test_restacking_batch(): 35 | """Ensure entries get restacked by cat the same way as they are unstacked.""" 36 | batched_input = torch.rand(size=(10, 3)) # (b, d) 37 | minibatches = [minibatch for minibatch in batch(batched_input, n=3)] 38 | restacked_minibatches = torch.cat(minibatches, dim=0) 39 | assert torch.allclose(batched_input, restacked_minibatches) 40 | --------------------------------------------------------------------------------