├── .github
├── ISSUE_TEMPLATE.md
├── TEST_FAIL_TEMPLATE.md
├── dependabot.yml
└── workflows
│ └── ci.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── examples
├── optimise_1d_grid_model.ipynb
└── optimise_2d_grid_model.ipynb
├── pyproject.toml
├── src
└── torch_cubic_spline_grids
│ ├── __init__.py
│ ├── _base_cubic_grid.py
│ ├── _constants.py
│ ├── b_spline_grids.py
│ ├── catmull_rom_grids.py
│ ├── interpolate_grids.py
│ ├── interpolate_pieces.py
│ ├── pad_grids.py
│ └── utils.py
└── tests
├── __init__.py
├── test_grid_optimisation.py
├── test_grids.py
├── test_interpolate_grid.py
├── test_interpolate_pieces.py
├── test_modules.py
├── test_pad_grid.py
└── test_utils.py
/.github/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | * torch-cubic-b-spline-grid version:
2 | * Python version:
3 | * Operating System:
4 |
5 | ### Description
6 |
7 | Describe what you were trying to get done.
8 | Tell us what happened, what went wrong, and what you expected to happen.
9 |
10 | ### What I Did
11 |
12 | ```
13 | Paste the command(s) you ran and the output.
14 | If there was a crash, please include the traceback here.
15 | ```
16 |
--------------------------------------------------------------------------------
/.github/TEST_FAIL_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: "{{ env.TITLE }}"
3 | labels: [bug]
4 | ---
5 | The {{ workflow }} workflow failed on {{ date | date("YYYY-MM-DD HH:mm") }} UTC
6 |
7 | The most recent failing test was on {{ env.PLATFORM }} py{{ env.PYTHON }}
8 | with commit: {{ sha }}
9 |
10 | Full run: https://github.com/{{ repo }}/actions/runs/{{ env.RUN_ID }}
11 |
12 | (This post will be updated if another test fails, as long as this issue remains open.)
13 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
2 |
3 | version: 2
4 | updates:
5 | - package-ecosystem: "github-actions"
6 | directory: "/"
7 | schedule:
8 | interval: "weekly"
9 | commit-message:
10 | prefix: "ci(dependabot):"
11 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | permissions:
4 | contents: write
5 |
6 | on:
7 | push:
8 | branches:
9 | - main
10 | tags:
11 | - "v*"
12 | pull_request: {}
13 | workflow_dispatch:
14 |
15 | jobs:
16 | check-manifest:
17 | name: Check Manifest
18 | runs-on: ubuntu-latest
19 | steps:
20 | - uses: actions/checkout@v5
21 | - uses: actions/setup-python@v6
22 | with:
23 | python-version: "3.x"
24 | - run: pip install check-manifest && check-manifest
25 |
26 | test:
27 | name: ${{ matrix.platform }} (${{ matrix.python-version }})
28 | runs-on: ${{ matrix.platform }}
29 | strategy:
30 | fail-fast: false
31 | matrix:
32 | python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
33 | platform: [
34 | ubuntu-latest,
35 | # macos-latest,
36 | # windows-latest,
37 | ]
38 |
39 | steps:
40 | - name: Cancel Previous Runs
41 | uses: styfle/cancel-workflow-action@0.12.1
42 | with:
43 | access_token: ${{ github.token }}
44 |
45 | - uses: actions/checkout@v5
46 |
47 | - name: Set up Python ${{ matrix.python-version }}
48 | uses: actions/setup-python@v6
49 | with:
50 | python-version: ${{ matrix.python-version }}
51 |
52 | - name: Install dependencies
53 | run: |
54 | python -m pip install -U pip
55 | python -m pip install -e .
56 | python -m pip install pytest pytest-cov
57 |
58 | - name: Test with pytest
59 | run: python -m pytest
60 | env:
61 | PLATFORM: ${{ matrix.platform }}
62 |
63 | - name: Coverage
64 | uses: codecov/codecov-action@v5
65 |
66 | deploy:
67 | name: Deploy
68 | needs: test
69 | if: "success() && startsWith(github.ref, 'refs/tags/')"
70 | runs-on: ubuntu-latest
71 |
72 | steps:
73 | - uses: actions/checkout@v5
74 |
75 | - name: Set up Python
76 | uses: actions/setup-python@v6
77 | with:
78 | python-version: "3.x"
79 |
80 | - name: install
81 | run: |
82 | git tag
83 | pip install -U pip
84 | pip install -U build twine
85 | python -m build
86 | twine check dist/*
87 | ls -lh dist
88 |
89 | - name: Build and publish
90 | run: twine upload dist/*
91 | env:
92 | TWINE_USERNAME: __token__
93 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }}
94 |
95 | - uses: softprops/action-gh-release@v2
96 | with:
97 | generate_release_notes: true
98 |
99 | # [WIP]
100 | # https://python-semantic-release.readthedocs.io/en/latest/automatic-releases/github-actions.html
101 | # release:
102 | # runs-on: ubuntu-latest
103 | # concurrency: release
104 |
105 | # steps:
106 | # - uses: actions/checkout@v5
107 | # with:
108 | # fetch-depth: 0
109 |
110 | # - name: Python Semantic Release
111 | # uses: relekang/python-semantic-release@master
112 | # with:
113 | # github_token: ${{ secrets.GITHUB_TOKEN }}
114 | # repository_username: __token__
115 | # repository_password: ${{ secrets.TWINE_API_KEY }}
116 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # Jupyter Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # SageMath parsed files
81 | *.sage.py
82 |
83 | # dotenv
84 | .env
85 |
86 | # virtualenv
87 | .venv
88 | venv/
89 | ENV/
90 |
91 | # Spyder project settings
92 | .spyderproject
93 | .spyproject
94 |
95 | # Rope project settings
96 | .ropeproject
97 |
98 | # mkdocs documentation
99 | /site
100 |
101 | # mypy
102 | .mypy_cache/
103 |
104 | # IDE settings
105 | .vscode/
106 |
107 | src/torch_cubic_spline_grids/_version.py
108 | src/torch_cubic_spline_grids/_version.py
109 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ci:
2 | autoupdate_schedule: monthly
3 | autofix_commit_msg: "style(pre-commit.ci): auto fixes [...]"
4 | autoupdate_commit_msg: "ci(pre-commit.ci): autoupdate"
5 |
6 | default_install_hook_types: [pre-commit, commit-msg]
7 |
8 | repos:
9 | - repo: https://github.com/compilerla/conventional-pre-commit
10 | rev: v1.3.0
11 | hooks:
12 | - id: conventional-pre-commit
13 | stages: [commit-msg]
14 |
15 | - repo: https://github.com/pre-commit/pre-commit-hooks
16 | rev: v4.3.0
17 | hooks:
18 | - id: check-docstring-first
19 | - id: end-of-file-fixer
20 | - id: trailing-whitespace
21 | - id: debug-statements
22 |
23 | - repo: https://github.com/astral-sh/ruff-pre-commit
24 | rev: v0.4.10
25 | hooks:
26 | - id: ruff
27 | args: [--fix]
28 | - id: ruff-format
29 |
30 | - repo: https://github.com/abravalheri/validate-pyproject
31 | rev: v0.10.1
32 | hooks:
33 | - id: validate-pyproject
34 |
35 | - repo: https://github.com/pre-commit/mirrors-mypy
36 | rev: v1.15.0
37 | hooks:
38 | - id: mypy
39 | files: "^src/"
40 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD License
2 |
3 | Copyright (c) 2023, Alister Burt
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # torch-cubic-spline-grids
2 |
3 | [](https://github.com/alisterburt/torch-cubic-spline-grids/raw/main/LICENSE)
4 | [](https://pypi.org/project/torch-cubic-spline-grids)
5 | [](https://python.org)
6 | [](https://github.com/alisterburt/torch-cubic-spline-grids/actions/workflows/ci.yml)
7 | [](https://codecov.io/gh/alisterburt/torch-cubic-spline-grids)
8 |
9 | *Cubic spline interpolation on multidimensional grids in PyTorch.*
10 |
11 | The primary goal of this package is to provide learnable, continuous
12 | parametrisations of 1-4D spaces.
13 |
14 | ---
15 |
16 | ## Overview
17 |
18 | `torch_cubic_spline_grids` provides a set of PyTorch components called grids.
19 |
20 | Grids are defined by
21 | - their dimensionality (1d, 2d, 3d, 4d...)
22 | - the number of points covering each dimension (`resolution`)
23 | - the number of values stored on each grid point (`n_channels`)
24 | - how we interpolate between values on grid points
25 |
26 | All grids in this package consist of uniformly spaced points covering the full
27 | extent of each dimension.
28 |
29 | ### First steps
30 | Let's make a simple 2D grid with one value on each grid point.
31 |
32 | ```python
33 | import torch
34 | from torch_cubic_spline_grids import CubicBSplineGrid2d
35 |
36 | grid = CubicBSplineGrid2d(resolution=(5, 3), n_channels=1)
37 | ```
38 |
39 | - `grid.ndim` is `2`
40 | - `grid.resolution` is `(5, 3)` (or `(h, w)`)
41 | - `grid.n_channels` is `1`
42 | - `grid.data.shape` is `(1, 5, 3)` (or `(c, h, w)`)
43 |
44 | In words, the grid extends over two dimensions `(h, w)` with 5 points
45 | in `h` and `3` points in `w`.
46 | There is one value stored at each point on the 2D grid.
47 | The grid data is stored in a tensor of shape `(c, *grid_resolution)`.
48 |
49 | We can obtain the value (interpolant) at any continuous point on the grid.
50 | The grid coordinate system extends from `[0, 1]` along each grid dimension.
51 | The interpolant is obtained by sequential application of
52 | cubic spline interpolation along each dimension of the grid.
53 |
54 | ```python
55 | coords = torch.rand(size=(10, 2)) # values in [0, 1]
56 | interpolants = grid(coords)
57 | ```
58 |
59 | - `interpolants.shape` is `(10, 1)`
60 |
61 | ### Optimisation
62 |
63 | Values at each grid point can be optimised by minimising a loss function associated with grid interpolants.
64 | In this way the continuous space of the grid can be made to more accurately model a 1-4D space.
65 |
66 |
67 |
68 |
69 |
70 | The image above shows the values of 6 control points on a 1D grid being optimised such
71 | that interpolating between them with cubic B-spline interpolation approximates a single oscillation of a sine wave.
72 |
73 | Notebooks are available for this
74 | [1D example](./examples/optimise_1d_grid_model.ipynb)
75 | and a similar
76 | [2D example](./examples/optimise_2d_grid_model.ipynb).
77 |
78 | ### Types of grids
79 |
80 | `torch_cubic_spline_grids` provides grids which can be interpolated with **cubic
81 | B-spline** interpolation or **cubic Catmull-Rom spline** interpolation.
82 |
83 | | spline | continuity | interpolating? |
84 | |--------------------|------------|----------------|
85 | | cubic B-spline | C2 | No |
86 | | Catmull-Rom spline | C1 | Yes |
87 |
88 | If your need the resulting curve to intersect the data on the grid you should
89 | use the cubic Catmull-Rom spline grids
90 |
91 | - `CubicCatmullRomGrid1d`
92 | - `CubicCatmullRomGrid2d`
93 | - `CubicCatmullRomGrid3d`
94 | - `CubicCatmullRomGrid4d`
95 |
96 | If you require continuous second derivatives then the cubic B-spline grids are more
97 | suitable.
98 |
99 | - `CubicBSplineGrid1d`
100 | - `CubicBSplineGrid2d`
101 | - `CubicBSplineGrid3d`
102 | - `CubicBSplineGrid4d`
103 |
104 | ### Regularisation
105 |
106 | The number of points in each dimension should be chosen such that interpolating on the
107 | grid can approximate the underlying phenomenon being modelled without overfitting.
108 | A low resolution grid provides a regularising effect by smoothing the model.
109 |
110 |
111 | ## Installation
112 |
113 | `torch_cubic_spline_grids` is available on PyPI
114 |
115 | ```shell
116 | pip install torch-cubic-spline-grids
117 | ```
118 |
119 |
120 | ## Related work
121 |
122 | This is a PyTorch implementation of the way
123 | [Warp](http://warpem.com/warp/#) models continuous deformation
124 | fields and locally variable optical parameters in cryo-EM images.
125 | The approach is described in
126 | [Dimitry Tegunov's paper](https://doi.org/10.1038/s41592-019-0580-y):
127 |
128 | > Many methods in Warp are based on a continuous parametrization of 1- to
129 | > 3-dimensional spaces.
130 | > This parameterization is achieved by spline interpolation between points on a coarse,
131 | > uniform grid, which is computationally efficient.
132 | > A grid extends over the entirety of each dimension that needs to be modeled.
133 | > The grid resolution is defined by the number of control points in each dimension
134 | > and is scaled according to physical constraints
135 | > (for example, the number of frames or pixels) and available signal.
136 | > The latter provides regularization to prevent overfitting of sparse data with too many
137 | > parameters.
138 | > When a parameter described by the grid is retrieved for a point in space (and time),
139 | > for example for a particle (frame), B-spline interpolation is performed at that point
140 | > on the grid.
141 | > To fit a grid’s parameters, in general, a cost function associated with the
142 | > interpolants at specific positions on the grid is optimized.
143 |
144 | ---
145 |
146 | For a fantastic introduction to splines I recommend
147 | [Freya Holmer](https://www.youtube.com/watch?v=jvPPXbo87ds)'s YouTube video.
148 |
149 | [The Continuity of Splines - YouTube](https://youtu.be/jvPPXbo87ds)
150 |
--------------------------------------------------------------------------------
/examples/optimise_1d_grid_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "outputs": [],
7 | "source": [
8 | "import torch\n",
9 | "from matplotlib import pyplot as plt\n",
10 | "\n",
11 | "from torch_cubic_spline_grids import CubicBSplineGrid1d"
12 | ],
13 | "metadata": {
14 | "collapsed": false,
15 | "pycharm": {
16 | "name": "#%%\n"
17 | }
18 | }
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "source": [
23 | "some variables we can set..."
24 | ],
25 | "metadata": {
26 | "collapsed": false,
27 | "pycharm": {
28 | "name": "#%% md\n"
29 | }
30 | }
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "outputs": [],
36 | "source": [
37 | "N_CONTROL_POINTS = 6\n",
38 | "N_OBSERVATIONS_PER_ITERATION = 20"
39 | ],
40 | "metadata": {
41 | "collapsed": false,
42 | "pycharm": {
43 | "name": "#%%\n"
44 | }
45 | }
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "source": [
50 | "initialise our optimisable parameters, a 1D grid of `N_CONTROL_POINTS` uniformly spaced\n",
51 | "points"
52 | ],
53 | "metadata": {
54 | "collapsed": false,
55 | "pycharm": {
56 | "name": "#%% md\n"
57 | }
58 | }
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 3,
63 | "outputs": [],
64 | "source": [
65 | "grid_1d = CubicBSplineGrid1d(resolution=N_CONTROL_POINTS)"
66 | ],
67 | "metadata": {
68 | "collapsed": false,
69 | "pycharm": {
70 | "name": "#%%\n"
71 | }
72 | }
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "source": [
77 | "define a function for making observations over the interval `[0, 1]` covering our 1d\n",
78 | "grid"
79 | ],
80 | "metadata": {
81 | "collapsed": false,
82 | "pycharm": {
83 | "name": "#%% md\n"
84 | }
85 | }
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 4,
90 | "outputs": [],
91 | "source": [
92 | "def make_observations(n, add_noise: bool = False):\n",
93 | " x = torch.rand(n) # in range [0, 1]\n",
94 | " y = torch.sin(2 * torch.pi * x)\n",
95 | " if add_noise is True:\n",
96 | " y += torch.normal(torch.zeros(n), std=0.5)\n",
97 | " return x, y"
98 | ],
99 | "metadata": {
100 | "collapsed": false,
101 | "pycharm": {
102 | "name": "#%%\n"
103 | }
104 | }
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "source": [
109 | "initialise the optimiser"
110 | ],
111 | "metadata": {
112 | "collapsed": false,
113 | "pycharm": {
114 | "name": "#%% md\n"
115 | }
116 | }
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 5,
121 | "outputs": [],
122 | "source": [
123 | "optimiser = torch.optim.Adam(grid_1d.parameters(), lr=0.02)"
124 | ],
125 | "metadata": {
126 | "collapsed": false,
127 | "pycharm": {
128 | "name": "#%%\n"
129 | }
130 | }
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "source": [
135 | "optimise the values at the control points such that interpolating between them with\n",
136 | "cubic B-spline interpolation fits the data"
137 | ],
138 | "metadata": {
139 | "collapsed": false,
140 | "pycharm": {
141 | "name": "#%% md\n"
142 | }
143 | }
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": 6,
148 | "outputs": [
149 | {
150 | "data": {
151 | "text/plain": "[]"
152 | },
153 | "execution_count": 6,
154 | "metadata": {},
155 | "output_type": "execute_result"
156 | },
157 | {
158 | "data": {
159 | "text/plain": "",
160 | "image/png": "\n"
161 | },
162 | "metadata": {},
163 | "output_type": "display_data"
164 | }
165 | ],
166 | "source": [
167 | "# some plotting stuff\n",
168 | "fig, ax = plt.subplots()\n",
169 | "control_point_x = torch.linspace(0, 1, N_CONTROL_POINTS)\n",
170 | "decay = 1\n",
171 | "ax.scatter(control_point_x, grid_1d.data, color='white')\n",
172 | "\n",
173 | "# actually optimising!\n",
174 | "for i in range(500):\n",
175 | " # make (noisy) observations of the data we want to model\n",
176 | " x, y = make_observations(N_OBSERVATIONS_PER_ITERATION, add_noise=True)\n",
177 | "\n",
178 | " # what does the model predict for our observations?\n",
179 | " prediction = grid_1d(x).squeeze()\n",
180 | "\n",
181 | " # zero gradients and calculate loss between observations and model prediction\n",
182 | " optimiser.zero_grad()\n",
183 | " loss = torch.sum((prediction - y)**2)**0.5\n",
184 | "\n",
185 | " # backpropagate loss and update values at points on grid\n",
186 | " loss.backward()\n",
187 | " optimiser.step()\n",
188 | "\n",
189 | " # plot\n",
190 | " if i % 10 == 0:\n",
191 | " decay *= 0.99\n",
192 | " ax.scatter(control_point_x, grid_1d.data, color='blue', alpha=1 - decay)\n",
193 | "\n",
194 | " x = torch.linspace(0, 1, 1000)\n",
195 | " y = grid_1d(x).squeeze()\n",
196 | " ax.plot(x, y.detach(), alpha=1 - decay, color='blue')\n",
197 | "\n",
198 | "x = torch.linspace(0, 1, 1000)\n",
199 | "y = torch.sin(x * 2 * torch.pi)\n",
200 | "ax.plot(x, y, ls='--', color='orange')"
201 | ],
202 | "metadata": {
203 | "collapsed": false,
204 | "pycharm": {
205 | "name": "#%%\n"
206 | }
207 | }
208 | },
209 | {
210 | "cell_type": "markdown",
211 | "source": [
212 | "this model has very little capacity to overfit to noisy data because of the small\n",
213 | "number of control points on our grid (parameters)"
214 | ],
215 | "metadata": {
216 | "collapsed": false,
217 | "pycharm": {
218 | "name": "#%% md\n"
219 | }
220 | }
221 | }
222 | ],
223 | "metadata": {
224 | "kernelspec": {
225 | "display_name": "Python 3",
226 | "language": "python",
227 | "name": "python3"
228 | },
229 | "language_info": {
230 | "codemirror_mode": {
231 | "name": "ipython",
232 | "version": 2
233 | },
234 | "file_extension": ".py",
235 | "mimetype": "text/x-python",
236 | "name": "python",
237 | "nbconvert_exporter": "python",
238 | "pygments_lexer": "ipython2",
239 | "version": "2.7.6"
240 | }
241 | },
242 | "nbformat": 4,
243 | "nbformat_minor": 0
244 | }
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # https://peps.python.org/pep-0517/
2 | [build-system]
3 | requires = ["hatchling", "hatch-vcs"]
4 | build-backend = "hatchling.build"
5 |
6 | # https://peps.python.org/pep-0621/
7 | [project]
8 | name = "torch-cubic-spline-grids"
9 | description = "Cubic spline interpolation on multidimensional grids in PyTorch"
10 | readme = "README.md"
11 | requires-python = ">=3.9"
12 | license = {text = "BSD 3-Clause License"}
13 | authors = [
14 | {email = "alisterburt@gmail.com"},
15 | {name = "Alister Burt"},
16 | ]
17 | classifiers = [
18 | "Development Status :: 3 - Alpha",
19 | "License :: OSI Approved :: BSD License",
20 | "Programming Language :: Python :: 3",
21 | "Programming Language :: Python :: 3.9",
22 | "Programming Language :: Python :: 3.10",
23 | "Programming Language :: Python :: 3.11",
24 | "Programming Language :: Python :: 3.12",
25 | "Programming Language :: Python :: 3.13",
26 | ]
27 | dynamic = ["version"]
28 | dependencies = [
29 | "torch",
30 | "numpy",
31 | "einops",
32 | "typing-extensions",
33 | ]
34 |
35 | # extras
36 | # https://peps.python.org/pep-0621/#dependencies-optional-dependencies
37 | [project.optional-dependencies]
38 | test = ["pytest>=6.0", "pytest-cov"]
39 | dev = [
40 | "ipython",
41 | "mypy",
42 | "pdbpp",
43 | "pre-commit",
44 | "pytest-cov",
45 | "pytest",
46 | "rich",
47 | "ruff",
48 | ]
49 |
50 | [project.urls]
51 | homepage = "https://github.com/alisterburt/torch-cubic-spline-grids"
52 | repository = "https://github.com/alisterburt/torch-cubic-spline-grids"
53 |
54 | # same as console_scripts entry point
55 | # [project.scripts]
56 | # spam-cli = "spam:main_cli"
57 |
58 | # Entry points
59 | # https://peps.python.org/pep-0621/#entry-points
60 | # [project.entry-points."spam.magical"]
61 | # tomatoes = "spam:main_tomatoes"
62 |
63 | # https://hatch.pypa.io/latest/config/metadata/
64 | [tool.hatch.version]
65 | source = "vcs"
66 |
67 | # https://hatch.pypa.io/latest/config/build/#file-selection
68 | # [tool.hatch.build.targets.sdist]
69 | # include = ["/src", "/tests"]
70 |
71 |
72 | # https://github.com/charliermarsh/ruff
73 | [tool.ruff]
74 | line-length = 88
75 | target-version = "py38"
76 |
77 | [tool.ruff.lint]
78 | extend-select = [
79 | "E", # style errors
80 | "F", # flakes
81 | "D", # pydocstyle
82 | "I001", # isort
83 | "U", # pyupgrade
84 | # "N", # pep8-naming
85 | # "S", # bandit
86 | "C", # flake8-comprehensions
87 | "B", # flake8-bugbear
88 | "A001", # flake8-builtins
89 | "RUF", # ruff-specific rules
90 | ]
91 | extend-ignore = [
92 | "D100", # Missing docstring in public module
93 | "D107", # Missing docstring in __init__
94 | "D203", # 1 blank line required before class docstring
95 | "D212", # Multi-line docstring summary should start at the first line
96 | "D213", # Multi-line docstring summary should start at the second line
97 | "D413", # Missing blank line after last section
98 | "D416", # Section name should end with a colon
99 | ]
100 |
101 | [tool.ruff.lint.per-file-ignores]
102 | "tests/*.py" = ["D"]
103 |
104 | [tool.ruff.lint.isort]
105 | combine-as-imports = true
106 |
107 | [tool.ruff.format]
108 | # Prefer single quotes over double quotes.
109 | quote-style = "single"
110 |
111 | # https://docs.pytest.org/en/6.2.x/customize.html
112 | [tool.pytest.ini_options]
113 | minversion = "6.0"
114 | pythonpath = "src"
115 | testpaths = ["tests"]
116 | filterwarnings = [
117 | "error",
118 | ]
119 |
120 | # https://mypy.readthedocs.io/en/stable/config_file.html
121 | [tool.mypy]
122 | files = "src/**/"
123 | strict = true
124 | disallow_any_generics = false
125 | disallow_subclassing_any = false
126 | show_error_codes = true
127 | pretty = true
128 |
129 |
130 | # https://coverage.readthedocs.io/en/6.4/config.html
131 | [tool.coverage.report]
132 | exclude_lines = [
133 | "pragma: no cover",
134 | "if TYPE_CHECKING:",
135 | "@overload",
136 | "except ImportError",
137 | ]
138 |
139 | # https://github.com/mgedmin/check-manifest#configuration
140 | [tool.check-manifest]
141 | ignore = [
142 | ".github_changelog_generator",
143 | ".pre-commit-config.yaml",
144 | ".ruff_cache/**/*",
145 | "tests/**/*",
146 | "tox.ini",
147 | ]
148 |
149 | # https://python-semantic-release.readthedocs.io/en/latest/configuration.html
150 | [tool.semantic_release]
151 | version_source = "tag_only"
152 | branch = "main"
153 | changelog_sections="feature,fix,breaking,documentation,performance,chore,:boom:,:sparkles:,:children_crossing:,:lipstick:,:iphone:,:egg:,:chart_with_upwards_trend:,:ambulance:,:lock:,:bug:,:zap:,:goal_net:,:alien:,:wheelchair:,:speech_balloon:,:mag:,:apple:,:penguin:,:checkered_flag:,:robot:,:green_apple:,Other"
154 | # commit_parser=semantic_release.history.angular_parser
155 | build_command = "pip install build && python -m build"
156 |
--------------------------------------------------------------------------------
/src/torch_cubic_spline_grids/__init__.py:
--------------------------------------------------------------------------------
1 | """Cubic B-spline interpolation on multidimensional grids in PyTorch."""
2 |
3 | from importlib.metadata import PackageNotFoundError, version
4 |
5 | try:
6 | __version__ = version('torch-cubic-b-spline-grid')
7 | except PackageNotFoundError:
8 | __version__ = 'uninstalled'
9 |
10 | __author__ = 'Alister Burt'
11 | __email__ = 'alisterburt@gmail.com'
12 |
13 | from .b_spline_grids import (
14 | CubicBSplineGrid1d,
15 | CubicBSplineGrid2d,
16 | CubicBSplineGrid3d,
17 | CubicBSplineGrid4d,
18 | )
19 | from .catmull_rom_grids import (
20 | CubicCatmullRomGrid1d,
21 | CubicCatmullRomGrid2d,
22 | CubicCatmullRomGrid3d,
23 | CubicCatmullRomGrid4d,
24 | )
25 |
26 | __all__ = [
27 | 'CubicBSplineGrid1d',
28 | 'CubicBSplineGrid2d',
29 | 'CubicBSplineGrid3d',
30 | 'CubicBSplineGrid4d',
31 | 'CubicCatmullRomGrid1d',
32 | 'CubicCatmullRomGrid2d',
33 | 'CubicCatmullRomGrid3d',
34 | 'CubicCatmullRomGrid4d',
35 | ]
36 |
--------------------------------------------------------------------------------
/src/torch_cubic_spline_grids/_base_cubic_grid.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, Tuple
2 |
3 | import einops
4 | import torch
5 | from typing_extensions import Self
6 |
7 | from torch_cubic_spline_grids.utils import (
8 | MonotonicityType,
9 | batch,
10 | coerce_to_multichannel_grid,
11 | )
12 |
13 |
14 | class CubicSplineGrid(torch.nn.Module):
15 | """Base class for continuous parametrisations of multidimensional spaces."""
16 |
17 | ndim: int
18 | _data: torch.nn.Parameter
19 | _interpolation_function: Callable
20 | _interpolation_matrix: torch.Tensor
21 | _minibatch_size: int
22 |
23 | def __init__(
24 | self,
25 | resolution: Optional[Tuple[int, ...]] = None,
26 | n_channels: int = 1,
27 | minibatch_size: int = 1_000_000,
28 | monotonicity: Optional[MonotonicityType] = None,
29 | ):
30 | super().__init__()
31 | if resolution is None:
32 | resolution = (2,) * self.ndim
33 | grid_shape = (n_channels, *resolution)
34 | self.data = torch.zeros(size=grid_shape)
35 | self._minibatch_size = minibatch_size
36 | self._monotonicity = monotonicity
37 | self.register_buffer(
38 | name='interpolation_matrix',
39 | tensor=self._interpolation_matrix,
40 | persistent=False,
41 | )
42 |
43 | def _interpolate(self, u: torch.Tensor) -> torch.Tensor:
44 | return self._interpolation_function(
45 | self._data,
46 | u,
47 | matrix=self.interpolation_matrix,
48 | monotonicity=self._monotonicity,
49 | )
50 |
51 | def forward(self, u: torch.Tensor) -> torch.Tensor:
52 | u = self._coerce_to_batched_coordinates(u) # (b, d)
53 |
54 | interpolated = [
55 | self._interpolate(minibatch_u)
56 | for minibatch_u in batch(u, n=self._minibatch_size)
57 | ] # List[Tensor[(b, d)]]
58 | interpolated = torch.cat(interpolated, dim=0) # (b, d)
59 | return self._unpack_interpolated_output(interpolated)
60 |
61 | @classmethod
62 | def from_grid_data(cls, data: torch.Tensor) -> Self:
63 | """Instantiate a grid from existing grid data.
64 |
65 | Parameters
66 | ----------
67 | data: torch.Tensor
68 | (c, *grid_dimensions) or (*grid_dimensions) array of multichannel values at
69 | each grid point.
70 | """
71 | grid = cls()
72 | grid.data = data
73 | return grid
74 |
75 | @property
76 | def data(self) -> torch.Tensor:
77 | return self._data.detach()
78 |
79 | @data.setter
80 | def data(self, grid_data: torch.Tensor) -> None:
81 | grid_data = coerce_to_multichannel_grid(grid_data, grid_ndim=self.ndim)
82 | self._data = torch.nn.Parameter(grid_data)
83 |
84 | @property
85 | def n_channels(self) -> int:
86 | return int(self._data.size(0))
87 |
88 | @property
89 | def resolution(self) -> Tuple[int, ...]:
90 | return tuple(self._data.shape[1:])
91 |
92 | def _coerce_to_batched_coordinates(self, u: torch.Tensor) -> torch.Tensor:
93 | u = torch.atleast_1d(torch.as_tensor(u, dtype=torch.float32))
94 | self._input_is_coordinate_like = u.shape[-1] == self.ndim
95 | if self._input_is_coordinate_like is False and self.ndim == 1:
96 | u = einops.rearrange(u, '... -> ... 1') # add singleton coord dimension
97 | else:
98 | u = torch.atleast_2d(u) # add batch dimension if missing
99 | u, self._packed_shapes = einops.pack([u], pattern='* coords')
100 | if u.shape[-1] != self.ndim:
101 | ndim = u.shape[-1]
102 | raise ValueError(
103 | f'Cannot interpolate on a {self.ndim}D grid with {ndim}D coordinates'
104 | )
105 | return u
106 |
107 | def _unpack_interpolated_output(self, interpolated: torch.Tensor) -> torch.Tensor:
108 | [interpolated] = einops.unpack(
109 | interpolated, packed_shapes=self._packed_shapes, pattern='* coords'
110 | )
111 | if self._input_is_coordinate_like is False and self.ndim == 1:
112 | interpolated = einops.rearrange(interpolated, '... 1 -> ...')
113 | return interpolated
114 |
--------------------------------------------------------------------------------
/src/torch_cubic_spline_grids/_constants.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # Described in Freya Holmer's video "The Continuity of Splines"
4 | # https://youtu.be/jvPPXbo87ds?t=3462
5 |
6 | CUBIC_B_SPLINE_MATRIX = (1 / 6) * torch.tensor([[1, 4, 1, 0],
7 | [-3, 0, 3, 0],
8 | [3, -6, 3, 0],
9 | [-1, 3, -3, 1]])
10 |
11 | CUBIC_CATMULL_ROM_MATRIX = (1 / 2) * torch.tensor([[0, 2, 0, 0],
12 | [-1, 0, 1, 0],
13 | [2, -5, 4, -1],
14 | [-1, 3, -3, 1]])
15 |
--------------------------------------------------------------------------------
/src/torch_cubic_spline_grids/b_spline_grids.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, Sequence, Tuple, Union
2 |
3 | import torch
4 |
5 | from torch_cubic_spline_grids._base_cubic_grid import CubicSplineGrid
6 | from torch_cubic_spline_grids._constants import CUBIC_B_SPLINE_MATRIX
7 | from torch_cubic_spline_grids.interpolate_grids import (
8 | interpolate_grid_1d as _interpolate_grid_1d,
9 | interpolate_grid_2d as _interpolate_grid_2d,
10 | interpolate_grid_3d as _interpolate_grid_3d,
11 | interpolate_grid_4d as _interpolate_grid_4d,
12 | )
13 | from torch_cubic_spline_grids.utils import MonotonicityType
14 |
15 | CoordinateLike = Union[float, Sequence[float], torch.Tensor]
16 |
17 |
18 | class _CubicBSplineGrid(CubicSplineGrid):
19 | _interpolation_matrix = CUBIC_B_SPLINE_MATRIX
20 |
21 |
22 | class CubicBSplineGrid1d(_CubicBSplineGrid):
23 | """Continuous parametrisation of a 1D space with a specific resolution."""
24 |
25 | ndim: int = 1
26 | _interpolation_function: Callable = staticmethod(_interpolate_grid_1d)
27 |
28 | def __init__(
29 | self,
30 | resolution: Optional[Union[int, Tuple[int]]] = None,
31 | n_channels: int = 1,
32 | minibatch_size: int = 1_000_000,
33 | monotonicity: Optional[MonotonicityType] = None,
34 | ):
35 | if isinstance(resolution, int):
36 | resolution = (resolution,)
37 | super().__init__(
38 | resolution=resolution,
39 | n_channels=n_channels,
40 | minibatch_size=minibatch_size,
41 | monotonicity=monotonicity,
42 | )
43 |
44 |
45 | class CubicBSplineGrid2d(_CubicBSplineGrid):
46 | """Continuous parametrisation of a 2D space with a specific resolution."""
47 |
48 | ndim: int = 2
49 | _interpolation_function: Callable = staticmethod(_interpolate_grid_2d)
50 |
51 |
52 | class CubicBSplineGrid3d(_CubicBSplineGrid):
53 | """Continuous parametrisation of a 3D space with a specific resolution."""
54 |
55 | ndim: int = 3
56 | _interpolation_function: Callable = staticmethod(_interpolate_grid_3d)
57 |
58 |
59 | class CubicBSplineGrid4d(_CubicBSplineGrid):
60 | """Continuous parametrisation of a 4D space with a specific resolution."""
61 |
62 | ndim: int = 4
63 | _interpolation_function: Callable = staticmethod(_interpolate_grid_4d)
64 |
--------------------------------------------------------------------------------
/src/torch_cubic_spline_grids/catmull_rom_grids.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, Sequence, Tuple, Union
2 |
3 | import torch
4 |
5 | from torch_cubic_spline_grids._base_cubic_grid import CubicSplineGrid
6 | from torch_cubic_spline_grids._constants import CUBIC_CATMULL_ROM_MATRIX
7 | from torch_cubic_spline_grids.interpolate_grids import (
8 | interpolate_grid_1d as _interpolate_grid_1d,
9 | interpolate_grid_2d as _interpolate_grid_2d,
10 | interpolate_grid_3d as _interpolate_grid_3d,
11 | interpolate_grid_4d as _interpolate_grid_4d,
12 | )
13 | from torch_cubic_spline_grids.utils import MonotonicityType
14 |
15 | CoordinateLike = Union[float, Sequence[float], torch.Tensor]
16 |
17 |
18 | class _CubicCatmullRomGrid(CubicSplineGrid):
19 | _interpolation_matrix = CUBIC_CATMULL_ROM_MATRIX
20 |
21 |
22 | class CubicCatmullRomGrid1d(_CubicCatmullRomGrid):
23 | """Continuous parametrisation of a 1D space with a specific resolution."""
24 |
25 | ndim: int = 1
26 | _interpolation_function: Callable = staticmethod(_interpolate_grid_1d)
27 |
28 | def __init__(
29 | self,
30 | resolution: Optional[Union[int, Tuple[int]]] = None,
31 | n_channels: int = 1,
32 | minibatch_size: int = 1_000_000,
33 | monotonicity: Optional[MonotonicityType] = None,
34 | ):
35 | if isinstance(resolution, int):
36 | resolution = (resolution,)
37 | super().__init__(
38 | resolution=resolution,
39 | n_channels=n_channels,
40 | minibatch_size=minibatch_size,
41 | monotonicity=monotonicity,
42 | )
43 |
44 |
45 | class CubicCatmullRomGrid2d(_CubicCatmullRomGrid):
46 | """Continuous parametrisation of a 2D space with a specific resolution."""
47 |
48 | ndim: int = 2
49 | _interpolation_function: Callable = staticmethod(_interpolate_grid_2d)
50 |
51 |
52 | class CubicCatmullRomGrid3d(_CubicCatmullRomGrid):
53 | """Continuous parametrisation of a 3D space with a specific resolution."""
54 |
55 | ndim: int = 3
56 | _interpolation_function: Callable = staticmethod(_interpolate_grid_3d)
57 |
58 |
59 | class CubicCatmullRomGrid4d(_CubicCatmullRomGrid):
60 | """Continuous parametrisation of a 4D space with a specific resolution."""
61 |
62 | ndim: int = 4
63 | _interpolation_function: Callable = staticmethod(_interpolate_grid_4d)
64 |
--------------------------------------------------------------------------------
/src/torch_cubic_spline_grids/interpolate_grids.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import einops
4 | import torch
5 |
6 | from torch_cubic_spline_grids.interpolate_pieces import (
7 | interpolate_pieces_1d,
8 | interpolate_pieces_2d,
9 | interpolate_pieces_3d,
10 | interpolate_pieces_4d,
11 | )
12 | from torch_cubic_spline_grids.pad_grids import (
13 | pad_grid_1d,
14 | pad_grid_2d,
15 | pad_grid_3d,
16 | pad_grid_4d,
17 | )
18 | from torch_cubic_spline_grids.utils import (
19 | MonotonicityType,
20 | interpolants_to_interpolation_data_1d,
21 | transform_to_monotonic_nd,
22 | )
23 |
24 |
25 | def interpolate_grid_1d(
26 | grid: torch.Tensor,
27 | u: torch.Tensor,
28 | matrix: torch.Tensor,
29 | monotonicity: Optional[MonotonicityType] = None,
30 | ) -> torch.Tensor:
31 | """Uniform cubic spline interpolation on a 1D grid.
32 |
33 | The range [0, 1] covers all data points in the 1D grid.
34 |
35 | Parameters
36 | ----------
37 | grid: torch.Tensor
38 | `(c, w)` array of `w` values in `c` channels to be interpolated.
39 | u: torch.Tensor
40 | `(b, 1)` array of query points in the range `[0, 1]` covering the `w`
41 | dimension of `grid`.
42 | matrix: torch.Tensor
43 | `(4, 4)` characteristic matrix for the spline.
44 | monotonicity: str
45 | when either 'increasing' or 'decreasing' is specified, ensures
46 | that control points of spline are monotonic.
47 |
48 | Returns
49 | -------
50 | interpolated: torch.Tensor
51 | `(b, c)` array of interpolated values in each channel.
52 | """
53 | if grid.ndim == 1:
54 | grid = einops.rearrange(grid, 'w -> 1 w')
55 | _, w = grid.shape
56 |
57 | # handle interpolation at edges by extending grid of control points according to
58 | # local gradients
59 | grid = pad_grid_1d(grid)
60 |
61 | # find control point indices and interpolation coordinate
62 | idx, t = interpolants_to_interpolation_data_1d(u[:, 0], n_samples=w)
63 | if monotonicity:
64 | grid = transform_to_monotonic_nd(grid, ndims=1, monotonicity=monotonicity)
65 | control_points = grid[..., idx] # (c, b, 4)
66 | control_points = einops.rearrange(control_points, 'c b p -> b c p')
67 |
68 | # interpolate
69 | return interpolate_pieces_1d(control_points, t, matrix=matrix)
70 |
71 |
72 | def interpolate_grid_2d(
73 | grid: torch.Tensor,
74 | u: torch.Tensor,
75 | matrix: torch.Tensor,
76 | monotonicity: Optional[MonotonicityType] = None,
77 | ) -> torch.Tensor:
78 | """Uniform cubic B-spline interpolation on a 2D grid.
79 |
80 | Parameters
81 | ----------
82 | grid: torch.Tensor
83 | `(c, h, w)` multichannel 2D grid.
84 | u: torch.Tensor
85 | `(b, 2)` array of values in the range `[0, 1]`.
86 | `[0, 1]` in `u[:, 0]` covers dim -2 (h) of `grid`
87 | `[0, 1]` in `u[:, 1]` covers dim -1 (w) of `grid`
88 | matrix: torch.Tensor
89 | `(4, 4)` characteristic matrix for the spline.
90 | monotonicity: str
91 | when either 'increasing' or 'decreasing' is specified, ensures
92 | that control points of spline are monotonic.
93 |
94 | Returns
95 | -------
96 | `(b, c)` array of interpolated values in each channel.
97 | """
98 | if grid.ndim == 2:
99 | grid = einops.rearrange(grid, 'h w -> 1 h w')
100 | _, h, w = grid.shape
101 |
102 | # pad grid to handle interpolation at edges.
103 | grid = pad_grid_2d(grid)
104 |
105 | # find control point indices and interpolation coordinate in each dim
106 | idx_h, t_h = interpolants_to_interpolation_data_1d(u[:, 0], n_samples=h)
107 | idx_w, t_w = interpolants_to_interpolation_data_1d(u[:, 1], n_samples=w)
108 |
109 | # construct (4, 4) grids of control points and 2D interpolant then interpolate
110 | idx_h = einops.repeat(idx_h, 'b h -> b h w', w=4)
111 | idx_w = einops.repeat(idx_w, 'b w -> b h w', h=4)
112 | if monotonicity:
113 | grid = transform_to_monotonic_nd(grid, ndims=2, monotonicity=monotonicity)
114 | control_points = grid[..., idx_h, idx_w] # (c, b, 4, 4)
115 | control_points = einops.rearrange(control_points, 'c b h w -> b c h w')
116 |
117 | t = einops.rearrange([t_h, t_w], 'hw b -> b hw')
118 | return interpolate_pieces_2d(control_points, t, matrix=matrix)
119 |
120 |
121 | def interpolate_grid_3d(
122 | grid: torch.Tensor,
123 | u: torch.Tensor,
124 | matrix: torch.Tensor,
125 | monotonicity: Optional[MonotonicityType] = None,
126 | ) -> torch.Tensor:
127 | """Uniform cubic B-spline interpolation on a 3D grid.
128 |
129 | Parameters
130 | ----------
131 | grid: torch.Tensor
132 | `(c, d, h, w)` multichannel 3D grid.
133 | u: torch.Tensor
134 | `(b, 3)` array of values in the range [0, 1].
135 | [0, 1] in b[:, 0] covers depth dim `d` of `grid`
136 | [0, 1] in b[:, 1] covers height dim `h` of `grid`
137 | [0, 1] in b[:, 2] covers width dim `w` of `grid`
138 | matrix: torch.Tensor
139 | `(4, 4)` characteristic matrix for the spline.
140 | monotonicity: str
141 | when either 'increasing' or 'decreasing' is specified, ensures
142 | that control points of spline are monotonic.
143 |
144 | Returns
145 | -------
146 | `(b, c)` array of c-dimensional interpolated values
147 | """
148 | if grid.ndim == 3:
149 | grid = einops.rearrange(grid, 'd h w -> 1 d h w')
150 | _, n_samples_d, n_samples_h, n_samples_w = grid.shape
151 |
152 | # expand grid to handle interpolation at edges
153 | grid = pad_grid_3d(grid)
154 |
155 | # find control point indices and interpolation coordinate in each dim
156 | idx_d, t_d = interpolants_to_interpolation_data_1d(u[:, 0], n_samples_d)
157 | idx_h, t_h = interpolants_to_interpolation_data_1d(u[:, 1], n_samples_h)
158 | idx_w, t_w = interpolants_to_interpolation_data_1d(u[:, 2], n_samples_w)
159 |
160 | # construct (4, 4, 4) grids of control points and 3D interpolant then interpolate
161 | idx_d = einops.repeat(idx_d, 'b d -> b d h w', h=4, w=4)
162 | idx_h = einops.repeat(idx_h, 'b h -> b d h w', d=4, w=4)
163 | idx_w = einops.repeat(idx_w, 'b w -> b d h w', d=4, h=4)
164 | if monotonicity:
165 | grid = transform_to_monotonic_nd(grid, ndims=3, monotonicity=monotonicity)
166 | control_points = grid[:, idx_d, idx_h, idx_w] # (c, b, 4, 4, 4)
167 | control_points = einops.rearrange(control_points, 'c b d h w -> b c d h w')
168 |
169 | t = einops.rearrange([t_d, t_h, t_w], 'dhw b -> b dhw')
170 | return interpolate_pieces_3d(control_points, t, matrix=matrix)
171 |
172 |
173 | def interpolate_grid_4d(
174 | grid: torch.Tensor,
175 | u: torch.Tensor,
176 | matrix: torch.Tensor,
177 | monotonicity: Optional[MonotonicityType] = None,
178 | ) -> torch.Tensor:
179 | """Uniform cubic B-spline interpolation on a 4D grid.
180 |
181 | Parameters
182 | ----------
183 | grid: torch.Tensor
184 | `(c, u, d, h, w)` multichannel 4D grid.
185 | u: torch.Tensor
186 | `(b, 4)` array of values in the range [0, 1].
187 | [0, 1] in b[:, 0] covers time dim `u` of `grid`
188 | [0, 1] in b[:, 1] covers depth dim `d` of `grid`
189 | [0, 1] in b[:, 2] covers height dim `h` of `grid`
190 | [0, 1] in b[:, 3] covers width dim `w` of `grid`
191 | matrix: torch.Tensor
192 | `(4, 4)` characteristic matrix for the spline.
193 | monotonicity: str
194 | when either 'increasing' or 'decreasing' is specified, ensures
195 | that control points of spline are monotonic.
196 |
197 | Returns
198 | -------
199 | `(b, c)` array of c-dimensional interpolated values
200 | """
201 | if grid.ndim == 4:
202 | grid = einops.rearrange(grid, 't d h w -> 1 t d h w')
203 | _, t, d, h, w = grid.shape
204 |
205 | # expand grid to handle interpolation at edges
206 | grid = pad_grid_4d(grid)
207 |
208 | # find control point indices and interpolation coordinate in each dim
209 | idx_t, t_t = interpolants_to_interpolation_data_1d(u[:, 0], n_samples=t)
210 | idx_d, t_d = interpolants_to_interpolation_data_1d(u[:, 1], n_samples=d)
211 | idx_h, t_h = interpolants_to_interpolation_data_1d(u[:, 2], n_samples=h)
212 | idx_w, t_w = interpolants_to_interpolation_data_1d(u[:, 3], n_samples=w)
213 |
214 | # construct (4, 4, 4, 4) grids of control points and 4D interpolant then interpolate
215 | idx_t = einops.repeat(idx_t, 'b t -> b t d h w', d=4, h=4, w=4)
216 | idx_d = einops.repeat(idx_d, 'b d -> b t d h w', t=4, h=4, w=4)
217 | idx_h = einops.repeat(idx_h, 'b h -> b t d h w', t=4, d=4, w=4)
218 | idx_w = einops.repeat(idx_w, 'b w -> b t d h w', t=4, d=4, h=4)
219 | if monotonicity:
220 | grid = transform_to_monotonic_nd(grid, ndims=3, monotonicity=monotonicity)
221 | control_points = grid[:, idx_t, idx_d, idx_h, idx_w] # (c, b, 4, 4, 4, 4)
222 | control_points = einops.rearrange(control_points, 'c b t d h w -> b c t d h w')
223 |
224 | t = einops.rearrange([t_t, t_d, t_h, t_w], 'tdhw b -> b tdhw')
225 | return interpolate_pieces_4d(control_points, t, matrix=matrix)
226 |
--------------------------------------------------------------------------------
/src/torch_cubic_spline_grids/interpolate_pieces.py:
--------------------------------------------------------------------------------
1 | """Interpolate 'pieces' for piecewise uniform cubic B-spline interpolation."""
2 |
3 | import einops
4 | import torch
5 |
6 |
7 | def interpolate_pieces_1d(
8 | control_points: torch.Tensor, t: torch.Tensor, matrix: torch.Tensor
9 | ) -> torch.Tensor:
10 | """Batched uniform 1D cubic spline interpolation.
11 |
12 | ```
13 | [0, u, u^2, u^3] * [a00, a01, a02, a03] * [p0]
14 | [a10, a11, a12, a13] [p1]
15 | [a20, a21, a22, a23] [p2]
16 | [a30, a31, a32, a33] [p3]
17 | ```
18 | c.f. Freya Holmer - "The Continuity of Splines": https://youtu.be/jvPPXbo87ds?t=3462
19 |
20 | Parameters
21 | ----------
22 | control_points: torch.Tensor
23 | `(b, c, 4)` batch of 4 uniformly spaced control points `[p0, p1, p2, p3]`
24 | in `c` channels.
25 | t: torch.Tensor
26 | `(b, )` batch of interpolants in the range [0, 1] covering the interpolation
27 | interval between `p1` and `p2`
28 | matrix: torch.Tensor
29 | `(4, 4)` characteristic matrix for the spline.
30 |
31 | Returns
32 | -------
33 | interpolated: torch.Tensor
34 | `(b, c)` array of per-channel interpolants of `control_points` at `u`.
35 | """
36 | t = einops.rearrange([t**0, t, t**2, t**3], 'u b -> b 1 1 u')
37 | control_points = einops.rearrange(control_points, 'b c p -> b c p 1')
38 | interpolated = t @ matrix @ control_points
39 | return einops.rearrange(interpolated, 'b c 1 1 -> b c')
40 |
41 |
42 | def interpolate_pieces_2d(
43 | control_points: torch.Tensor, t: torch.Tensor, matrix: torch.Tensor
44 | ) -> torch.Tensor:
45 | """Batched uniform 2D cubic B-spline interpolation.
46 |
47 | Parameters
48 | ----------
49 | control_points: torch.Tensor
50 | `(b, c, 4, 4)` batch of 2D multichannel grids of uniformly spaced control
51 | points `[p0, p1, p2, p3]` for cubic B-spline interpolation.
52 | t: torch.Tensor
53 | `(b, 2)` batch of values in the range `[0, 1]` defining the position of 2D
54 | points to be interpolated within the interval `[p1, p2]` along dim 1 and 2 of
55 | the 2D grid of control points.
56 | matrix: torch.Tensor
57 | `(4, 4)` characteristic matrix for the spline.
58 |
59 | Returns
60 | -------
61 | interpolated:
62 | `(b, n)` batch of n-dimensional interpolated values.
63 | """
64 | # extract (b, c, 4) control points at each height along width dim of (h, w) grid
65 | h0, h1, h2, h3 = einops.rearrange(control_points, 'b c h w -> h b c w')
66 |
67 | # separate u into components along height and width dimensions
68 | t_h, t_w = einops.rearrange(t, 'b hw -> hw b')
69 |
70 | # 1d interpolation along width dim at each height
71 | p0 = interpolate_pieces_1d(control_points=h0, t=t_w, matrix=matrix)
72 | p1 = interpolate_pieces_1d(control_points=h1, t=t_w, matrix=matrix)
73 | p2 = interpolate_pieces_1d(control_points=h2, t=t_w, matrix=matrix)
74 | p3 = interpolate_pieces_1d(control_points=h3, t=t_w, matrix=matrix)
75 |
76 | # 1d interpolation of result along height dim
77 | control_points = einops.rearrange([p0, p1, p2, p3], 'p b c -> b c p')
78 | return interpolate_pieces_1d(control_points=control_points, t=t_h, matrix=matrix)
79 |
80 |
81 | def interpolate_pieces_3d(
82 | control_points: torch.Tensor, t: torch.Tensor, matrix: torch.Tensor
83 | ) -> torch.Tensor:
84 | """Batched uniform 3D cubic B-spline interpolation.
85 |
86 | Parameters
87 | ----------
88 | control_points: torch.Tensor
89 | `(b, c, 4, 4, 4)` batch of `(4, 4, 4)` multichannel grids of uniformly
90 | spaced control points for cubic B-spline interpolation.
91 | t: torch.Tensor
92 | `(b, 3)` batch of values in the range `[0, 1]` defining the position of 3D
93 | points to be interpolated within the interval `[p1, p2]` along dim -3,
94 | -2 and -1 of `control_points`
95 | matrix: torch.Tensor
96 | `(4, 4)` characteristic matrix for the spline.
97 |
98 | Returns
99 | -------
100 | interpolated:
101 | `(b, c)` batch interpolated values in each channel.
102 | """
103 | # extract (b, c, 4, 4) 2D control point planes at each point along the depth dim
104 | d0, d1, d2, d3 = einops.rearrange(control_points, 'b c d h w -> d b c h w')
105 |
106 | # separate u into components along depth and (height, width) dimensions
107 | t_d = t[:, 0]
108 | t_hw = t[:, [1, 2]]
109 |
110 | # 2d interpolation on each (height, width) plane at each depth
111 | p0 = interpolate_pieces_2d(control_points=d0, t=t_hw, matrix=matrix)
112 | p1 = interpolate_pieces_2d(control_points=d1, t=t_hw, matrix=matrix)
113 | p2 = interpolate_pieces_2d(control_points=d2, t=t_hw, matrix=matrix)
114 | p3 = interpolate_pieces_2d(control_points=d3, t=t_hw, matrix=matrix)
115 |
116 | # 1d interpolation of result along depth dim
117 | control_points = einops.rearrange([p0, p1, p2, p3], 'p b c -> b c p')
118 | return interpolate_pieces_1d(control_points=control_points, t=t_d, matrix=matrix)
119 |
120 |
121 | def interpolate_pieces_4d(
122 | control_points: torch.Tensor, t: torch.Tensor, matrix: torch.Tensor
123 | ) -> torch.Tensor:
124 | """Batched 4D cubic B-spline interpolation.
125 |
126 | Parameters
127 | ----------
128 | control_points: torch.Tensor
129 | `(b, c, 4, 4, 4, 4)` batch of multichannel `(4, 4, 4, 4)` grids of uniformly
130 | spaced control points for cubic B-spline interpolation.
131 | t: torch.Tensor
132 | `(b, 4)` batch of values in the range `[0, 1]` defining the position of 4D
133 | points to be interpolated within the interval `[p1, p2]` along dims -4, -3,
134 | -2 and -1 of `control_points`.
135 | matrix: torch.Tensor
136 | `(4, 4)` characteristic matrix for the spline.
137 |
138 | Returns
139 | -------
140 | interpolated:
141 | `(b, n)` batch of n-dimensional interpolated values.
142 | """
143 | # extract (b, c, 4, 4, 4) 3D control point grids at each point along the time dim
144 | t0, t1, t2, t3 = einops.rearrange(control_points, 'b c u d h w -> u b c d h w')
145 |
146 | # separate u into components along time and (depth, height, width) dimensions
147 | t_t = t[:, 0]
148 | t_dhw = t[:, [1, 2, 3]]
149 |
150 | # 3D interpolation on each 3D grid along time dimension
151 | p0 = interpolate_pieces_3d(control_points=t0, t=t_dhw, matrix=matrix)
152 | p1 = interpolate_pieces_3d(control_points=t1, t=t_dhw, matrix=matrix)
153 | p2 = interpolate_pieces_3d(control_points=t2, t=t_dhw, matrix=matrix)
154 | p3 = interpolate_pieces_3d(control_points=t3, t=t_dhw, matrix=matrix)
155 |
156 | # 1d interpolation of result along time dim
157 | control_points = einops.rearrange([p0, p1, p2, p3], 'p b c -> b c p')
158 | return interpolate_pieces_1d(control_points=control_points, t=t_t, matrix=matrix)
159 |
--------------------------------------------------------------------------------
/src/torch_cubic_spline_grids/pad_grids.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 |
5 | def pad_grid_1d(grid: torch.Tensor) -> torch.Tensor:
6 | """Pad in the last dimension according to local gradients.
7 |
8 | e.g. [0, 1, 2] -> [-1, 0, 1, 2, 3]
9 |
10 | grid: torch.Tensor
11 | `(..., w)` array of values to be padded in last dimension.
12 |
13 | Returns
14 | -------
15 | padded_grid: torch.Tensor
16 | `(..., w+2)` padded array.
17 | """
18 | # remove singleton dimension if necessary
19 | w = grid.shape[-1]
20 | if w == 1:
21 | grid = einops.repeat(grid, '... w -> ... (repeat w)', repeat=2)
22 |
23 | # find values for padding at each end of width dim
24 | start = grid[..., 0] - (grid[..., 1] - grid[..., 0])
25 | end = grid[..., -1] + (grid[..., -1] - grid[..., -2])
26 |
27 | # reintroduce width dim lost during indexing
28 | start = einops.rearrange(start, '... -> ... 1')
29 | end = einops.repeat(end, '... -> ... 1')
30 | return torch.cat([start, grid, end], dim=-1)
31 |
32 |
33 | def pad_grid_2d(grid: torch.Tensor) -> torch.Tensor:
34 | """Pad a 2D grid of values according to local gradients.
35 |
36 | ```
37 | e.g. [[-3, -2, -1, 0]
38 | [[0, 1] [-1, 0, 1, 2]
39 | [2, 3]] -> [ 1, 2, 3, 4]
40 | [ 3, 4, 5, 6]]
41 | ```
42 |
43 | Parameters
44 | ----------
45 | grid: torch.Tensor
46 | `(..., h, w)` array of values to be padded in height and width dimensions.
47 |
48 | Returns
49 | -------
50 | padded_grid: torch.Tensor
51 | `(..., h+2, w+2)` padded array.
52 | """
53 | # remove singleton dimension if necessary
54 | h = grid.shape[-2]
55 | if h == 1:
56 | grid = einops.repeat(grid, '... h w -> ... (repeat h) w', repeat=2)
57 | grid = pad_grid_1d(grid) # pad width dim (..., h, w+2)
58 |
59 | # find values for padding at each end of height dim
60 | h_start = grid[..., 0, :] - (grid[..., 1, :] - grid[..., 0, :])
61 | h_end = grid[..., -1, :] + (grid[..., -1, :] - grid[..., -2, :])
62 |
63 | # reintroduce height dim lost through indexing
64 | h_start = einops.rearrange(h_start, '... w -> ... 1 w')
65 | h_end = einops.rearrange(h_end, '... w -> ... 1 w')
66 |
67 | # pad height dim
68 | return torch.cat([h_start, grid, h_end], dim=-2)
69 |
70 |
71 | def pad_grid_3d(grid: torch.Tensor) -> torch.Tensor:
72 | """Pad a 3D grid of values according to local gradients.
73 |
74 | Parameters
75 | ----------
76 | grid: torch.Tensor
77 | `(..., d, h, w)` array of values to be padded in depth, height and width
78 | dimensions.
79 |
80 | Returns
81 | -------
82 | padded_grid: torch.Tensor
83 | `(..., d+2, h+2, w+2)` padded array.
84 | """
85 | # remove singleton dimension if necessary
86 | d = grid.shape[-3]
87 | if d == 1:
88 | grid = einops.repeat(grid, '... d h w -> ... (repeat d) h w', repeat=2)
89 |
90 | # pad in height and width dims
91 | grid = pad_grid_2d(grid)
92 |
93 | # find values for padding at each end of depth dim
94 | d_start = grid[..., 0, :, :] - (grid[..., 1, :, :] - grid[..., 0, :, :])
95 | d_end = grid[..., -1, :, :] + (grid[..., -1, :, :] - grid[..., -2, :, :])
96 |
97 | # reintroduce depth dim dropped by indexing
98 | d_start = einops.rearrange(d_start, '... h w -> ... 1 h w')
99 | d_end = einops.rearrange(d_end, '... h w -> ... 1 h w')
100 | return torch.cat([d_start, grid, d_end], dim=-3)
101 |
102 |
103 | def pad_grid_4d(grid: torch.Tensor) -> torch.Tensor:
104 | """Pad a 4D grid of values according to local gradients.
105 |
106 | Parameters
107 | ----------
108 | grid: torch.Tensor
109 | `(..., u, d, h, w)` array of values to be padded in time, depth, height and
110 | width dimensions.
111 |
112 | Returns
113 | -------
114 | padded_grid: torch.Tensor
115 | `(..., u+2, d+2, h+2, w+2)` grid
116 | """
117 | # remove singleton dimension if necessary
118 | t = grid.shape[-4]
119 | if t == 1:
120 | grid = einops.repeat(grid, '... u d h w -> ... (repeat u) d h w', repeat=2)
121 |
122 | # pad in height and width dims
123 | grid = pad_grid_3d(grid) # (..., u, d+2, h+2, w+2)
124 |
125 | # find values for padding at each end of time dim
126 | dt_start = grid[..., 1, :, :, :] - grid[..., 0, :, :, :]
127 | t_start = grid[..., 0, :, :, :] - dt_start
128 | dt_end = grid[..., -1, :, :, :] - grid[..., -2, :, :, :]
129 | t_end = grid[..., -1, :, :, :] + dt_end
130 |
131 | # reintroduce time dim dropped by indexing
132 | t_start = einops.rearrange(t_start, '... d h w -> ... 1 d h w')
133 | t_end = einops.rearrange(t_end, '... d h w -> ... 1 d h w')
134 | return torch.cat([t_start, grid, t_end], dim=-4)
135 |
--------------------------------------------------------------------------------
/src/torch_cubic_spline_grids/utils.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 | from typing import Callable, Iterable, Literal, Tuple
3 |
4 | import einops
5 | import torch
6 |
7 |
8 | def generate_sample_positions_for_padded_grid_1d(
9 | n_samples: int, device: torch.device
10 | ) -> torch.Tensor:
11 | """Generate a 1D vector of sample coordinates for a padded grid.
12 |
13 | Coordinate system is [0, 1] covering each dimension, pre-padding.
14 | e.g. for 6 samples on a padded grid
15 | `[-0.333, 0, 0.333, 0.666, 1, 1.333]`
16 |
17 |
18 | Parameters
19 | ----------
20 | n_samples: int
21 | The number of samples on the grid prior to padding.
22 | device: torch.device
23 | The torch device on which to store the tensor.
24 |
25 | Returns
26 | -------
27 | sample_coordinates: torch.Tensor
28 | The coordinates in the [0, 1] coordinate system of each sample on the
29 | padded grid.
30 | """
31 | du = 1 / (n_samples - 1)
32 | sample_coordinates = torch.linspace(-du, 1 + du, steps=n_samples + 2, device=device)
33 |
34 | # fix for numerical stability issues around 0 and 1
35 | # ensures valid control point indices are selected
36 | epsilon = 1e-6
37 | sample_coordinates[1] = 0 - epsilon
38 | sample_coordinates[-2] = 1 + epsilon
39 | return sample_coordinates
40 |
41 |
42 | def find_control_point_idx_1d(
43 | sample_positions: torch.Tensor, query_points: torch.Tensor
44 | ) -> torch.Tensor:
45 | """Find indices of four control points required for cubic interpolation.
46 |
47 | E.g. for sample positions `[0, 1, 2, 3, 4, 5]` and query point `2.5` the control
48 | point indices would be `[1, 2, 3, 4]` as `2.5` lies between `2` and `3`
49 |
50 | Parameters
51 | ----------
52 | sample_positions: torch.Tensor
53 | Monotonically increasing 1D array of sample positions.
54 | query_points: torch.Tensor
55 | `(b, )` array of query points for which control point indices.
56 |
57 | Returns
58 | -------
59 | control_point_idx: torch.Tensor
60 | `(b, 4)` array of indices for control points.
61 | """
62 | # find index of upper bound of interval for each query point
63 | sample_positions = sample_positions.contiguous()
64 | query_points = query_points.contiguous()
65 | iub_idx = torch.searchsorted(sample_positions, query_points, side='right')
66 |
67 | # generate (b, 4) array of indices of control points [s0, s1, s2, s3]
68 | # required for cubic interpolation
69 | s0_idx = iub_idx - 2
70 | s1_idx = iub_idx - 1
71 | s2_idx = iub_idx
72 | s3_idx = iub_idx + 1
73 | return einops.rearrange([s0_idx, s1_idx, s2_idx, s3_idx], 's b -> b s')
74 |
75 |
76 | def interpolants_to_interpolation_data_1d(
77 | interpolants: torch.Tensor, n_samples: int
78 | ) -> Tuple[torch.Tensor, torch.Tensor]:
79 | """Find the necessary data for piecewise cubic interpolation on a padded grid.
80 |
81 | Two pieces of data are required for piecewise cubic interpolation
82 | - four control points `[p0, p1, p2, p3]`
83 | - the interpolation coordinate
84 |
85 | The interpolation coordinate is a value in the range [0, 1] telling us how far into
86 | the interval `[p1, p2]` a query point is.
87 |
88 | This function returns the indices of the control points and the interpolation
89 | coordinate for a 1D grid. Returning the indices rather than the control points
90 | makes this more flexible for use in multidimensional grid interpolation which
91 | requires reusing and combining control point indices across dimensions.
92 |
93 | Parameters
94 | ----------
95 | interpolants: torch.Tensor
96 | `(b, )` batch of values in range [0, 1] covering the dimension being
97 | interpolated.
98 | n_samples: int
99 | The number of samples on the grid being interpolated (prior to padding).
100 |
101 | Returns
102 | -------
103 | control_point_idx, interpolation_coordinate: Tuple[torch.Tensor, torch.Tensor]
104 | The indices of control points `[p0, p1, p2, p3]` on a padded 1D grid and the
105 | interpolation coordinate associated with the interval `[p1, p2]`.
106 | """
107 | interpolants = torch.clamp(interpolants, min=0, max=1)
108 | device = interpolants.device
109 | if n_samples > 1:
110 | grid_u = generate_sample_positions_for_padded_grid_1d(n_samples, device=device)
111 | control_point_idx = find_control_point_idx_1d(
112 | sample_positions=grid_u, query_points=interpolants
113 | )
114 | u_p1 = grid_u[control_point_idx[:, 1]]
115 | du = 1 / (n_samples - 1)
116 | interpolation_coordinate = (interpolants - u_p1) / du
117 | else:
118 | control_point_idx = einops.repeat(
119 | torch.tensor([0, 1, 2, 3]), 'p -> b p', b=len(interpolants)
120 | )
121 | interpolation_coordinate = einops.repeat(
122 | torch.tensor([0.5], device=device), '1 -> b', b=len(interpolants)
123 | )
124 | return control_point_idx, interpolation_coordinate
125 |
126 |
127 | def coerce_to_multichannel_grid(grid: torch.Tensor, grid_ndim: int) -> torch.Tensor:
128 | """If missing, add a channel dimension to a multidimensional grid.
129 |
130 | e.g. for a 2D (h, w) grid
131 | `h w -> 1 h w`
132 | `c h w -> c h w`
133 | """
134 | grid_is_multichannel = grid.ndim == grid_ndim + 1
135 | grid_is_single_channel = grid.ndim == grid_ndim
136 | if grid_is_single_channel is False and grid_is_multichannel is False:
137 | raise ValueError(f'expected a {grid_ndim}D grid, got {grid.ndim}')
138 | if grid_is_single_channel:
139 | grid = einops.rearrange(grid, '... -> 1 ...')
140 | return grid
141 |
142 |
143 | MonotonicityType = Literal['increasing', 'decreasing']
144 |
145 |
146 | def transform_to_monotonic_nd(
147 | tensor: torch.Tensor, ndims: int, monotonicity: MonotonicityType
148 | ) -> torch.Tensor:
149 | """Transform tensor values, so they are monotonic across dimensions.
150 |
151 | Parameters
152 | ----------
153 | tensor: torch.Tensor
154 | a tensor of the arbitrary shape.
155 | ndims: int
156 | the number of the dimensions counting from the last to the first, for which
157 | elements should be monotonic.
158 | monotonicity: str
159 | Either 'decreasing' or 'increasing'.
160 |
161 | Returns
162 | -------
163 | tensor: torch.Tensor
164 | a tensor, with elements monotonic for the last `ndims` dimensions.
165 | """
166 | monotonicity_function: Callable
167 |
168 | if monotonicity == 'increasing':
169 | monotonicity_function = torch.cummax
170 | elif monotonicity == 'decreasing':
171 | monotonicity_function = torch.cummin
172 | elif monotonicity != '':
173 | raise ValueError(f'Unsupported monotonicity type "{monotonicity}" specified.')
174 |
175 | for dim in range(1, ndims + 1):
176 | tensor, _ = monotonicity_function(tensor, dim=-dim)
177 |
178 | return tensor
179 |
180 |
181 | def batch(iterable: Sequence, n: int = 1) -> Iterable[Iterable]:
182 | """Split an iterable into batches of constant length."""
183 | max_len = len(iterable)
184 | for idx in range(0, max_len, n):
185 | yield iterable[idx : min(idx + n, max_len)]
186 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/teamtomo/torch-cubic-spline-grids/9894cef28da6ae8055a956d83bcdee522907c672/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_grid_optimisation.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 | from torch_cubic_spline_grids import (
4 | CubicBSplineGrid1d,
5 | CubicBSplineGrid2d,
6 | CubicBSplineGrid3d,
7 | CubicBSplineGrid4d,
8 | )
9 |
10 |
11 | def test_1d_grid_optimisation():
12 | grid_resolution = 6
13 | n_observations_per_iteration = 100
14 | grid = CubicBSplineGrid1d(resolution=grid_resolution, n_channels=1)
15 |
16 | def f(x: torch.Tensor, add_noise: bool = False):
17 | y = torch.sin(x * 2 * torch.pi)
18 | if add_noise is True:
19 | y += torch.normal(mean=torch.zeros(len(y)), std=0.3)
20 | return y
21 |
22 | optimiser = torch.optim.SGD(lr=0.1, params=grid.parameters())
23 | for _i in range(5000):
24 | x = torch.rand(size=(n_observations_per_iteration,))
25 | observations = f(x, add_noise=True)
26 | prediction = grid(x).squeeze()
27 | loss = torch.mean(torch.abs(prediction - observations))
28 | loss.backward()
29 | optimiser.step()
30 | optimiser.zero_grad()
31 |
32 | x = torch.linspace(0, 1, steps=100)
33 | ground_truth = f(x)
34 | prediction = grid(x).squeeze()
35 | mean_absolute_error = torch.mean(torch.abs(prediction - ground_truth))
36 | assert mean_absolute_error.item() < 0.02
37 |
38 |
39 | def test_1d_grid_optimization_decreasing():
40 | grid_resolution = 8
41 | n_observations_per_iteration = 100
42 | grid = CubicBSplineGrid1d(
43 | resolution=grid_resolution, n_channels=1, monotonicity='decreasing'
44 | )
45 |
46 | def f(x: torch.Tensor, add_noise: bool = False):
47 | y = torch.exp(-5 * x)
48 | if add_noise is True:
49 | y += torch.normal(mean=torch.zeros(len(y)), std=0.4)
50 | return y
51 |
52 | optimiser = torch.optim.SGD(lr=0.1, params=grid.parameters())
53 | for _i in range(5000):
54 | x = torch.rand(size=(n_observations_per_iteration,))
55 | observations = f(x, add_noise=True)
56 | prediction = grid(x).squeeze()
57 | loss = torch.mean(torch.abs(prediction - observations))
58 | loss.backward()
59 | optimiser.step()
60 | optimiser.zero_grad()
61 |
62 | x = torch.linspace(0, 1, steps=100)
63 | prediction = grid(x).squeeze()
64 |
65 | eps = torch.tensor(1e-5, dtype=prediction.dtype)
66 | non_increasing = torch.diff(prediction, dim=-1) <= eps
67 | assert non_increasing.all().item()
68 |
69 |
70 | def test_2d_grid_optimisation():
71 | grid_resolution = (3, 3)
72 | n_observations_per_iteration = 100
73 | grid = CubicBSplineGrid2d(resolution=grid_resolution, n_channels=1)
74 |
75 | def f(x: torch.Tensor, add_noise: bool = False):
76 | centered = x - 0.5
77 | y = torch.sqrt(torch.sum(centered**2, dim=-1)) # (x**2 + y**2) ** 0.5
78 | if add_noise is True:
79 | y += torch.normal(mean=torch.zeros(len(y)), std=0.3)
80 | return y
81 |
82 | optimiser = torch.optim.SGD(lr=0.3, params=grid.parameters())
83 | for _i in range(1000):
84 | x = torch.rand(size=(n_observations_per_iteration, 2))
85 | observations = f(x, add_noise=True)
86 | prediction = grid(x).squeeze()
87 | loss = torch.mean((prediction - observations) ** 2)
88 | loss.backward()
89 | optimiser.step()
90 | optimiser.zero_grad()
91 |
92 | _x = torch.linspace(0, 1, steps=100)
93 | x = torch.meshgrid(_x, _x, indexing='xy')
94 | x = einops.rearrange([*x], 'xy h w -> (h w) xy')
95 | ground_truth = f(x)
96 | prediction = grid(x).squeeze()
97 | mean_absolute_error = torch.mean(torch.abs(prediction - ground_truth))
98 | assert mean_absolute_error.item() < 0.02
99 |
100 |
101 | def test_3d_grid_optimisation():
102 | grid_resolution = (3, 3, 3)
103 | n_observations_per_iteration = 1000
104 | grid = CubicBSplineGrid3d(resolution=grid_resolution, n_channels=1)
105 |
106 | def f(x: torch.Tensor, add_noise: bool = False):
107 | centered = x - 0.5
108 | y = torch.sqrt(torch.sum(centered**2, dim=-1)) # (x**2 + y**2 + z**2) ** 0.5
109 | if add_noise is True:
110 | y += torch.normal(mean=torch.zeros(len(y)), std=0.3)
111 | return y
112 |
113 | optimiser = torch.optim.SGD(lr=0.3, params=grid.parameters())
114 | for _i in range(1000):
115 | x = torch.rand(size=(n_observations_per_iteration, 3))
116 | observations = f(x, add_noise=True)
117 | prediction = grid(x).squeeze()
118 | loss = torch.mean((prediction - observations) ** 2)
119 | loss.backward()
120 | optimiser.step()
121 | optimiser.zero_grad()
122 |
123 | _x = torch.linspace(0, 1, steps=100)
124 | x = torch.meshgrid(_x, _x, _x, indexing='xy')
125 | x = einops.rearrange([*x], 'xyz d h w -> (d h w) xyz')
126 | ground_truth = f(x)
127 | prediction = grid(x).squeeze()
128 | mean_absolute_error = torch.mean(torch.abs(prediction - ground_truth))
129 | assert mean_absolute_error.item() < 0.02
130 |
131 |
132 | def test_4d_grid_optimisation():
133 | grid_resolution = (3, 3, 3, 3)
134 | n_observations_per_iteration = 1000
135 | grid = CubicBSplineGrid4d(resolution=grid_resolution, n_channels=1)
136 |
137 | def f(x: torch.Tensor, add_noise: bool = False):
138 | centered = x - 0.5
139 | y = torch.sqrt(torch.sum(centered**2, dim=-1))
140 | if add_noise is True:
141 | y += torch.normal(mean=torch.zeros(len(y)), std=0.3)
142 | return y
143 |
144 | optimiser = torch.optim.SGD(lr=0.9, params=grid.parameters())
145 | for _i in range(1000):
146 | x = torch.rand(size=(n_observations_per_iteration, 4))
147 | observations = f(x, add_noise=True)
148 | prediction = grid(x).squeeze()
149 | loss = torch.mean((prediction - observations) ** 2)
150 | loss.backward()
151 | optimiser.step()
152 | optimiser.zero_grad()
153 |
154 | _x = torch.linspace(0, 1, steps=10)
155 | x = torch.meshgrid(_x, _x, _x, _x, indexing='xy')
156 | x = einops.rearrange([*x], 'xyz u d h w -> (u d h w) xyz')
157 | ground_truth = f(x)
158 | prediction = grid(x).squeeze()
159 | mean_absolute_error = torch.mean(torch.abs(prediction - ground_truth))
160 | assert mean_absolute_error.item() < 0.02
161 |
--------------------------------------------------------------------------------
/tests/test_grids.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_cubic_spline_grids import (
5 | CubicBSplineGrid1d,
6 | CubicBSplineGrid2d,
7 | CubicBSplineGrid3d,
8 | CubicBSplineGrid4d,
9 | CubicCatmullRomGrid1d,
10 | CubicCatmullRomGrid2d,
11 | CubicCatmullRomGrid3d,
12 | CubicCatmullRomGrid4d,
13 | )
14 |
15 |
16 | @pytest.mark.parametrize(
17 | 'grid_cls', [CubicBSplineGrid1d, CubicCatmullRomGrid1d]
18 | )
19 | def test_1d_grid_direct_instantiation(grid_cls):
20 | """Test grid instantiation with different types for resolution argument."""
21 | grid = grid_cls()
22 | assert isinstance(grid, grid_cls)
23 | assert grid.data.shape == (1, 2)
24 |
25 | grid = grid_cls(resolution=5, n_channels=3)
26 | assert isinstance(grid, grid_cls)
27 | assert grid.data.shape == (3, 5)
28 |
29 | grid = grid_cls(resolution=(5,), n_channels=3)
30 | assert isinstance(grid, grid_cls)
31 | assert grid.data.shape == (3, 5)
32 |
33 |
34 | @pytest.mark.parametrize(
35 | 'grid_cls', [CubicBSplineGrid1d, CubicCatmullRomGrid1d]
36 | )
37 | def test_1d_grid_instantiation_from_existing_data(grid_cls):
38 | """Test grid instantiation from existing data."""
39 | grid = grid_cls.from_grid_data(data=torch.zeros(3, 5))
40 | assert grid.ndim == 1
41 | assert grid.resolution == (5,)
42 | assert grid.n_channels == 3
43 | assert isinstance(grid._data, torch.nn.Parameter)
44 |
45 |
46 | @pytest.mark.parametrize(
47 | 'grid_cls', [CubicBSplineGrid1d, CubicCatmullRomGrid1d]
48 | )
49 | def test_calling_1d_grid(grid_cls):
50 | """Test calling 1d grid."""
51 | grid = grid_cls()
52 | expected = torch.tensor([0.])
53 | for arg in (0.5, [0.5], torch.tensor([0.5])):
54 | result = grid(arg)
55 | assert torch.allclose(result, expected)
56 |
57 |
58 | @pytest.mark.parametrize(
59 | 'grid_cls', [CubicBSplineGrid1d, CubicCatmullRomGrid1d]
60 | )
61 | def test_1d_grid_with_singleton_dimension(grid_cls):
62 | """Test that a 2D grid with a singleton dimension can be used."""
63 | # singleton in width dim
64 | grid = grid_cls(resolution=1)
65 | result = grid(0.5)
66 | assert torch.allclose(result, torch.tensor([0.0]))
67 |
68 |
69 | @pytest.mark.parametrize(
70 | 'grid_cls', [CubicBSplineGrid1d, CubicCatmullRomGrid1d]
71 | )
72 | def test_calling_1d_grid_with_stacked_coords(grid_cls):
73 | """Test calling a 1d grid with a multidimensional array of coordinates."""
74 | grid = grid_cls(resolution=1)
75 | h, w = 4, 4
76 |
77 | # no explicit coordinate dimension
78 | result = grid(torch.rand(size=(h, w)))
79 | assert result.shape == (h, w)
80 | assert torch.allclose(result, torch.tensor([0]).float())
81 |
82 | # with explicit coordinate dimension
83 | result = grid(torch.rand(size=(h, w, 1)))
84 | assert result.shape == (h, w, 1)
85 | assert torch.allclose(result, torch.tensor([0]).float())
86 |
87 |
88 | def test_interpolation_matrix_device():
89 | """Interpolation matrix should move when Module moves to a different device."""
90 | grid = CubicBSplineGrid1d(resolution=3)
91 | assert grid.interpolation_matrix.device == torch.device('cpu')
92 | grid.to(torch.device('meta'))
93 | assert grid.interpolation_matrix.device == torch.device('meta')
94 |
95 |
96 | def test_grid_device():
97 | """Grid data should move when Module moves to a different device."""
98 | grid = CubicBSplineGrid1d(resolution=3)
99 | assert grid.data.device == torch.device('cpu')
100 | grid.to(torch.device('meta'))
101 | assert grid.data.device == torch.device('meta')
102 |
103 |
104 | @pytest.mark.parametrize(
105 | 'grid_cls', [CubicBSplineGrid2d, CubicCatmullRomGrid2d]
106 | )
107 | def test_2d_grid_direct_instantiation(grid_cls):
108 | grid = grid_cls()
109 | assert isinstance(grid, grid_cls)
110 | assert grid.data.shape == (1, 2, 2)
111 |
112 | grid = grid_cls(resolution=(5, 4), n_channels=3)
113 | assert isinstance(grid, grid_cls)
114 | assert grid.data.shape == (3, 5, 4)
115 |
116 |
117 | @pytest.mark.parametrize(
118 | 'grid_cls', [CubicBSplineGrid2d, CubicCatmullRomGrid2d]
119 | )
120 | def test_2d_grid_instantiation_from_existing_data(grid_cls):
121 | """Test grid instantiation from existing data."""
122 | grid = grid_cls.from_grid_data(data=torch.zeros(3, 5, 4))
123 | assert grid.ndim == 2
124 | assert grid.resolution == (5, 4)
125 | assert grid.n_channels == 3
126 | assert isinstance(grid._data, torch.nn.Parameter)
127 |
128 |
129 | @pytest.mark.parametrize(
130 | 'grid_cls', [CubicBSplineGrid2d, CubicCatmullRomGrid2d]
131 | )
132 | def test_calling_2d_grid(grid_cls):
133 | """Test calling 2d grid."""
134 | grid = grid_cls()
135 | expected = torch.tensor([0., 0.])
136 | for arg in ([0.5, 0.5], torch.tensor([0.5, 0.5])):
137 | result = grid(arg)
138 | assert torch.allclose(result, expected)
139 |
140 |
141 | @pytest.mark.parametrize(
142 | 'grid_cls', [CubicBSplineGrid2d, CubicCatmullRomGrid2d]
143 | )
144 | def test_2d_grid_with_singleton_dimension(grid_cls):
145 | """Test that a 2D grid with a singleton dimension can be used."""
146 | # singleton in width dim
147 | grid = grid_cls(resolution=(2, 1))
148 | result = grid([0.5, 0.5])
149 | assert torch.allclose(result, torch.tensor([0.0, 0.0]))
150 |
151 | # singleton in height dim
152 | grid = grid_cls(resolution=(1, 2))
153 | result = grid([0.5, 0.5])
154 | assert torch.allclose(result, torch.tensor([0.0, 0.0]))
155 |
156 |
157 | @pytest.mark.parametrize(
158 | 'grid_cls', [CubicBSplineGrid2d, CubicCatmullRomGrid2d]
159 | )
160 | def test_calling_2d_grid_with_stacked_coordinates(grid_cls):
161 | """Test calling a 2D grid with stacked coordinates."""
162 | grid = grid_cls(resolution=(2, 2), n_channels=1)
163 | result = grid(torch.rand(size=(5, 5, 2)))
164 | assert result.shape == (5, 5, 1)
165 |
166 | grid = grid_cls(resolution=(2, 2), n_channels=2)
167 | result = grid(torch.rand(size=(5, 5, 2)))
168 | assert result.shape == (5, 5, 2)
169 |
170 |
171 | @pytest.mark.parametrize(
172 | 'grid_cls', [CubicBSplineGrid3d, CubicCatmullRomGrid3d]
173 | )
174 | def test_3d_grid_direct_instantiation(grid_cls):
175 | grid = grid_cls()
176 | assert isinstance(grid, grid_cls)
177 | assert grid.data.shape == (1, 2, 2, 2)
178 |
179 | grid = grid_cls(resolution=(5, 4, 3), n_channels=2)
180 | assert isinstance(grid, grid_cls)
181 | assert grid.data.shape == (2, 5, 4, 3)
182 |
183 |
184 | @pytest.mark.parametrize(
185 | 'grid_cls', [CubicBSplineGrid3d, CubicCatmullRomGrid3d]
186 | )
187 | def test_3d_grid_instantiation_from_existing_data(grid_cls):
188 | """Test grid instantiation from existing data."""
189 | grid = grid_cls.from_grid_data(data=torch.zeros(2, 5, 4, 3))
190 | assert grid.ndim == 3
191 | assert grid.resolution == (5, 4, 3)
192 | assert grid.n_channels == 2
193 | assert isinstance(grid._data, torch.nn.Parameter)
194 |
195 |
196 | @pytest.mark.parametrize(
197 | 'grid_cls', [CubicBSplineGrid3d, CubicCatmullRomGrid3d]
198 | )
199 | def test_calling_3d_grid(grid_cls):
200 | """Test calling 3d grid."""
201 | grid = grid_cls()
202 | expected = torch.tensor([0., 0., 0.])
203 | for arg in ([0.5, 0.5, 0.5], torch.tensor([0.5, 0.5, 0.5])):
204 | result = grid(arg)
205 | assert torch.allclose(result, expected)
206 |
207 |
208 | @pytest.mark.parametrize(
209 | 'grid_cls', [CubicBSplineGrid3d, CubicCatmullRomGrid3d]
210 | )
211 | def test_calling_3d_grid_with_stacked_coordinates(grid_cls):
212 | """Test calling 3d grid with stacked coordinates."""
213 | grid = grid_cls()
214 | d, h, w = 4, 4, 4
215 | result = grid(torch.rand(size=(d, h, w, 3)))
216 | assert result.shape == (d, h, w, 1)
217 |
218 |
219 | @pytest.mark.parametrize(
220 | 'grid_cls', [CubicBSplineGrid3d, CubicCatmullRomGrid3d]
221 | )
222 | def test_3d_grid_with_singleton_dimension(grid_cls):
223 | """Test that a 3D grid with a singleton dimension can be used."""
224 | # singleton in width dim
225 | grid = grid_cls(resolution=(2, 2, 1))
226 | result = grid([0.5, 0.5, 0.5])
227 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0]))
228 |
229 | # singleton in height dim
230 | grid = grid_cls(resolution=(2, 1, 2))
231 | result = grid([0.5, 0.5, 0.5])
232 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0]))
233 |
234 | # singleton in depth dim
235 | grid = grid_cls(resolution=(1, 2, 2))
236 | result = grid([0.5, 0.5, 0.5])
237 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0]))
238 |
239 |
240 | @pytest.mark.parametrize(
241 | 'grid_cls', [CubicBSplineGrid4d, CubicCatmullRomGrid4d]
242 | )
243 | def test_4d_grid_direct_instantiation(grid_cls):
244 | grid = grid_cls()
245 | assert isinstance(grid, grid_cls)
246 | assert grid.data.shape == (1, 2, 2, 2, 2)
247 |
248 | grid = grid_cls(resolution=(6, 5, 4, 3), n_channels=2)
249 | assert isinstance(grid, grid_cls)
250 | assert grid.data.shape == (2, 6, 5, 4, 3)
251 |
252 |
253 | @pytest.mark.parametrize(
254 | 'grid_cls', [CubicBSplineGrid4d, CubicCatmullRomGrid4d]
255 | )
256 | def test_4d_grid_instantiation_from_existing_data(grid_cls):
257 | """Test grid instantiation from existing data."""
258 | grid = grid_cls.from_grid_data(data=torch.zeros(2, 6, 5, 4, 3))
259 | assert grid.ndim == 4
260 | assert grid.resolution == (6, 5, 4, 3)
261 | assert grid.n_channels == 2
262 | assert isinstance(grid._data, torch.nn.Parameter)
263 |
264 |
265 | @pytest.mark.parametrize(
266 | 'grid_cls', [CubicBSplineGrid4d, CubicCatmullRomGrid4d]
267 | )
268 | def test_calling_4d_grid(grid_cls):
269 | """Test calling 4d grid."""
270 | grid = grid_cls()
271 | expected = torch.tensor([0., 0., 0., 0.])
272 | for arg in ([0.5, 0.5, 0.5, 0.5], torch.tensor([0.5, 0.5, 0.5, 0.5])):
273 | result = grid(arg)
274 | assert torch.allclose(result, expected)
275 |
276 |
277 | @pytest.mark.parametrize(
278 | 'grid_cls', [CubicBSplineGrid4d, CubicCatmullRomGrid4d]
279 | )
280 | def test_calling_4d_grid_with_stacked_coordinates(grid_cls):
281 | """Test calling 3d grid with stacked coordinates."""
282 | grid = grid_cls()
283 | t, d, h, w = 2, 4, 4, 4
284 | result = grid(torch.rand(size=(t, d, h, w, 4)))
285 | assert result.shape == (t, d, h, w, 1)
286 |
287 |
288 | @pytest.mark.parametrize(
289 | 'grid_cls', [CubicBSplineGrid4d, CubicCatmullRomGrid4d]
290 | )
291 | def test_4d_grid_with_singleton_dimension(grid_cls):
292 | """Test that a 4D grid with a singleton dimension can be used."""
293 | # singleton in width dim
294 | grid = grid_cls(resolution=(2, 2, 2, 1))
295 | result = grid([0.5, 0.5, 0.5, 0.5])
296 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0, 0.0]))
297 |
298 | # singleton in height dim
299 | grid = grid_cls(resolution=(2, 2, 1, 2))
300 | result = grid([0.5, 0.5, 0.5, 0.5])
301 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0, 0.0]))
302 |
303 | # singleton in depth dim
304 | grid = grid_cls(resolution=(2, 1, 2, 2))
305 | result = grid([0.5, 0.5, 0.5, 0.5])
306 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0, 0.0]))
307 |
308 | # singleton in time dim
309 | grid = grid_cls(resolution=(1, 2, 2, 2))
310 | result = grid([0.5, 0.5, 0.5, 0.5])
311 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0, 0.0]))
312 |
313 | # multiple singletons
314 | grid = grid_cls(resolution=(1, 1, 1, 1))
315 | result = grid([0.5, 0.5, 0.5, 0.5])
316 | assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0, 0.0]))
317 |
--------------------------------------------------------------------------------
/tests/test_interpolate_grid.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_cubic_spline_grids import interpolate_grids
4 | from torch_cubic_spline_grids._constants import CUBIC_B_SPLINE_MATRIX
5 |
6 |
7 | def test_interpolate_grid_1d():
8 | """Check that 1d interpolation works as expected."""
9 | grid = torch.tensor([0, 1, 2, 3, 4, 5]).float()
10 | u = torch.tensor([0.5]).view((1, 1))
11 | result = interpolate_grids.interpolate_grid_1d(
12 | grid, u, matrix=CUBIC_B_SPLINE_MATRIX
13 | )
14 | expected = torch.tensor([[2.5]])
15 | assert torch.allclose(result, expected)
16 |
17 |
18 | def test_interpolate_grid_1d_approx():
19 | """Check that 1D interpolation approximates a function."""
20 | grid_x = torch.linspace(0, 2 * torch.pi, steps=50)
21 | grid_y = torch.sin(grid_x)
22 | sample_x = torch.linspace(0, 1, steps=1000).view((-1, 1))
23 | sample_y = interpolate_grids.interpolate_grid_1d(
24 | grid_y, sample_x, matrix=CUBIC_B_SPLINE_MATRIX
25 | )
26 | ground_truth_y = torch.sin(sample_x * 2 * torch.pi)
27 | mean_absolute_error = torch.mean(torch.abs(sample_y - ground_truth_y))
28 | assert mean_absolute_error <= 0.01
29 |
30 |
31 | def test_interpolate_grid_2d():
32 | """Check that 2D interpolation works."""
33 | grid = torch.tensor(
34 | [[0, 1, 2, 3],
35 | [4, 5, 6, 7],
36 | [8, 9, 10, 11],
37 | [12, 13, 14, 15]]
38 | ).float()
39 | u = torch.tensor([0.5, 0.5]).view(1, 2)
40 | result = interpolate_grids.interpolate_grid_2d(
41 | grid, u, matrix=CUBIC_B_SPLINE_MATRIX
42 | )
43 | expected = torch.tensor([7.5])
44 | assert torch.allclose(result, expected)
45 |
46 |
47 | def test_interpolate_grid_3d():
48 | """Check that 3D interpolation works."""
49 | grid = torch.tensor(
50 | [[[0, 1, 2, 3],
51 | [4, 5, 6, 7],
52 | [8, 9, 10, 11],
53 | [12, 13, 14, 15]],
54 | [[16, 17, 18, 19],
55 | [20, 21, 22, 23],
56 | [24, 25, 26, 27],
57 | [28, 29, 30, 31]],
58 | [[32, 33, 34, 35],
59 | [36, 37, 38, 39],
60 | [40, 41, 42, 43],
61 | [44, 45, 46, 47]],
62 | [[48, 49, 50, 51],
63 | [52, 53, 54, 55],
64 | [56, 57, 58, 59],
65 | [60, 61, 62, 63]]],
66 | ).float()
67 | u = torch.tensor([[0.5, 0.5, 0.5]]).view(1, 3)
68 | result = interpolate_grids.interpolate_grid_3d(grid, u,
69 | matrix=CUBIC_B_SPLINE_MATRIX)
70 | expected = torch.tensor([31.5])
71 | assert torch.allclose(result, expected)
72 |
73 |
74 | def test_interpolate_grid_4d():
75 | """Check that 4D interpolation works as expected."""
76 | grid = torch.tensor(
77 | [[[[0, 1, 2, 3],
78 | [4, 5, 6, 7],
79 | [8, 9, 10, 11],
80 | [12, 13, 14, 15]],
81 | [[16, 17, 18, 19],
82 | [20, 21, 22, 23],
83 | [24, 25, 26, 27],
84 | [28, 29, 30, 31]],
85 | [[32, 33, 34, 35],
86 | [36, 37, 38, 39],
87 | [40, 41, 42, 43],
88 | [44, 45, 46, 47]],
89 | [[48, 49, 50, 51],
90 | [52, 53, 54, 55],
91 | [56, 57, 58, 59],
92 | [60, 61, 62, 63]]],
93 | [[[64, 65, 66, 67],
94 | [68, 69, 70, 71],
95 | [72, 73, 74, 75],
96 | [76, 77, 78, 79]],
97 | [[80, 81, 82, 83],
98 | [84, 85, 86, 87],
99 | [88, 89, 90, 91],
100 | [92, 93, 94, 95]],
101 | [[96, 97, 98, 99],
102 | [100, 101, 102, 103],
103 | [104, 105, 106, 107],
104 | [108, 109, 110, 111]],
105 | [[112, 113, 114, 115],
106 | [116, 117, 118, 119],
107 | [120, 121, 122, 123],
108 | [124, 125, 126, 127]]],
109 | [[[128, 129, 130, 131],
110 | [132, 133, 134, 135],
111 | [136, 137, 138, 139],
112 | [140, 141, 142, 143]],
113 | [[144, 145, 146, 147],
114 | [148, 149, 150, 151],
115 | [152, 153, 154, 155],
116 | [156, 157, 158, 159]],
117 | [[160, 161, 162, 163],
118 | [164, 165, 166, 167],
119 | [168, 169, 170, 171],
120 | [172, 173, 174, 175]],
121 | [[176, 177, 178, 179],
122 | [180, 181, 182, 183],
123 | [184, 185, 186, 187],
124 | [188, 189, 190, 191]]],
125 | [[[192, 193, 194, 195],
126 | [196, 197, 198, 199],
127 | [200, 201, 202, 203],
128 | [204, 205, 206, 207]],
129 | [[208, 209, 210, 211],
130 | [212, 213, 214, 215],
131 | [216, 217, 218, 219],
132 | [220, 221, 222, 223]],
133 | [[224, 225, 226, 227],
134 | [228, 229, 230, 231],
135 | [232, 233, 234, 235],
136 | [236, 237, 238, 239]],
137 | [[240, 241, 242, 243],
138 | [244, 245, 246, 247],
139 | [248, 249, 250, 251],
140 | [252, 253, 254, 255]]]]
141 | ).float()
142 | u = torch.tensor([0.5, 0.5, 0.5, 0.5]).view(1, 4)
143 | result = interpolate_grids.interpolate_grid_4d(
144 | grid, u, matrix=CUBIC_B_SPLINE_MATRIX
145 | )
146 | expected = torch.tensor([127.5])
147 | assert torch.allclose(result, expected)
148 |
--------------------------------------------------------------------------------
/tests/test_interpolate_pieces.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_cubic_spline_grids import interpolate_pieces
4 | from torch_cubic_spline_grids._constants import CUBIC_B_SPLINE_MATRIX
5 |
6 |
7 | def test_interpolate_pieces_1d():
8 | """test that cubic B-spline interpolation results in expected values."""
9 | t = 0.5
10 | points = torch.tensor([-1.5, 5.1, 2.2, 6.8])
11 | result = interpolate_pieces.interpolate_pieces_1d(
12 | control_points=points.view(1, 1, 4), # (b, c, 4)
13 | t=torch.tensor([t]),
14 | matrix=CUBIC_B_SPLINE_MATRIX,
15 | )
16 | expected = torch.tensor([1, t, t ** 2, t ** 3]) @ CUBIC_B_SPLINE_MATRIX @ points
17 | assert torch.allclose(result, expected)
18 |
19 |
20 | def test_interpolate_pieces_1d_with_batched_queries():
21 | """test batched evaluation over one 'piece' (set of 4 control points)."""
22 | pieces = torch.tensor([[0, 1, 2, 3]]).float().view(1, 1, 4) # (b, c, 4)
23 | t = torch.tensor([0, 0.5, 1]) # (b, )
24 | result = interpolate_pieces.interpolate_pieces_1d(
25 | pieces, t, matrix=CUBIC_B_SPLINE_MATRIX
26 | )
27 |
28 | # cubic b spline intepolation should be equivalent to linear interpolation
29 | # for four control points on the same line
30 | expected = torch.tensor([1, 1.5, 2]).view((3, 1))
31 | assert torch.allclose(result, expected)
32 |
33 |
34 | def test_interpolate_pieces_1d_with_batched_pieces_and_queries():
35 | """test batched evaluation over batched 'pieces' (sets of 4 control points)."""
36 | pieces = torch.tensor(
37 | [[0, 1, 2, 3],
38 | [2, 3, 4, 5]]
39 | ).float().view(2, 1, 4) # (b, c, 4)
40 | t = torch.tensor([0.5, 0.5]) # (b, )
41 | result = interpolate_pieces.interpolate_pieces_1d(
42 | pieces, t, matrix=CUBIC_B_SPLINE_MATRIX
43 | ) # (b, c)
44 |
45 | # cubic b spline intepolation should be equivalent to linear interpolation
46 | # for four control points on the same line
47 | expected = torch.tensor([[1.5, 3.5]]).view(2, 1)
48 | assert torch.allclose(result, expected)
49 |
50 |
51 | def test_interpolate_pieces_2d():
52 | """test evaluation of 2D cubic B-spline interpolation."""
53 | control_points = torch.tensor(
54 | [[0, 1, 2, 3],
55 | [4, 5, 6, 7],
56 | [8, 9, 10, 11],
57 | [12, 13, 14, 15]]
58 | ).float().view(1, 1, 4, 4)
59 | t = torch.tensor([0.5, 0.5]).view(1, 2)
60 | result = interpolate_pieces.interpolate_pieces_2d(
61 | control_points, t, matrix=CUBIC_B_SPLINE_MATRIX
62 | )
63 | expected = torch.tensor([7.5])
64 | assert torch.allclose(result, expected)
65 |
66 |
67 | def test_interpolate_pieces_3d():
68 | """test evaluation of 3D cubic B-spline interpolation."""
69 | control_points = torch.tensor(
70 | [[[0, 1, 2, 3],
71 | [4, 5, 6, 7],
72 | [8, 9, 10, 11],
73 | [12, 13, 14, 15]],
74 | [[16, 17, 18, 19],
75 | [20, 21, 22, 23],
76 | [24, 25, 26, 27],
77 | [28, 29, 30, 31]],
78 | [[32, 33, 34, 35],
79 | [36, 37, 38, 39],
80 | [40, 41, 42, 43],
81 | [44, 45, 46, 47]],
82 | [[48, 49, 50, 51],
83 | [52, 53, 54, 55],
84 | [56, 57, 58, 59],
85 | [60, 61, 62, 63]]],
86 | ).float().view(1, 1, 4, 4, 4)
87 | t = torch.tensor([[0.5, 0.5, 0.5]]).view(1, 3)
88 | result = interpolate_pieces.interpolate_pieces_3d(
89 | control_points, t, matrix=CUBIC_B_SPLINE_MATRIX
90 | )
91 | expected = torch.tensor([31.5])
92 | assert torch.allclose(result, expected)
93 |
94 |
95 | def test_interpolate_pieces_4d():
96 | """test evaluation of 4D cubic B-spline interpolation."""
97 | control_points = torch.tensor(
98 | [[[[0, 1, 2, 3],
99 | [4, 5, 6, 7],
100 | [8, 9, 10, 11],
101 | [12, 13, 14, 15]],
102 | [[16, 17, 18, 19],
103 | [20, 21, 22, 23],
104 | [24, 25, 26, 27],
105 | [28, 29, 30, 31]],
106 | [[32, 33, 34, 35],
107 | [36, 37, 38, 39],
108 | [40, 41, 42, 43],
109 | [44, 45, 46, 47]],
110 | [[48, 49, 50, 51],
111 | [52, 53, 54, 55],
112 | [56, 57, 58, 59],
113 | [60, 61, 62, 63]]],
114 | [[[64, 65, 66, 67],
115 | [68, 69, 70, 71],
116 | [72, 73, 74, 75],
117 | [76, 77, 78, 79]],
118 | [[80, 81, 82, 83],
119 | [84, 85, 86, 87],
120 | [88, 89, 90, 91],
121 | [92, 93, 94, 95]],
122 | [[96, 97, 98, 99],
123 | [100, 101, 102, 103],
124 | [104, 105, 106, 107],
125 | [108, 109, 110, 111]],
126 | [[112, 113, 114, 115],
127 | [116, 117, 118, 119],
128 | [120, 121, 122, 123],
129 | [124, 125, 126, 127]]],
130 | [[[128, 129, 130, 131],
131 | [132, 133, 134, 135],
132 | [136, 137, 138, 139],
133 | [140, 141, 142, 143]],
134 | [[144, 145, 146, 147],
135 | [148, 149, 150, 151],
136 | [152, 153, 154, 155],
137 | [156, 157, 158, 159]],
138 | [[160, 161, 162, 163],
139 | [164, 165, 166, 167],
140 | [168, 169, 170, 171],
141 | [172, 173, 174, 175]],
142 | [[176, 177, 178, 179],
143 | [180, 181, 182, 183],
144 | [184, 185, 186, 187],
145 | [188, 189, 190, 191]]],
146 | [[[192, 193, 194, 195],
147 | [196, 197, 198, 199],
148 | [200, 201, 202, 203],
149 | [204, 205, 206, 207]],
150 | [[208, 209, 210, 211],
151 | [212, 213, 214, 215],
152 | [216, 217, 218, 219],
153 | [220, 221, 222, 223]],
154 | [[224, 225, 226, 227],
155 | [228, 229, 230, 231],
156 | [232, 233, 234, 235],
157 | [236, 237, 238, 239]],
158 | [[240, 241, 242, 243],
159 | [244, 245, 246, 247],
160 | [248, 249, 250, 251],
161 | [252, 253, 254, 255]]]]
162 | ).float().view(1, 1, 4, 4, 4, 4)
163 | t = torch.tensor([0.5, 0.5, 0.5, 0.5]).view(1, 4)
164 | result = interpolate_pieces.interpolate_pieces_4d(
165 | control_points, t, matrix=CUBIC_B_SPLINE_MATRIX
166 | )
167 | expected = torch.tensor([127.5])
168 | assert torch.allclose(result, expected)
169 |
--------------------------------------------------------------------------------
/tests/test_modules.py:
--------------------------------------------------------------------------------
1 | from torch_cubic_spline_grids import (
2 | CubicBSplineGrid1d,
3 | CubicBSplineGrid2d,
4 | CubicBSplineGrid3d,
5 | CubicBSplineGrid4d,
6 | )
7 |
8 |
9 | def test_grid_class_instantiation():
10 | grid_classes = [
11 | CubicBSplineGrid1d,
12 | CubicBSplineGrid2d,
13 | CubicBSplineGrid3d,
14 | CubicBSplineGrid4d
15 | ]
16 | for grid_class in grid_classes:
17 | instance = grid_class()
18 | assert isinstance(instance, grid_class)
19 | assert len(list(instance.parameters())) > 0
--------------------------------------------------------------------------------
/tests/test_pad_grid.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_cubic_spline_grids import pad_grids
4 |
5 |
6 | def test_pad_1d():
7 | grid = torch.arange(3)
8 | padded_grid = pad_grids.pad_grid_1d(grid)
9 | expected = torch.tensor([-1, 0, 1, 2, 3])
10 | assert torch.allclose(padded_grid, expected)
11 |
12 |
13 | def test_pad_2d():
14 | grid = torch.tensor(
15 | [[0, 1],
16 | [2, 3]]
17 | )
18 | padded_grid = pad_grids.pad_grid_2d(grid)
19 | expected = torch.tensor(
20 | [[-3, -2, -1, 0],
21 | [-1, 0, 1, 2],
22 | [1, 2, 3, 4],
23 | [3, 4, 5, 6]]
24 | )
25 | assert torch.allclose(padded_grid, expected)
26 |
27 |
28 | def test_pad_3d():
29 | grid = torch.tensor(
30 | [[[0, 1],
31 | [2, 3]],
32 | [[4, 5],
33 | [6, 7]]]
34 | )
35 | padded_grid = pad_grids.pad_grid_3d(grid)
36 | expected = torch.tensor(
37 | [[[-7, -6, -5, -4],
38 | [-5, -4, -3, -2],
39 | [-3, -2, -1, 0],
40 | [-1, 0, 1, 2]],
41 |
42 | [[-3, -2, -1, 0],
43 | [-1, 0, 1, 2],
44 | [1, 2, 3, 4],
45 | [3, 4, 5, 6]],
46 |
47 | [[1, 2, 3, 4],
48 | [3, 4, 5, 6],
49 | [5, 6, 7, 8],
50 | [7, 8, 9, 10]],
51 |
52 | [[5, 6, 7, 8],
53 | [7, 8, 9, 10],
54 | [9, 10, 11, 12],
55 | [11, 12, 13, 14]]]
56 | )
57 | assert torch.allclose(padded_grid, expected)
58 |
59 |
60 | def test_pad_4d():
61 | grid = torch.tensor(
62 | [[[[0, 1],
63 | [2, 3]],
64 | [[4, 5],
65 | [6, 7]]],
66 | [[[8, 9],
67 | [10, 11]],
68 | [[12, 13],
69 | [14, 15]]]]
70 | )
71 | padded_grid = pad_grids.pad_grid_4d(grid)
72 | expected = torch.tensor(
73 | [[[[-15, -14, -13, -12],
74 | [-13, -12, -11, -10],
75 | [-11, -10, -9, -8],
76 | [-9, -8, -7, -6]],
77 | [[-11, -10, -9, -8],
78 | [-9, -8, -7, -6],
79 | [-7, -6, -5, -4],
80 | [-5, -4, -3, -2]],
81 | [[-7, -6, -5, -4],
82 | [-5, -4, -3, -2],
83 | [-3, -2, -1, 0],
84 | [-1, 0, 1, 2]],
85 | [[-3, -2, -1, 0],
86 | [-1, 0, 1, 2],
87 | [1, 2, 3, 4],
88 | [3, 4, 5, 6]]],
89 | [[[-7, -6, -5, -4],
90 | [-5, -4, -3, -2],
91 | [-3, -2, -1, 0],
92 | [-1, 0, 1, 2]],
93 | [[-3, -2, -1, 0],
94 | [-1, 0, 1, 2],
95 | [1, 2, 3, 4],
96 | [3, 4, 5, 6]],
97 | [[1, 2, 3, 4],
98 | [3, 4, 5, 6],
99 | [5, 6, 7, 8],
100 | [7, 8, 9, 10]],
101 | [[5, 6, 7, 8],
102 | [7, 8, 9, 10],
103 | [9, 10, 11, 12],
104 | [11, 12, 13, 14]]],
105 | [[[1, 2, 3, 4],
106 | [3, 4, 5, 6],
107 | [5, 6, 7, 8],
108 | [7, 8, 9, 10]],
109 | [[5, 6, 7, 8],
110 | [7, 8, 9, 10],
111 | [9, 10, 11, 12],
112 | [11, 12, 13, 14]],
113 | [[9, 10, 11, 12],
114 | [11, 12, 13, 14],
115 | [13, 14, 15, 16],
116 | [15, 16, 17, 18]],
117 | [[13, 14, 15, 16],
118 | [15, 16, 17, 18],
119 | [17, 18, 19, 20],
120 | [19, 20, 21, 22]]],
121 | [[[9, 10, 11, 12],
122 | [11, 12, 13, 14],
123 | [13, 14, 15, 16],
124 | [15, 16, 17, 18]],
125 | [[13, 14, 15, 16],
126 | [15, 16, 17, 18],
127 | [17, 18, 19, 20],
128 | [19, 20, 21, 22]],
129 | [[17, 18, 19, 20],
130 | [19, 20, 21, 22],
131 | [21, 22, 23, 24],
132 | [23, 24, 25, 26]],
133 | [[21, 22, 23, 24],
134 | [23, 24, 25, 26],
135 | [25, 26, 27, 28],
136 | [27, 28, 29, 30]]]]
137 | )
138 | assert torch.allclose(padded_grid, expected)
139 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from torch_cubic_spline_grids.utils import find_control_point_idx_1d, batch
5 |
6 |
7 | def test_find_control_points():
8 | sample_positions = torch.tensor([0, 1, 2, 3, 4, 5, 6])
9 |
10 | # sample between points should yield four closest points
11 | result = find_control_point_idx_1d(sample_positions, torch.tensor([2.5]))
12 | expected = torch.tensor([[1, 2, 3, 4]])
13 | assert torch.allclose(result, expected)
14 |
15 | # sample on point should be included as lower bound in interval
16 | result = find_control_point_idx_1d(sample_positions, torch.tensor([2]))
17 | expected = torch.tensor([[1, 2, 3, 4]])
18 | assert torch.allclose(result, expected)
19 |
20 | # check the same is true for 3, the upper bound of the same interval
21 | result = find_control_point_idx_1d(sample_positions, torch.tensor([3]))
22 | expected = torch.tensor([[2, 3, 4, 5]])
23 | assert torch.allclose(result, expected)
24 |
25 |
26 | def test_batch():
27 | """All items should be present in minibatches."""
28 | l = [0, 1, 2, 3, 4, 5, 6]
29 | minibatches = [minibatch for minibatch in batch(l, n=3)]
30 | expected = [[0, 1, 2], [3, 4, 5], [6]]
31 | assert minibatches == expected
32 |
33 |
34 | def test_restacking_batch():
35 | """Ensure entries get restacked by cat the same way as they are unstacked."""
36 | batched_input = torch.rand(size=(10, 3)) # (b, d)
37 | minibatches = [minibatch for minibatch in batch(batched_input, n=3)]
38 | restacked_minibatches = torch.cat(minibatches, dim=0)
39 | assert torch.allclose(batched_input, restacked_minibatches)
40 |
--------------------------------------------------------------------------------